1040 lines
45 KiB
Python
1040 lines
45 KiB
Python
"""ReAct 推理-行动循环引擎
|
||
|
||
实现 ReAct (Reasoning-Action) 模式,使 Agent 能够自主推理、
|
||
选择工具并根据中间结果调整策略。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
||
from agentkit.core.protocol import CancellationToken
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
|
||
from agentkit.telemetry.metrics import (
|
||
agent_request_counter,
|
||
agent_duration_histogram,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
|
||
from agentkit.core.trace import TraceRecorder
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ReActStep:
|
||
"""ReAct 单步记录"""
|
||
|
||
step: int
|
||
action: str # "tool_call" or "final_answer"
|
||
tool_name: str | None = None
|
||
arguments: dict[str, Any] | None = None
|
||
result: Any = None
|
||
content: str | None = None
|
||
tokens: int = 0
|
||
|
||
|
||
@dataclass
|
||
class ReActResult:
|
||
"""ReAct 执行结果"""
|
||
|
||
output: str
|
||
trajectory: list[ReActStep]
|
||
total_steps: int
|
||
total_tokens: int
|
||
status: str = "success" # "success" | "timeout" | "cancelled" | "partial"
|
||
fallback_strategy: str | None = None # e.g. "simplified_rewoo", "react", "direct"
|
||
|
||
|
||
@dataclass
|
||
class ReActEvent:
|
||
"""ReAct 执行事件"""
|
||
|
||
event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error"
|
||
step: int
|
||
data: dict[str, Any] = field(default_factory=dict)
|
||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||
|
||
|
||
class ReActEngine:
|
||
"""ReAct 推理-行动循环引擎
|
||
|
||
通过 Think (LLM 调用) → Act (工具执行) → Observe (结果观察) 的循环,
|
||
使 Agent 能够自主推理并选择工具完成任务。
|
||
"""
|
||
|
||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True):
|
||
if max_steps < 1:
|
||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||
self._llm_gateway = llm_gateway
|
||
self._max_steps = max_steps
|
||
self._default_timeout = default_timeout
|
||
self._parallel_tools = parallel_tools
|
||
|
||
async def execute(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
timeout_seconds: float | None = None,
|
||
) -> ReActResult:
|
||
"""执行 ReAct 循环
|
||
|
||
1. 构建初始消息(system_prompt + 任务消息)
|
||
2. 循环:Think (LLM 调用) → Act (工具执行) → Observe (结果)
|
||
3. 停止条件:LLM 不返回 tool_calls,或达到 max_steps
|
||
4. 返回 ReActResult 包含输出和轨迹
|
||
|
||
Args:
|
||
cancellation_token: 协作式取消令牌,每次循环迭代检查是否已取消
|
||
timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout
|
||
"""
|
||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||
|
||
try:
|
||
if effective_timeout > 0:
|
||
result = await asyncio.wait_for(
|
||
self._execute_loop(
|
||
messages=messages,
|
||
tools=tools,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
system_prompt=system_prompt,
|
||
trace_recorder=trace_recorder,
|
||
memory_retriever=memory_retriever,
|
||
task_id=task_id,
|
||
compressor=compressor,
|
||
retrieval_config=retrieval_config,
|
||
cancellation_token=cancellation_token,
|
||
),
|
||
timeout=effective_timeout,
|
||
)
|
||
else:
|
||
result = await self._execute_loop(
|
||
messages=messages,
|
||
tools=tools,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
system_prompt=system_prompt,
|
||
trace_recorder=trace_recorder,
|
||
memory_retriever=memory_retriever,
|
||
task_id=task_id,
|
||
compressor=compressor,
|
||
retrieval_config=retrieval_config,
|
||
cancellation_token=cancellation_token,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
raise TaskTimeoutError(
|
||
task_id=task_id or "",
|
||
timeout_seconds=int(effective_timeout),
|
||
)
|
||
except TaskCancelledError:
|
||
raise
|
||
|
||
return result
|
||
|
||
async def _execute_loop(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
) -> ReActResult:
|
||
tools = tools or []
|
||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||
if tool_schemas:
|
||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||
logger.info(f"ReActEngine executing with {len(tool_schemas)} tools: {tool_names}")
|
||
else:
|
||
logger.info("ReActEngine executing with NO tools")
|
||
|
||
# Telemetry: record agent request
|
||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||
|
||
# Start telemetry span for the entire agent execution
|
||
_span_cm = None
|
||
_span = None
|
||
_exec_start = time.monotonic()
|
||
|
||
if _OTEL_AVAILABLE:
|
||
_span_cm = start_span(
|
||
"agent.execute",
|
||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||
)
|
||
_span = _span_cm.__enter__()
|
||
|
||
# Initialize before try so finally can access them
|
||
trajectory: list[ReActStep] = []
|
||
total_tokens = 0
|
||
trace_outcome = "error"
|
||
|
||
try:
|
||
# 启动轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.start_trace(
|
||
task_id="",
|
||
agent_name=agent_name,
|
||
skill_name=task_type or None,
|
||
)
|
||
|
||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||
if memory_retriever:
|
||
try:
|
||
query = str(messages[-1].get("content", "")) if messages else ""
|
||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||
memory_context = await memory_retriever.get_context_string(
|
||
query=query,
|
||
top_k=top_k,
|
||
token_budget=token_budget,
|
||
)
|
||
if memory_context:
|
||
if system_prompt:
|
||
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||
else:
|
||
system_prompt = f"## 参考信息\n{memory_context}"
|
||
except Exception as e:
|
||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||
|
||
# 构建初始消息
|
||
conversation: list[dict[str, Any]] = []
|
||
if system_prompt:
|
||
conversation.append({"role": "system", "content": system_prompt})
|
||
conversation.extend(messages)
|
||
|
||
# Context compression: 压缩超长对话历史
|
||
if compressor:
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||
|
||
trace_outcome = "success"
|
||
step = 0
|
||
output = ""
|
||
|
||
while step < self._max_steps:
|
||
step += 1
|
||
|
||
# 协作式取消检查
|
||
if cancellation_token is not None:
|
||
cancellation_token.check()
|
||
|
||
# Think: 调用 LLM
|
||
llm_start = time.monotonic()
|
||
response = await self._llm_gateway.chat(
|
||
messages=conversation,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
tools=tool_schemas,
|
||
)
|
||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||
|
||
step_tokens = response.usage.total_tokens
|
||
total_tokens += step_tokens
|
||
|
||
# 检查是否有 Function Calling 的 tool_calls
|
||
if response.has_tool_calls:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# Act: 执行工具调用
|
||
# 先记录 assistant 消息(含 tool_calls)到对话历史
|
||
assistant_msg: dict[str, Any] = {
|
||
"role": "assistant",
|
||
"content": response.content or "",
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.arguments),
|
||
},
|
||
}
|
||
for tc in response.tool_calls
|
||
],
|
||
}
|
||
conversation.append(assistant_msg)
|
||
|
||
# 执行工具调用
|
||
if self._parallel_tools and len(response.tool_calls) > 1:
|
||
# 并行执行多个工具调用
|
||
tool_results = await asyncio.gather(
|
||
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
|
||
return_exceptions=True,
|
||
)
|
||
for idx, tc in enumerate(response.tool_calls):
|
||
tool_result = tool_results[idx]
|
||
if isinstance(tool_result, Exception):
|
||
tool_result = {"error": str(tool_result)}
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=0,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||
conversation.append(tool_msg)
|
||
else:
|
||
# 串行执行(单工具或 parallel_tools=False)
|
||
for tc in response.tool_calls:
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# Observe: 将工具结果添加到对话历史
|
||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
|
||
else:
|
||
# 检查文本解析模式
|
||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||
if parsed_calls and tools:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# 文本解析模式执行工具
|
||
conversation.append({"role": "assistant", "content": response.content})
|
||
|
||
for pc in parsed_calls:
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
arguments=pc["arguments"],
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
input_data=pc["arguments"],
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# 将工具结果添加到对话历史
|
||
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"])
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
else:
|
||
# Final answer: LLM 没有调用工具,返回最终答案
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="final_answer",
|
||
content=response.content,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
output = response.content or ""
|
||
|
||
# 记录最终答案步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="final_answer",
|
||
output_data={"content": response.content},
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
break
|
||
|
||
# 达到 max_steps 时,返回当前最佳输出
|
||
if step >= self._max_steps and not output:
|
||
trace_outcome = "partial"
|
||
# 使用最后一步的内容作为输出
|
||
if trajectory and trajectory[-1].content:
|
||
output = trajectory[-1].content
|
||
elif trajectory and trajectory[-1].result is not None:
|
||
output = str(trajectory[-1].result)
|
||
else:
|
||
output = response.content or ""
|
||
|
||
# 结束轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.end_trace(outcome=trace_outcome)
|
||
|
||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||
try:
|
||
summary = output[:500] if output else ""
|
||
await memory_retriever.store_episode(
|
||
key=f"task:{task_id or 'unknown'}",
|
||
value={"output_summary": summary, "agent_name": agent_name},
|
||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||
|
||
return ReActResult(
|
||
output=output,
|
||
trajectory=trajectory,
|
||
total_steps=len(trajectory),
|
||
total_tokens=total_tokens,
|
||
)
|
||
finally:
|
||
# Telemetry: end span and record duration — always runs
|
||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||
if _span is not None:
|
||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||
_span.set_attribute("agent.outcome", trace_outcome)
|
||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||
if _span_cm is not None:
|
||
_span_cm.__exit__(None, None, None)
|
||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||
|
||
async def execute_stream(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
timeout_seconds: float | None = None,
|
||
confirmation_handler: Any | None = None,
|
||
):
|
||
"""Execute ReAct loop, yielding ReActEvent objects.
|
||
|
||
Same logic as execute() but yields events at each step instead of
|
||
accumulating a result.
|
||
"""
|
||
tools = tools or []
|
||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||
if tool_schemas:
|
||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||
logger.info(f"ReActEngine executing with {len(tool_schemas)} tools: {tool_names}")
|
||
else:
|
||
logger.info("ReActEngine executing with NO tools")
|
||
|
||
# 启动轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.start_trace(
|
||
task_id="",
|
||
agent_name=agent_name,
|
||
skill_name=task_type or None,
|
||
)
|
||
|
||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||
if memory_retriever:
|
||
try:
|
||
query = str(messages[-1].get("content", "")) if messages else ""
|
||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||
memory_context = await memory_retriever.get_context_string(
|
||
query=query,
|
||
top_k=top_k,
|
||
token_budget=token_budget,
|
||
)
|
||
if memory_context:
|
||
if system_prompt:
|
||
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||
else:
|
||
system_prompt = f"## 参考信息\n{memory_context}"
|
||
except Exception as e:
|
||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||
|
||
conversation: list[dict[str, Any]] = []
|
||
if system_prompt:
|
||
conversation.append({"role": "system", "content": system_prompt})
|
||
conversation.extend(messages)
|
||
|
||
# Context compression: 压缩超长对话历史
|
||
if compressor:
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||
|
||
trajectory: list[ReActStep] = []
|
||
total_tokens = 0
|
||
step = 0
|
||
output = ""
|
||
trace_outcome = "success"
|
||
|
||
try:
|
||
while step < self._max_steps:
|
||
step += 1
|
||
|
||
# Yield thinking event
|
||
yield ReActEvent(
|
||
event_type="thinking",
|
||
step=step,
|
||
data={"message": f"Step {step}: Calling LLM..."},
|
||
)
|
||
|
||
# Think: call LLM (with optional token streaming)
|
||
llm_start = time.monotonic()
|
||
|
||
# Use streaming for token-by-token output
|
||
stream_content = ""
|
||
stream_usage = None
|
||
stream_tool_calls: list[Any] = []
|
||
stream_model = model
|
||
|
||
async for chunk in self._llm_gateway.chat_stream(
|
||
messages=conversation,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
tools=tool_schemas,
|
||
):
|
||
if chunk.content:
|
||
stream_content += chunk.content
|
||
yield ReActEvent(
|
||
event_type="token",
|
||
step=step,
|
||
data={"content": chunk.content},
|
||
)
|
||
if chunk.usage:
|
||
stream_usage = chunk.usage
|
||
if chunk.tool_calls:
|
||
stream_tool_calls = chunk.tool_calls
|
||
if chunk.model:
|
||
stream_model = chunk.model
|
||
|
||
# Build response-like object from stream
|
||
response = self._build_response_from_stream(
|
||
content=stream_content,
|
||
tool_calls=stream_tool_calls,
|
||
usage=stream_usage,
|
||
model=stream_model,
|
||
)
|
||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||
|
||
step_tokens = response.usage.total_tokens
|
||
total_tokens += step_tokens
|
||
|
||
if response.has_tool_calls:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# Record assistant message
|
||
assistant_msg: dict[str, Any] = {
|
||
"role": "assistant",
|
||
"content": response.content or "",
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.arguments),
|
||
},
|
||
}
|
||
for tc in response.tool_calls
|
||
],
|
||
}
|
||
conversation.append(assistant_msg)
|
||
|
||
for tc in response.tool_calls:
|
||
# Yield tool_call event
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||
)
|
||
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
# 检测工具返回的确认请求
|
||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||
confirmation_id = tool_result["confirmation_id"]
|
||
command = tool_result.get("command", "")
|
||
reason = tool_result.get("reason", "")
|
||
|
||
# Yield 确认请求事件
|
||
yield ReActEvent(
|
||
event_type="confirmation_request",
|
||
step=step,
|
||
data={
|
||
"confirmation_id": confirmation_id,
|
||
"tool_name": tc.name,
|
||
"command": command,
|
||
"reason": reason,
|
||
},
|
||
)
|
||
|
||
# 等待用户确认
|
||
approved = False
|
||
if confirmation_handler is not None:
|
||
try:
|
||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||
except Exception as e:
|
||
logger.warning(f"Confirmation handler error: {e}")
|
||
|
||
if approved:
|
||
# 用户确认执行:临时绕过安全检查重新执行
|
||
tool = self._find_tool(tc.name, tools)
|
||
if tool and hasattr(tool, '_is_dangerous'):
|
||
# 保存原始 _is_dangerous 并临时禁用
|
||
original_is_dangerous = tool._is_dangerous
|
||
tool._is_dangerous = lambda cmd: False
|
||
try:
|
||
tool_result = await tool.safe_execute(**tc.arguments)
|
||
finally:
|
||
tool._is_dangerous = original_is_dangerous
|
||
else:
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
|
||
yield ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={"confirmation_id": confirmation_id, "approved": True},
|
||
)
|
||
else:
|
||
# 用户拒绝执行
|
||
tool_result = {
|
||
"output": "",
|
||
"exit_code": 126,
|
||
"is_error": True,
|
||
"error_type": "permission_denied",
|
||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||
}
|
||
yield ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={"confirmation_id": confirmation_id, "approved": False},
|
||
)
|
||
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# Yield tool_result event
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step,
|
||
data={"tool_name": tc.name, "result": tool_result},
|
||
)
|
||
|
||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
|
||
else:
|
||
# Check text parsing mode
|
||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||
if parsed_calls and tools:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
conversation.append({"role": "assistant", "content": response.content})
|
||
|
||
for pc in parsed_calls:
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||
)
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
trajectory.append(ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
arguments=pc["arguments"],
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
))
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
input_data=pc["arguments"],
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step,
|
||
data={"tool_name": pc["name"], "result": tool_result},
|
||
)
|
||
tool_msg = await self._build_tool_result_message(
|
||
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
|
||
)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
else:
|
||
# Final answer
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="final_answer",
|
||
content=response.content,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
output = response.content or ""
|
||
|
||
# 记录最终答案步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="final_answer",
|
||
output_data={"content": response.content},
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
yield ReActEvent(
|
||
event_type="final_answer",
|
||
step=step,
|
||
data={
|
||
"output": output,
|
||
"total_steps": len(trajectory),
|
||
"total_tokens": total_tokens,
|
||
},
|
||
)
|
||
break
|
||
|
||
if step >= self._max_steps and not output:
|
||
trace_outcome = "partial"
|
||
if trajectory and trajectory[-1].content:
|
||
output = trajectory[-1].content
|
||
elif trajectory and trajectory[-1].result is not None:
|
||
output = str(trajectory[-1].result)
|
||
else:
|
||
output = response.content or ""
|
||
|
||
yield ReActEvent(
|
||
event_type="final_answer",
|
||
step=step,
|
||
data={
|
||
"output": output,
|
||
"total_steps": len(trajectory),
|
||
"total_tokens": total_tokens,
|
||
"max_steps_reached": True,
|
||
},
|
||
)
|
||
finally:
|
||
# 结束轨迹记录 — always runs even if consumer doesn't fully iterate
|
||
if trace_recorder is not None:
|
||
trace_recorder.end_trace(outcome=trace_outcome)
|
||
|
||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||
try:
|
||
summary = output[:500] if output else ""
|
||
await memory_retriever.store_episode(
|
||
key=f"task:{task_id or 'unknown'}",
|
||
value={"output_summary": summary, "agent_name": agent_name},
|
||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||
|
||
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
||
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
||
schemas = []
|
||
for tool in tools:
|
||
schema = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool.name,
|
||
"description": tool.description,
|
||
"parameters": tool.input_schema or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
schemas.append(schema)
|
||
return schemas
|
||
|
||
@staticmethod
|
||
def _build_response_from_stream(
|
||
content: str,
|
||
tool_calls: list[Any],
|
||
usage: Any,
|
||
model: str,
|
||
) -> LLMResponse:
|
||
"""Build an LLMResponse from accumulated stream chunks."""
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
if usage is None:
|
||
usage = TokenUsage()
|
||
return LLMResponse(
|
||
content=content,
|
||
tool_calls=tool_calls,
|
||
usage=usage,
|
||
model=model,
|
||
)
|
||
|
||
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
|
||
"""根据名称从可用工具中查找工具"""
|
||
for tool in tools:
|
||
if tool.name == name:
|
||
return tool
|
||
return None
|
||
|
||
# Default token threshold for incremental compression
|
||
_DEFAULT_COMPRESS_THRESHOLD = 8000
|
||
|
||
def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool:
|
||
"""检查是否需要增量压缩"""
|
||
if not compressor:
|
||
return False
|
||
# Estimate tokens in conversation (rough: 4 chars ≈ 1 token)
|
||
total_chars = sum(len(str(m.get("content", ""))) for m in conversation)
|
||
estimated_tokens = total_chars // 4
|
||
return estimated_tokens > self._DEFAULT_COMPRESS_THRESHOLD
|
||
|
||
async def _build_tool_result_message(
|
||
self,
|
||
tool_call_id: str,
|
||
result: Any,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
tool_name: str | None = None,
|
||
) -> dict:
|
||
"""构建工具结果消息用于对话历史"""
|
||
content = str(result)
|
||
if compressor and tool_name:
|
||
try:
|
||
content = await compressor.compress_tool_result(tool_name, result)
|
||
except Exception as e:
|
||
logger.warning(f"Tool result compression failed for '{tool_name}': {e}")
|
||
content = str(result)
|
||
return {
|
||
"role": "tool",
|
||
"tool_call_id": tool_call_id,
|
||
"content": content,
|
||
}
|
||
|
||
async def _execute_tool(
|
||
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
|
||
) -> dict:
|
||
"""执行工具调用,处理成功和失败情况"""
|
||
tool = self._find_tool(tool_name, tools)
|
||
if tool is None:
|
||
error_msg = f"Tool '{tool_name}' not found"
|
||
logger.warning(error_msg)
|
||
return {"error": error_msg}
|
||
|
||
try:
|
||
result = await tool.safe_execute(**arguments)
|
||
return result
|
||
except Exception as e:
|
||
error_msg = f"Tool '{tool_name}' execution failed: {e}"
|
||
logger.warning(error_msg)
|
||
return {"error": error_msg}
|
||
|
||
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
|
||
"""从文本中解析工具调用模式
|
||
|
||
支持两种格式:
|
||
1. Action: tool_name(args)
|
||
2. ```tool\\n{"name": "...", "arguments": {...}}\\n```
|
||
"""
|
||
calls: list[dict[str, Any]] = []
|
||
|
||
# 格式 1: Action: tool_name(args)
|
||
action_pattern = re.compile(
|
||
r"Action:\s*(\w+)\((.+?)\)", re.DOTALL
|
||
)
|
||
for match in action_pattern.finditer(content):
|
||
name = match.group(1)
|
||
args_str = match.group(2)
|
||
try:
|
||
arguments = json.loads(args_str)
|
||
except (json.JSONDecodeError, TypeError):
|
||
arguments = {"raw_input": args_str}
|
||
calls.append({"name": name, "arguments": arguments})
|
||
|
||
if calls:
|
||
return calls
|
||
|
||
# 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n```
|
||
code_block_pattern = re.compile(
|
||
r"```tool\s*\n(.*?)\n\s*```", re.DOTALL
|
||
)
|
||
for match in code_block_pattern.finditer(content):
|
||
json_str = match.group(1).strip()
|
||
try:
|
||
parsed = json.loads(json_str)
|
||
name = parsed.get("name", "")
|
||
arguments = parsed.get("arguments", {})
|
||
if name:
|
||
calls.append({"name": name, "arguments": arguments})
|
||
except (json.JSONDecodeError, TypeError):
|
||
logger.warning(f"Failed to parse tool call from text: {json_str}")
|
||
|
||
return calls
|