2307 lines
104 KiB
Python
2307 lines
104 KiB
Python
"""ReAct 推理-行动循环引擎
|
||
|
||
实现 ReAct (Reasoning-Action) 模式,使 Agent 能够自主推理、
|
||
选择工具并根据中间结果调整策略。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from collections import Counter, deque
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, TaskTimeoutError
|
||
from agentkit.core.protocol import CancellationToken
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse
|
||
from agentkit.tools.base import Tool, ToolValidationError
|
||
from agentkit.telemetry.tracing import start_span, _OTEL_AVAILABLE
|
||
from agentkit.telemetry.metrics import (
|
||
agent_request_counter,
|
||
agent_duration_histogram,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.core.compressor import CompressionStrategy
|
||
from agentkit.core.middleware import MiddlewareChain
|
||
from agentkit.core.trace import TraceRecorder
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ReActStep:
|
||
"""ReAct 单步记录"""
|
||
|
||
step: int
|
||
action: str # "tool_call" or "final_answer"
|
||
tool_name: str | None = None
|
||
arguments: dict[str, Any] | None = None
|
||
result: Any = None
|
||
content: str | None = None
|
||
tokens: int = 0
|
||
|
||
|
||
async def _ensure_async_iterable(obj: Any, label: str = "<obj>"):
|
||
"""Defensive helper: ensure the given object is an async iterable.
|
||
|
||
Guards against the recurring ``'async for' requires an object with
|
||
__aiter__ method, got coroutine`` error. This error happens when an
|
||
``async def`` function that *should* yield values ends up returning a
|
||
coroutine object instead of an async generator — typically because
|
||
every code path through the function exits before the first ``yield``
|
||
(e.g. early ``raise``) and a misbehaving caller in some Python
|
||
versions or a specific runtime configuration treats it as a coroutine.
|
||
|
||
This helper accepts either:
|
||
- An async iterable (async generator) → returned as-is.
|
||
- An awaitable that resolves to an async iterable → awaited, then yielded.
|
||
- Anything else → raises a clear, actionable error naming ``label``.
|
||
|
||
Use it like::
|
||
|
||
async for chunk in _ensure_async_iterable(
|
||
some_func_that_should_stream(), label="some_func"
|
||
):
|
||
...
|
||
|
||
Args:
|
||
obj: The object returned by calling an ``async def`` function.
|
||
label: A short human-readable name used in error messages to help
|
||
locate the source of the bug.
|
||
|
||
Yields:
|
||
Items from the resolved async iterable.
|
||
|
||
Raises:
|
||
TypeError: If ``obj`` is neither an async iterable nor an
|
||
awaitable that resolves to one. The error message names
|
||
``label`` so the offending call site is easy to find.
|
||
"""
|
||
# Case 1: already an async iterable (the normal case).
|
||
if hasattr(obj, "__aiter__"):
|
||
async for item in obj:
|
||
yield item
|
||
return
|
||
|
||
# Case 2: an awaitable that hasn't been awaited yet (the bug we're
|
||
# guarding against). Awaiting it should produce an async iterable.
|
||
if asyncio.iscoroutine(obj) or asyncio.isfuture(obj):
|
||
resolved = await obj
|
||
if hasattr(resolved, "__aiter__"):
|
||
async for item in resolved:
|
||
yield item
|
||
return
|
||
raise TypeError(
|
||
f"{label}: awaited value is not async iterable (got {type(resolved).__name__})"
|
||
)
|
||
|
||
# Case 3: anything else — surface a clear, actionable error rather
|
||
# than the cryptic CPython ``TypeError: 'async for' requires...``.
|
||
raise TypeError(
|
||
f"{label}: expected an async iterable, got {type(obj).__name__}. "
|
||
f"This usually means the called function returned a coroutine "
|
||
f"instead of an async generator — check that it contains at "
|
||
f"least one reachable ``yield`` statement."
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class ReActResult:
|
||
"""ReAct 执行结果"""
|
||
|
||
output: str
|
||
trajectory: list[ReActStep]
|
||
total_steps: int
|
||
total_tokens: int
|
||
status: str = "success" # "success" | "timeout" | "cancelled" | "partial"
|
||
fallback_strategy: str | None = None # e.g. "simplified_rewoo", "react", "direct"
|
||
|
||
|
||
@dataclass
|
||
class ReActEvent:
|
||
"""ReAct 执行事件"""
|
||
|
||
event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error"
|
||
step: int
|
||
data: dict[str, Any] = field(default_factory=dict)
|
||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||
|
||
|
||
class ReActEngine:
|
||
"""ReAct 推理-行动循环引擎
|
||
|
||
通过 Think (LLM 调用) → Act (工具执行) → Observe (结果观察) 的循环,
|
||
使 Agent 能够自主推理并选择工具完成任务。
|
||
"""
|
||
|
||
# Default core tools that always get full descriptions injected into the
|
||
# prompt. ``tool_search`` is included so its full description is always
|
||
# available to the LLM when tiered injection is active.
|
||
_DEFAULT_CORE_TOOLS: tuple[str, ...] = (
|
||
"read_file",
|
||
"write_file",
|
||
"bash",
|
||
"search",
|
||
"tool_search",
|
||
)
|
||
|
||
def __init__(
|
||
self,
|
||
llm_gateway: LLMGateway,
|
||
max_steps: int = 10,
|
||
default_timeout: float = 300.0,
|
||
parallel_tools: bool | str = False,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
verification_enabled: bool = False,
|
||
verification_commands: list[str] | None = None,
|
||
core_tool_names: list[str] | None = None,
|
||
enable_tool_search: bool = True,
|
||
middleware_chain: "MiddlewareChain | None" = None,
|
||
prompt_cache_enable: bool = True,
|
||
flush_interval_ms: int = 0,
|
||
):
|
||
if max_steps < 1:
|
||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||
if isinstance(parallel_tools, str) and parallel_tools not in ("auto",):
|
||
raise ValueError(
|
||
f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}"
|
||
)
|
||
self._llm_gateway = llm_gateway
|
||
self._max_steps = max_steps
|
||
self._default_timeout = default_timeout
|
||
self._parallel_tools = parallel_tools
|
||
self._verification_enabled = verification_enabled
|
||
self._verification_commands = verification_commands
|
||
# U2/G2: prompt cache 双块结构开关(True 时 Anthropic 用 cache_control blocks,
|
||
# 其他 provider 走字符串拼接依赖自动前缀缓存)
|
||
self._prompt_cache_enable = prompt_cache_enable
|
||
# U3/G8: token chunk 节流间隔(ms)。0 = 逐 chunk yield(向后兼容)。
|
||
# 用 time.monotonic() 不受系统时钟跳变影响。
|
||
self._flush_interval_ms = flush_interval_ms
|
||
# Tiered tool description injection config
|
||
self._core_tool_names: tuple[str, ...] | None = (
|
||
tuple(core_tool_names) if core_tool_names is not None else None
|
||
)
|
||
self._enable_tool_search = enable_tool_search
|
||
# Default context compression: keep last 10 turns
|
||
if compressor is not None:
|
||
self._compressor = compressor
|
||
else:
|
||
from agentkit.core.compressor import ContextCompressor
|
||
|
||
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
|
||
|
||
# Loop detection: sliding window of (tool_name, args_hash) to catch
|
||
# repeated identical tool calls. ponytail: hash-based, not semantic.
|
||
self._loop_window: deque[str] = deque(maxlen=5)
|
||
self._loop_threshold: int = 2
|
||
self._loop_corrected: bool = False
|
||
# U6: Middleware chain (parallel integration, feature flag controlled)
|
||
self._middleware_chain = middleware_chain
|
||
|
||
def reset(self) -> None:
|
||
"""Reset internal state for reuse across conversations.
|
||
|
||
Call this before each execute/execute_stream to ensure clean state.
|
||
The engine itself (LLM gateway, config) is preserved.
|
||
"""
|
||
# ReActEngine is stateless between calls — conversation history,
|
||
# step counts, and trajectory are local to each execute call.
|
||
# This method exists for API clarity and future stateful extensions.
|
||
self._loop_window.clear()
|
||
self._loop_corrected = False
|
||
|
||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||
"""检测重复工具调用模式。
|
||
|
||
将当前步的工具调用 hash 加入滑动窗口,若同一 hash 在窗口内出现
|
||
>= threshold 次,返回对应的 tool_name;否则返回 None。
|
||
|
||
ponytail: 精确 hash 匹配,不做语义相似度。
|
||
"""
|
||
hash_to_name: dict[str, str] = {}
|
||
for tc in tool_calls:
|
||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||
h = str(hash(f"{tc.name}:{args_str}"))
|
||
self._loop_window.append(h)
|
||
hash_to_name[h] = tc.name
|
||
|
||
counts = Counter(self._loop_window)
|
||
for h, count in counts.items():
|
||
if count >= self._loop_threshold:
|
||
return hash_to_name.get(h)
|
||
return None
|
||
|
||
async def execute(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
timeout_seconds: float | None = None,
|
||
confirmation_handler: Any | None = None,
|
||
) -> ReActResult:
|
||
"""执行 ReAct 循环
|
||
|
||
1. 构建初始消息(system_prompt + 任务消息)
|
||
2. 循环:Think (LLM 调用) → Act (工具执行) → Observe (结果)
|
||
3. 停止条件:LLM 不返回 tool_calls,或达到 max_steps
|
||
4. 返回 ReActResult 包含输出和轨迹
|
||
|
||
Args:
|
||
compressor: 压缩策略,None 时使用实例默认压缩器
|
||
cancellation_token: 协作式取消令牌,每次循环迭代检查是否已取消
|
||
timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout
|
||
"""
|
||
# P2 #9: Reset loop detection state so reuse across conversations is clean
|
||
self.reset()
|
||
effective_compressor = compressor if compressor is not None else self._compressor
|
||
effective_timeout = (
|
||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||
)
|
||
|
||
# U6: Middleware chain (parallel integration, KTD1)
|
||
# If middleware_chain is present, wrap the handler with it.
|
||
# Otherwise, use the existing path (backward compatible).
|
||
if self._middleware_chain is not None:
|
||
from agentkit.core.middleware import RequestContext
|
||
|
||
ctx = RequestContext(
|
||
messages=messages,
|
||
tools=tools or [],
|
||
system_prompt=system_prompt,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
task_id=task_id,
|
||
)
|
||
|
||
async def _handler(c: RequestContext) -> ReActResult:
|
||
return await self._execute_loop(
|
||
messages=c.messages,
|
||
tools=c.tools or None,
|
||
model=c.model,
|
||
agent_name=c.agent_name,
|
||
task_type=c.task_type,
|
||
system_prompt=c.system_prompt,
|
||
trace_recorder=trace_recorder,
|
||
memory_retriever=memory_retriever,
|
||
task_id=c.task_id,
|
||
compressor=effective_compressor,
|
||
retrieval_config=retrieval_config,
|
||
cancellation_token=cancellation_token,
|
||
confirmation_handler=confirmation_handler,
|
||
)
|
||
|
||
try:
|
||
if effective_timeout > 0:
|
||
result = await asyncio.wait_for(
|
||
self._middleware_chain.execute(ctx, _handler),
|
||
timeout=effective_timeout,
|
||
)
|
||
else:
|
||
result = await self._middleware_chain.execute(ctx, _handler)
|
||
except asyncio.TimeoutError:
|
||
raise TaskTimeoutError(
|
||
task_id=task_id or "",
|
||
timeout_seconds=int(effective_timeout),
|
||
)
|
||
except TaskCancelledError:
|
||
raise
|
||
|
||
return result
|
||
|
||
try:
|
||
if effective_timeout > 0:
|
||
result = await asyncio.wait_for(
|
||
self._execute_loop(
|
||
messages=messages,
|
||
tools=tools,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
system_prompt=system_prompt,
|
||
trace_recorder=trace_recorder,
|
||
memory_retriever=memory_retriever,
|
||
task_id=task_id,
|
||
compressor=effective_compressor,
|
||
retrieval_config=retrieval_config,
|
||
cancellation_token=cancellation_token,
|
||
confirmation_handler=confirmation_handler,
|
||
),
|
||
timeout=effective_timeout,
|
||
)
|
||
else:
|
||
result = await self._execute_loop(
|
||
messages=messages,
|
||
tools=tools,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
system_prompt=system_prompt,
|
||
trace_recorder=trace_recorder,
|
||
memory_retriever=memory_retriever,
|
||
task_id=task_id,
|
||
compressor=effective_compressor,
|
||
retrieval_config=retrieval_config,
|
||
cancellation_token=cancellation_token,
|
||
confirmation_handler=confirmation_handler,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
raise TaskTimeoutError(
|
||
task_id=task_id or "",
|
||
timeout_seconds=int(effective_timeout),
|
||
)
|
||
except TaskCancelledError:
|
||
raise
|
||
|
||
return result
|
||
|
||
async def _execute_loop(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
confirmation_handler: Any | None = None,
|
||
) -> ReActResult:
|
||
tools = tools or []
|
||
if tools:
|
||
tools = self._maybe_add_tool_search(tools)
|
||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||
if tool_schemas:
|
||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||
logger.info(f"ReActEngine executing with {len(tool_schemas)} tools: {tool_names}")
|
||
else:
|
||
logger.info("ReActEngine executing with NO tools")
|
||
|
||
# Prompt-based tool calling: inject tool descriptions into system prompt
|
||
# when tools are available, so LLM can use <tool_use> format even if
|
||
# the provider doesn't support native function calling.
|
||
if tools and system_prompt is not None:
|
||
tool_desc = self._build_tool_use_prompt(tools)
|
||
system_prompt = f"{system_prompt}\n\n{tool_desc}"
|
||
elif tools and system_prompt is None:
|
||
system_prompt = self._build_tool_use_prompt(tools)
|
||
|
||
# Telemetry: record agent request
|
||
agent_request_counter().add(
|
||
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
|
||
)
|
||
|
||
# Start telemetry span for the entire agent execution
|
||
_span_cm = None
|
||
_span = None
|
||
_exec_start = time.monotonic()
|
||
|
||
if _OTEL_AVAILABLE:
|
||
_span_cm = start_span(
|
||
"agent.execute",
|
||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||
)
|
||
_span = _span_cm.__enter__()
|
||
|
||
# Initialize before try so finally can access them
|
||
trajectory: list[ReActStep] = []
|
||
total_tokens = 0
|
||
trace_outcome = "error"
|
||
|
||
try:
|
||
# 启动轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.start_trace(
|
||
task_id="",
|
||
agent_name=agent_name,
|
||
skill_name=task_type or None,
|
||
)
|
||
|
||
# Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message
|
||
# U2/G2: 不再拼到 stable(system_prompt)末尾,改由 _build_system_message 组装双块结构
|
||
memory_context = ""
|
||
if memory_retriever:
|
||
try:
|
||
query = str(messages[-1].get("content", "")) if messages else ""
|
||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||
memory_context = await memory_retriever.get_context_string(
|
||
query=query,
|
||
top_k=top_k,
|
||
token_budget=token_budget,
|
||
) or ""
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Memory retrieval failed, continuing without context: {e}", exc_info=True
|
||
)
|
||
|
||
# 构建初始消息
|
||
conversation: list[dict[str, Any]] = []
|
||
system_content = self._build_system_message(
|
||
stable=system_prompt or "",
|
||
volatile=memory_context,
|
||
model=model,
|
||
)
|
||
if system_content is not None:
|
||
conversation.append({"role": "system", "content": system_content})
|
||
conversation.extend(messages)
|
||
|
||
# Context compression: 压缩超长对话历史
|
||
if compressor:
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Context compression failed, continuing with original messages: {e}"
|
||
)
|
||
|
||
trace_outcome = "success"
|
||
step = 0
|
||
output = ""
|
||
|
||
while step < self._max_steps:
|
||
step += 1
|
||
|
||
# 协作式取消检查
|
||
if cancellation_token is not None:
|
||
cancellation_token.check()
|
||
|
||
# Think: 调用 LLM
|
||
llm_start = time.monotonic()
|
||
response = await self._llm_gateway.chat(
|
||
messages=conversation,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
tools=tool_schemas,
|
||
)
|
||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||
|
||
step_tokens = response.usage.total_tokens
|
||
total_tokens += step_tokens
|
||
|
||
# 检查是否有 Function Calling 的 tool_calls
|
||
if response.has_tool_calls:
|
||
# 循环检测:检查是否重复调用相同工具+参数
|
||
looped_tool = self._check_tool_loop(response.tool_calls)
|
||
if looped_tool is not None:
|
||
if not self._loop_corrected:
|
||
# 第一次检测:注入纠正消息,给 LLM 改变策略的机会
|
||
logger.warning(
|
||
f"Loop detected: tool '{looped_tool}' repeated, "
|
||
f"injecting correction at step {step}"
|
||
)
|
||
correction_msg = {
|
||
"role": "user",
|
||
"content": (
|
||
f"You are repeatedly calling tool '{looped_tool}' "
|
||
f"with the same arguments. This indicates a loop. "
|
||
f"Please change your strategy or provide a final answer."
|
||
),
|
||
}
|
||
conversation.append(correction_msg)
|
||
self._loop_corrected = True
|
||
continue
|
||
else:
|
||
# 第二次检测:纠正后仍未改变,强制中断
|
||
raise LoopDetectedError(
|
||
tool_name=looped_tool,
|
||
repetitions=self._loop_threshold + 1,
|
||
)
|
||
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# Act: 执行工具调用
|
||
# 先记录 assistant 消息(含 tool_calls)到对话历史
|
||
assistant_msg: dict[str, Any] = {
|
||
"role": "assistant",
|
||
"content": response.content or "",
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.arguments),
|
||
},
|
||
}
|
||
for tc in response.tool_calls
|
||
],
|
||
}
|
||
conversation.append(assistant_msg)
|
||
|
||
# 执行工具调用
|
||
if self._parallel_tools == "auto" and len(response.tool_calls) > 1:
|
||
# Auto mode: mixed parallel/serial based on _parallelizable flag
|
||
parallelizable_set = set(
|
||
self._get_parallelizable_indices(response.tool_calls)
|
||
)
|
||
serial_calls = [
|
||
(i, tc)
|
||
for i, tc in enumerate(response.tool_calls)
|
||
if i not in parallelizable_set
|
||
]
|
||
parallel_calls = [
|
||
(i, tc)
|
||
for i, tc in enumerate(response.tool_calls)
|
||
if i in parallelizable_set
|
||
]
|
||
|
||
# Result slots indexed by original position
|
||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||
|
||
# Execute serial tools first (in order)
|
||
for i, tc in serial_calls:
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
all_results[i] = (tc, tool_result, tool_duration_ms)
|
||
|
||
# Execute parallelizable tools in parallel
|
||
if len(parallel_calls) > 1:
|
||
para_results = await asyncio.gather(
|
||
*[
|
||
self._execute_tool(tc.name, tc.arguments, tools)
|
||
for _, tc in parallel_calls
|
||
],
|
||
return_exceptions=True,
|
||
)
|
||
for j, (i, tc) in enumerate(parallel_calls):
|
||
tool_result = para_results[j]
|
||
if isinstance(tool_result, Exception):
|
||
tool_result = {"error": str(tool_result)}
|
||
all_results[i] = (tc, tool_result, 0)
|
||
elif len(parallel_calls) == 1:
|
||
i, tc = parallel_calls[0]
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
all_results[i] = (tc, tool_result, 0)
|
||
|
||
# Process all results in original order
|
||
for i, tc in enumerate(response.tool_calls):
|
||
tc_obj, tool_result, tool_duration_ms = all_results[i]
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
tool_msg = await self._build_tool_result_message(
|
||
tc.id, tool_result, compressor, tc.name
|
||
)
|
||
conversation.append(tool_msg)
|
||
elif self._should_execute_parallel(response.tool_calls):
|
||
# 并行执行多个工具调用 (parallel_tools=True)
|
||
tool_results = await asyncio.gather(
|
||
*[
|
||
self._execute_tool(tc.name, tc.arguments, tools)
|
||
for tc in response.tool_calls
|
||
],
|
||
return_exceptions=True,
|
||
)
|
||
for idx, tc in enumerate(response.tool_calls):
|
||
tool_result = tool_results[idx]
|
||
if isinstance(tool_result, Exception):
|
||
tool_result = {"error": str(tool_result)}
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=0,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
tool_msg = await self._build_tool_result_message(
|
||
tc.id, tool_result, compressor, tc.name
|
||
)
|
||
conversation.append(tool_msg)
|
||
else:
|
||
# 串行执行(单工具或 parallel_tools=False)
|
||
for tc in response.tool_calls:
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
|
||
# Handle confirmation flow
|
||
if isinstance(tool_result, dict) and tool_result.get(
|
||
"needs_confirmation"
|
||
):
|
||
confirmation_id = tool_result["confirmation_id"]
|
||
command = tool_result.get("command", "")
|
||
reason = tool_result.get("reason", "")
|
||
|
||
approved = False
|
||
if confirmation_handler is not None:
|
||
try:
|
||
approved = await confirmation_handler(
|
||
confirmation_id, command, reason
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Confirmation handler error: {e}")
|
||
|
||
if approved:
|
||
tool = self._find_tool(tc.name, tools)
|
||
if tool and hasattr(tool, "_is_dangerous"):
|
||
clean_args = {
|
||
k: v
|
||
for k, v in tc.arguments.items()
|
||
if not k.startswith("_")
|
||
}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = await tool.safe_execute(**clean_args)
|
||
except Exception as e:
|
||
tool_result = {
|
||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||
}
|
||
else:
|
||
# Non-dangerous tool: confirmation was for the overall action,
|
||
# re-execute with skip flag to avoid re-triggering confirmation
|
||
clean_args = {
|
||
k: v
|
||
for k, v in tc.arguments.items()
|
||
if not k.startswith("_")
|
||
}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = (
|
||
await tool.safe_execute(**clean_args)
|
||
if tool
|
||
else {"error": f"Tool '{tc.name}' not found"}
|
||
)
|
||
except Exception as e:
|
||
tool_result = {
|
||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||
}
|
||
else:
|
||
tool_result = {
|
||
"output": "",
|
||
"exit_code": 126,
|
||
"is_error": True,
|
||
"error_type": "permission_denied",
|
||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||
}
|
||
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# Observe: 将工具结果添加到对话历史
|
||
tool_msg = await self._build_tool_result_message(
|
||
tc.id, tool_result, compressor, tc.name
|
||
)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
|
||
else:
|
||
# 检查文本解析模式
|
||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||
if parsed_calls and tools:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# 文本解析模式执行工具
|
||
conversation.append({"role": "assistant", "content": response.content})
|
||
|
||
for pc in parsed_calls:
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(
|
||
pc["name"], pc["arguments"], tools
|
||
)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
arguments=pc["arguments"],
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
input_data=pc["arguments"],
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# 将工具结果添加到对话历史
|
||
tool_msg = await self._build_tool_result_message(
|
||
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
|
||
)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, compressor):
|
||
try:
|
||
conversation = await compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
else:
|
||
# ponytail: 检查是否为畸形工具调用(含 <tool_use> 但解析失败)
|
||
# 如果是,注入纠正消息让模型重试,而不是把原始 XML 作为最终答案泄漏
|
||
if "<tool_use>" in (response.content or ""):
|
||
logger.warning(
|
||
f"Step {step}: content contains <tool_use> but "
|
||
f"parsing failed — injecting correction"
|
||
)
|
||
conversation.append(
|
||
{"role": "assistant", "content": response.content}
|
||
)
|
||
conversation.append(
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
"你上一次的工具调用格式有误,无法解析。"
|
||
"请使用正确的格式重新调用工具:\n"
|
||
'<tool_use>\n'
|
||
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||
"</tool_use>\n"
|
||
"确保 JSON 完整且不要混入其他标签。"
|
||
),
|
||
}
|
||
)
|
||
continue
|
||
|
||
# Final answer: LLM 没有调用工具,返回最终答案
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="final_answer",
|
||
content=response.content,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
output = response.content or ""
|
||
|
||
# 记录最终答案步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="final_answer",
|
||
output_data={"content": response.content},
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
break
|
||
|
||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||
if self._verification_enabled and output:
|
||
try:
|
||
from agentkit.core.verification_loop import VerificationLoop
|
||
|
||
vloop = VerificationLoop(commands=self._verification_commands)
|
||
vresult = await vloop.verify()
|
||
if not vresult.passed:
|
||
# 将验证失败信息作为 ReActStep 添加到轨迹
|
||
verification_step = ReActStep(
|
||
step=step + 1,
|
||
action="tool_call",
|
||
tool_name="verification",
|
||
arguments={"commands": self._verification_commands},
|
||
result={
|
||
"passed": vresult.passed,
|
||
"errors": vresult.errors,
|
||
"test_output": vresult.test_output,
|
||
},
|
||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||
)
|
||
trajectory.append(verification_step)
|
||
logger.info(
|
||
"Verification failed after final answer, "
|
||
"appended feedback to trajectory"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Verification loop failed: {e}")
|
||
|
||
# 达到 max_steps 时,返回当前最佳输出
|
||
if step >= self._max_steps and not output:
|
||
trace_outcome = "partial"
|
||
# 使用最后一步的内容作为输出
|
||
if trajectory and trajectory[-1].content:
|
||
output = trajectory[-1].content
|
||
elif trajectory and trajectory[-1].result is not None:
|
||
output = str(trajectory[-1].result)
|
||
else:
|
||
output = response.content or ""
|
||
|
||
# 兜底:确保 output 永远不为空字符串
|
||
if not output or not output.strip():
|
||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
|
||
|
||
if step >= self._max_steps:
|
||
output = MAX_STEPS_REACHED
|
||
else:
|
||
output = EMPTY_LLM_RESPONSE
|
||
trace_outcome = "empty_fallback"
|
||
|
||
# 结束轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.end_trace(outcome=trace_outcome)
|
||
|
||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||
try:
|
||
summary = output[:500] if output else ""
|
||
await memory_retriever.store_episode(
|
||
key=f"task:{task_id or 'unknown'}",
|
||
value={"output_summary": summary, "agent_name": agent_name},
|
||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||
|
||
return ReActResult(
|
||
output=output,
|
||
trajectory=trajectory,
|
||
total_steps=len(trajectory),
|
||
total_tokens=total_tokens,
|
||
)
|
||
finally:
|
||
# Telemetry: end span and record duration — always runs
|
||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||
if _span is not None:
|
||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||
_span.set_attribute("agent.outcome", trace_outcome)
|
||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||
if _span_cm is not None:
|
||
_span_cm.__exit__(None, None, None)
|
||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||
|
||
async def execute_stream(
|
||
self,
|
||
messages: list[dict[str, str]],
|
||
tools: list[Tool] | None = None,
|
||
model: str = "default",
|
||
agent_name: str = "",
|
||
task_type: str = "",
|
||
system_prompt: str | None = None,
|
||
trace_recorder: "TraceRecorder | None" = None,
|
||
memory_retriever: "MemoryRetriever | None" = None,
|
||
task_id: str | None = None,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
retrieval_config: dict[str, Any] | None = None,
|
||
cancellation_token: CancellationToken | None = None,
|
||
timeout_seconds: float | None = None,
|
||
confirmation_handler: Any | None = None,
|
||
):
|
||
"""Execute ReAct loop, yielding ReActEvent objects.
|
||
|
||
Same logic as execute() but yields events at each step instead of
|
||
accumulating a result.
|
||
|
||
Args:
|
||
compressor: 压缩策略,None 时使用实例默认压缩器
|
||
"""
|
||
# P2 #9: Reset loop detection state so reuse across conversations is clean
|
||
self.reset()
|
||
effective_compressor = compressor if compressor is not None else self._compressor
|
||
tools = tools or []
|
||
if tools:
|
||
tools = self._maybe_add_tool_search(tools)
|
||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||
if tool_schemas:
|
||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||
logger.info(f"ReActEngine executing with {len(tool_schemas)} tools: {tool_names}")
|
||
else:
|
||
logger.info("ReActEngine executing with NO tools")
|
||
|
||
# Prompt-based tool calling: inject tool descriptions into system prompt
|
||
# when tools are available, so LLM can use <tool_use> format even if
|
||
# the provider doesn't support native function calling.
|
||
if tools and system_prompt is not None:
|
||
tool_desc = self._build_tool_use_prompt(tools)
|
||
system_prompt = f"{system_prompt}\n\n{tool_desc}"
|
||
elif tools and system_prompt is None:
|
||
system_prompt = self._build_tool_use_prompt(tools)
|
||
|
||
# Telemetry: record agent request
|
||
agent_request_counter().add(
|
||
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
|
||
)
|
||
|
||
# Start telemetry span for the entire agent execution
|
||
_span_cm = None
|
||
_span = None
|
||
_exec_start = time.monotonic()
|
||
|
||
if _OTEL_AVAILABLE:
|
||
_span_cm = start_span(
|
||
"agent.execute_stream",
|
||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||
)
|
||
_span = _span_cm.__enter__()
|
||
|
||
# 启动轨迹记录
|
||
if trace_recorder is not None:
|
||
trace_recorder.start_trace(
|
||
task_id="",
|
||
agent_name=agent_name,
|
||
skill_name=task_type or None,
|
||
)
|
||
|
||
# Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message
|
||
# U2/G2: 不再拼到 stable(system_prompt)末尾破坏 cache 前缀,改由 _build_system_message
|
||
# 组装双块结构(stable + volatile),Anthropic provider 在 stable 上加 cache_control。
|
||
memory_context = ""
|
||
if memory_retriever:
|
||
try:
|
||
query = str(messages[-1].get("content", "")) if messages else ""
|
||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||
memory_context = await memory_retriever.get_context_string(
|
||
query=query,
|
||
top_k=top_k,
|
||
token_budget=token_budget,
|
||
) or ""
|
||
except Exception as e:
|
||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||
|
||
conversation: list[dict[str, Any]] = []
|
||
system_content = self._build_system_message(
|
||
stable=system_prompt or "",
|
||
volatile=memory_context,
|
||
model=model,
|
||
)
|
||
if system_content is not None:
|
||
conversation.append({"role": "system", "content": system_content})
|
||
conversation.extend(messages)
|
||
|
||
# Context compression: 压缩超长对话历史
|
||
if effective_compressor:
|
||
try:
|
||
conversation = await effective_compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Context compression failed, continuing with original messages: {e}"
|
||
)
|
||
|
||
trajectory: list[ReActStep] = []
|
||
total_tokens = 0
|
||
step = 0
|
||
output = ""
|
||
trace_outcome = "success"
|
||
_stream_start = time.monotonic()
|
||
effective_timeout = (
|
||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||
)
|
||
|
||
try:
|
||
while step < self._max_steps:
|
||
step += 1
|
||
|
||
# 协作式取消检查
|
||
if cancellation_token is not None:
|
||
cancellation_token.check()
|
||
|
||
# 超时检查
|
||
if effective_timeout > 0:
|
||
elapsed = time.monotonic() - _stream_start
|
||
if elapsed > effective_timeout:
|
||
trace_outcome = "timeout"
|
||
raise asyncio.TimeoutError(
|
||
f"execute_stream exceeded {effective_timeout}s timeout after {elapsed:.1f}s"
|
||
)
|
||
|
||
# Yield thinking event
|
||
yield ReActEvent(
|
||
event_type="thinking",
|
||
step=step,
|
||
data={"message": f"Step {step}: Calling LLM..."},
|
||
)
|
||
|
||
# Think: call LLM (with optional token streaming)
|
||
llm_start = time.monotonic()
|
||
|
||
# Use streaming for token-by-token output
|
||
stream_content_chunks: list[str] = []
|
||
stream_usage = None
|
||
stream_tool_calls: list[Any] = []
|
||
stream_model = model
|
||
# U3/G8: delta_flush 节流 buffer,按 flush_interval_ms 批量 yield
|
||
_flush_buffer: list[str] = []
|
||
_last_flush_ts = time.monotonic()
|
||
|
||
async for chunk in _ensure_async_iterable(
|
||
self._llm_gateway.chat_stream(
|
||
messages=conversation,
|
||
model=model,
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
tools=tool_schemas,
|
||
),
|
||
label=f"llm_gateway.chat_stream(model={model!r})",
|
||
):
|
||
if chunk.content:
|
||
stream_content_chunks.append(chunk.content)
|
||
_flush_buffer.append(chunk.content)
|
||
now = time.monotonic()
|
||
# flush_interval_ms=0 → 逐 chunk yield(向后兼容,条件短路为 True)
|
||
if (
|
||
self._flush_interval_ms == 0
|
||
or now - _last_flush_ts >= self._flush_interval_ms / 1000
|
||
):
|
||
yield ReActEvent(
|
||
event_type="token",
|
||
step=step,
|
||
data={"content": "".join(_flush_buffer)},
|
||
)
|
||
_flush_buffer = []
|
||
_last_flush_ts = now
|
||
if chunk.usage:
|
||
stream_usage = chunk.usage
|
||
if chunk.tool_calls:
|
||
stream_tool_calls = chunk.tool_calls
|
||
if chunk.model:
|
||
stream_model = chunk.model
|
||
|
||
# U3/G8: 流结束 mid-interval → 最终 flush 剩余 buffer(不丢字符)
|
||
if _flush_buffer:
|
||
yield ReActEvent(
|
||
event_type="token",
|
||
step=step,
|
||
data={"content": "".join(_flush_buffer)},
|
||
)
|
||
_flush_buffer = []
|
||
|
||
# Build response-like object from stream
|
||
stream_content = "".join(stream_content_chunks)
|
||
response = self._build_response_from_stream(
|
||
content=stream_content,
|
||
tool_calls=stream_tool_calls,
|
||
usage=stream_usage,
|
||
model=stream_model,
|
||
)
|
||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||
|
||
step_tokens = response.usage.total_tokens
|
||
total_tokens += step_tokens
|
||
|
||
if response.has_tool_calls:
|
||
# 循环检测:检查是否重复调用相同工具+参数
|
||
looped_tool = self._check_tool_loop(response.tool_calls)
|
||
if looped_tool is not None:
|
||
if not self._loop_corrected:
|
||
logger.warning(
|
||
f"Loop detected (stream): tool '{looped_tool}' repeated, "
|
||
f"injecting correction at step {step}"
|
||
)
|
||
correction_msg = {
|
||
"role": "user",
|
||
"content": (
|
||
f"You are repeatedly calling tool '{looped_tool}' "
|
||
f"with the same arguments. This indicates a loop. "
|
||
f"Please change your strategy or provide a final answer."
|
||
),
|
||
}
|
||
conversation.append(correction_msg)
|
||
self._loop_corrected = True
|
||
yield ReActEvent(
|
||
event_type="step",
|
||
step=step,
|
||
data={
|
||
"message": f"Loop detected: tool '{looped_tool}' repeated. Correction injected.",
|
||
"loop_detected": True,
|
||
"tool_name": looped_tool,
|
||
},
|
||
)
|
||
continue
|
||
else:
|
||
raise LoopDetectedError(
|
||
tool_name=looped_tool,
|
||
repetitions=self._loop_threshold + 1,
|
||
)
|
||
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
# Record assistant message
|
||
assistant_msg: dict[str, Any] = {
|
||
"role": "assistant",
|
||
"content": response.content or "",
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.arguments),
|
||
},
|
||
}
|
||
for tc in response.tool_calls
|
||
],
|
||
}
|
||
conversation.append(assistant_msg)
|
||
|
||
# Execute tool calls with parallel support
|
||
if (
|
||
self._parallel_tools
|
||
and len(response.tool_calls) > 1
|
||
and self._should_execute_parallel(response.tool_calls)
|
||
):
|
||
# Parallel execution path
|
||
parallelizable_set = (
|
||
set(self._get_parallelizable_indices(response.tool_calls))
|
||
if self._parallel_tools == "auto"
|
||
else set(range(len(response.tool_calls)))
|
||
)
|
||
serial_calls = [
|
||
(i, tc)
|
||
for i, tc in enumerate(response.tool_calls)
|
||
if i not in parallelizable_set
|
||
]
|
||
parallel_calls = [
|
||
(i, tc)
|
||
for i, tc in enumerate(response.tool_calls)
|
||
if i in parallelizable_set
|
||
]
|
||
|
||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||
|
||
# Execute serial tools first (handles confirmation flow)
|
||
for i, tc in serial_calls:
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||
)
|
||
tool_start = time.monotonic()
|
||
(
|
||
tool_result,
|
||
confirm_events,
|
||
) = await self._execute_tool_with_confirmation(
|
||
tc, tools, step, confirmation_handler
|
||
)
|
||
for ev in confirm_events:
|
||
yield ev
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
all_results[i] = (tc, tool_result, tool_duration_ms)
|
||
|
||
# Execute parallelizable tools concurrently
|
||
if len(parallel_calls) > 1:
|
||
para_results = await asyncio.gather(
|
||
*[
|
||
self._execute_tool(tc.name, tc.arguments, tools)
|
||
for _, tc in parallel_calls
|
||
],
|
||
return_exceptions=True,
|
||
)
|
||
for j, (i, tc) in enumerate(parallel_calls):
|
||
tool_result = para_results[j]
|
||
if isinstance(tool_result, Exception):
|
||
tool_result = {"error": str(tool_result)}
|
||
all_results[i] = (tc, tool_result, 0)
|
||
elif len(parallel_calls) == 1:
|
||
i, tc = parallel_calls[0]
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
all_results[i] = (tc, tool_result, 0)
|
||
|
||
# Process all results in original order
|
||
for i, tc in enumerate(response.tool_calls):
|
||
tc_obj, tool_result, tool_duration_ms = all_results[i]
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||
)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step,
|
||
data={"tool_name": tc.name, "result": tool_result},
|
||
)
|
||
tool_msg = await self._build_tool_result_message(
|
||
tc.id, tool_result, effective_compressor, tc.name
|
||
)
|
||
conversation.append(tool_msg)
|
||
else:
|
||
# Serial execution path (with confirmation flow)
|
||
for tc in response.tool_calls:
|
||
# Yield tool_call event
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||
)
|
||
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
# 检测工具返回的确认请求
|
||
if isinstance(tool_result, dict) and tool_result.get(
|
||
"needs_confirmation"
|
||
):
|
||
confirmation_id = tool_result["confirmation_id"]
|
||
command = tool_result.get("command", "")
|
||
reason = tool_result.get("reason", "")
|
||
|
||
# Yield 确认请求事件
|
||
yield ReActEvent(
|
||
event_type="confirmation_request",
|
||
step=step,
|
||
data={
|
||
"confirmation_id": confirmation_id,
|
||
"tool_name": tc.name,
|
||
"command": command,
|
||
"reason": reason,
|
||
},
|
||
)
|
||
|
||
# 等待用户确认
|
||
approved = False
|
||
if confirmation_handler is not None:
|
||
try:
|
||
approved = await confirmation_handler(
|
||
confirmation_id, command, reason
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Confirmation handler error: {e}")
|
||
|
||
if approved:
|
||
# 用户确认执行:使用 per-call override 绕过安全检查
|
||
tool = self._find_tool(tc.name, tools)
|
||
if tool and hasattr(tool, "_is_dangerous"):
|
||
# Strip internal metadata and pass skip_dangerous_check flag
|
||
clean_args = {
|
||
k: v
|
||
for k, v in tc.arguments.items()
|
||
if not k.startswith("_")
|
||
}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = await tool.safe_execute(**clean_args)
|
||
finally:
|
||
pass # No shared state mutation needed
|
||
else:
|
||
# Non-dangerous tool: re-execute with skip flag
|
||
clean_args = {
|
||
k: v
|
||
for k, v in tc.arguments.items()
|
||
if not k.startswith("_")
|
||
}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = (
|
||
await tool.safe_execute(**clean_args)
|
||
if tool
|
||
else {"error": f"Tool '{tc.name}' not found"}
|
||
)
|
||
except Exception as e:
|
||
tool_result = {
|
||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||
}
|
||
|
||
yield ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={"confirmation_id": confirmation_id, "approved": True},
|
||
)
|
||
else:
|
||
# 用户拒绝执行
|
||
tool_result = {
|
||
"output": "",
|
||
"exit_code": 126,
|
||
"is_error": True,
|
||
"error_type": "permission_denied",
|
||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||
}
|
||
yield ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={
|
||
"confirmation_id": confirmation_id,
|
||
"approved": False,
|
||
},
|
||
)
|
||
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
arguments=tc.arguments,
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=tc.name,
|
||
input_data=tc.arguments,
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
|
||
# Yield tool_result event
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step,
|
||
data={"tool_name": tc.name, "result": tool_result},
|
||
)
|
||
|
||
tool_msg = await self._build_tool_result_message(
|
||
tc.id, tool_result, effective_compressor, tc.name
|
||
)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, effective_compressor):
|
||
try:
|
||
conversation = await effective_compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
|
||
else:
|
||
# Check text parsing mode
|
||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||
if parsed_calls and tools:
|
||
# 记录 LLM 调用步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="llm_call",
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
conversation.append({"role": "assistant", "content": response.content})
|
||
|
||
for pc in parsed_calls:
|
||
yield ReActEvent(
|
||
event_type="tool_call",
|
||
step=step,
|
||
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||
)
|
||
tool_start = time.monotonic()
|
||
tool_result = await self._execute_tool(
|
||
pc["name"], pc["arguments"], tools
|
||
)
|
||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||
trajectory.append(
|
||
ReActStep(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
arguments=pc["arguments"],
|
||
result=tool_result,
|
||
tokens=step_tokens,
|
||
)
|
||
)
|
||
# 记录工具调用步骤
|
||
if trace_recorder is not None:
|
||
tool_error = None
|
||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||
tool_error = tool_result["error"]
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="tool_call",
|
||
tool_name=pc["name"],
|
||
input_data=pc["arguments"],
|
||
output_data=tool_result,
|
||
duration_ms=tool_duration_ms,
|
||
tokens_used=0,
|
||
error=tool_error,
|
||
)
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step,
|
||
data={"tool_name": pc["name"], "result": tool_result},
|
||
)
|
||
tool_msg = await self._build_tool_result_message(
|
||
pc.get("id", f"text_tc_{step}"),
|
||
tool_result,
|
||
effective_compressor,
|
||
pc["name"],
|
||
)
|
||
conversation.append(tool_msg)
|
||
|
||
# Incremental compression: compress conversation if it's getting long
|
||
if self._should_compress(conversation, effective_compressor):
|
||
try:
|
||
conversation = await effective_compressor.compress(conversation)
|
||
except Exception as e:
|
||
logger.warning(f"Incremental compression failed: {e}")
|
||
else:
|
||
# ponytail: 检查是否为畸形工具调用(含 <tool_use> 但解析失败)
|
||
# 如果是,注入纠正消息让模型重试,而不是把原始 XML 作为最终答案泄漏
|
||
if "<tool_use>" in (response.content or ""):
|
||
logger.warning(
|
||
f"Step {step}: content contains <tool_use> but "
|
||
f"parsing failed — injecting correction (stream)"
|
||
)
|
||
conversation.append(
|
||
{"role": "assistant", "content": response.content}
|
||
)
|
||
conversation.append(
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
"你上一次的工具调用格式有误,无法解析。"
|
||
"请使用正确的格式重新调用工具:\n"
|
||
'<tool_use>\n'
|
||
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||
"</tool_use>\n"
|
||
"确保 JSON 完整且不要混入其他标签。"
|
||
),
|
||
}
|
||
)
|
||
yield ReActEvent(
|
||
event_type="step",
|
||
step=step,
|
||
data={"message": "工具调用格式异常,已注入纠正消息"},
|
||
)
|
||
continue
|
||
|
||
# Final answer
|
||
react_step = ReActStep(
|
||
step=step,
|
||
action="final_answer",
|
||
content=response.content,
|
||
tokens=step_tokens,
|
||
)
|
||
trajectory.append(react_step)
|
||
output = response.content or ""
|
||
|
||
# 记录最终答案步骤
|
||
if trace_recorder is not None:
|
||
trace_recorder.record_step(
|
||
step=step,
|
||
action="final_answer",
|
||
output_data={"content": response.content},
|
||
duration_ms=llm_duration_ms,
|
||
tokens_used=step_tokens,
|
||
)
|
||
|
||
yield ReActEvent(
|
||
event_type="final_answer",
|
||
step=step,
|
||
data={
|
||
"output": output,
|
||
"total_steps": len(trajectory),
|
||
"total_tokens": total_tokens,
|
||
},
|
||
)
|
||
break
|
||
|
||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||
if self._verification_enabled and output:
|
||
try:
|
||
from agentkit.core.verification_loop import VerificationLoop
|
||
|
||
vloop = VerificationLoop(commands=self._verification_commands)
|
||
vresult = await vloop.verify()
|
||
if not vresult.passed:
|
||
verification_step = ReActStep(
|
||
step=step + 1,
|
||
action="tool_call",
|
||
tool_name="verification",
|
||
arguments={"commands": self._verification_commands},
|
||
result={
|
||
"passed": vresult.passed,
|
||
"errors": vresult.errors,
|
||
"test_output": vresult.test_output,
|
||
},
|
||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||
)
|
||
trajectory.append(verification_step)
|
||
yield ReActEvent(
|
||
event_type="tool_result",
|
||
step=step + 1,
|
||
data={
|
||
"tool_name": "verification",
|
||
"result": {
|
||
"passed": vresult.passed,
|
||
"errors": vresult.errors,
|
||
"test_output": vresult.test_output,
|
||
},
|
||
},
|
||
)
|
||
logger.info(
|
||
"Verification failed after final answer, "
|
||
"appended feedback to trajectory"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Verification loop failed: {e}")
|
||
|
||
if step >= self._max_steps and not output:
|
||
trace_outcome = "partial"
|
||
if trajectory and trajectory[-1].content:
|
||
output = trajectory[-1].content
|
||
elif trajectory and trajectory[-1].result is not None:
|
||
output = str(trajectory[-1].result)
|
||
else:
|
||
output = response.content or ""
|
||
|
||
yield ReActEvent(
|
||
event_type="final_answer",
|
||
step=step,
|
||
data={
|
||
"output": output,
|
||
"total_steps": len(trajectory),
|
||
"total_tokens": total_tokens,
|
||
"max_steps_reached": True,
|
||
},
|
||
)
|
||
|
||
# 兜底:确保 output 永远不为空字符串
|
||
if not output or not output.strip():
|
||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
|
||
|
||
if step >= self._max_steps:
|
||
output = MAX_STEPS_REACHED
|
||
else:
|
||
output = EMPTY_LLM_RESPONSE
|
||
trace_outcome = "empty_fallback"
|
||
yield ReActEvent(
|
||
event_type="final_answer",
|
||
step=step,
|
||
data={
|
||
"output": output,
|
||
"total_steps": len(trajectory),
|
||
"total_tokens": total_tokens,
|
||
"empty_fallback": True,
|
||
},
|
||
)
|
||
finally:
|
||
# 结束轨迹记录 — always runs even if consumer doesn't fully iterate
|
||
if trace_recorder is not None:
|
||
trace_recorder.end_trace(outcome=trace_outcome)
|
||
|
||
# Telemetry: end span and record duration — always runs
|
||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||
if _span is not None:
|
||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||
_span.set_attribute("agent.outcome", trace_outcome)
|
||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||
if _span_cm is not None:
|
||
_span_cm.__exit__(None, None, None)
|
||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||
|
||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||
try:
|
||
summary = output[:500] if output else ""
|
||
await memory_retriever.store_episode(
|
||
key=f"task:{task_id or 'unknown'}",
|
||
value={"output_summary": summary, "agent_name": agent_name},
|
||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||
|
||
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
||
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
||
schemas = []
|
||
for tool in tools:
|
||
schema = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool.name,
|
||
"description": tool.description,
|
||
"parameters": tool.input_schema or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
schemas.append(schema)
|
||
return schemas
|
||
|
||
def _build_system_message(
|
||
self,
|
||
stable: str,
|
||
volatile: str,
|
||
*,
|
||
model: str,
|
||
) -> str | list[dict[str, Any]] | None:
|
||
"""构建双块结构 system message(stable + volatile)。
|
||
|
||
- prompt_cache_enable=False 或无 stable+volatile → 返回 str(或 None)
|
||
- Anthropic provider → 返回 content blocks 列表,stable 块带 cache_control
|
||
- 其他 provider → 返回字符串拼接(stable + volatile),依赖 stable 前缀命中自动前缀缓存
|
||
|
||
ponytail: 断点数硬编码为 1(stable 层),不暴露配置(YAGNI — 双块结构 >1 无语义)。
|
||
"""
|
||
if not stable and not volatile:
|
||
return None
|
||
if not self._prompt_cache_enable:
|
||
# 退化为字符串拼接(向后兼容,行为同改动前)
|
||
if stable and volatile:
|
||
return f"{stable}\n\n## 参考信息\n{volatile}"
|
||
if volatile:
|
||
return f"## 参考信息\n{volatile}"
|
||
return stable
|
||
|
||
provider_name = self._get_provider_name(model)
|
||
if provider_name == "anthropic":
|
||
blocks: list[dict[str, Any]] = []
|
||
if stable:
|
||
blocks.append({
|
||
"type": "text",
|
||
"text": stable,
|
||
"cache_control": {"type": "ephemeral"},
|
||
})
|
||
if volatile:
|
||
blocks.append({
|
||
"type": "text",
|
||
"text": f"## 参考信息\n{volatile}",
|
||
})
|
||
return blocks if blocks else None
|
||
|
||
# 非 Anthropic:字符串拼接,stable 前缀命中 OpenAI/DashScope 自动前缀缓存
|
||
if stable and volatile:
|
||
return f"{stable}\n\n## 参考信息\n{volatile}"
|
||
if volatile:
|
||
return f"## 参考信息\n{volatile}"
|
||
return stable
|
||
|
||
def _get_provider_name(self, model: str) -> str | None:
|
||
"""通过 gateway 查询 model 对应的 provider 名。失败回退 None(字符串拼接)。"""
|
||
try:
|
||
return self._llm_gateway.get_provider_name_for_model(model)
|
||
except Exception:
|
||
# ponytail: 测试中 gateway 可能是 MagicMock,无该方法;回退保守路径
|
||
return None
|
||
|
||
def _build_tool_use_prompt(self, tools: list[Tool]) -> str:
|
||
"""Build prompt-based tool calling instructions with tiered injection.
|
||
|
||
Core tools (defined by ``self._core_tool_names`` or
|
||
:attr:`_DEFAULT_CORE_TOOLS`) get full descriptions (name +
|
||
description + parameters). Extended tools get only name + a
|
||
one-line description. When ``tool_search`` is present alongside
|
||
extended tools, a hint is added telling the LLM to call
|
||
``tool_search`` for full parameter details.
|
||
|
||
Instructs the LLM to use ``<tool_use>`` XML format for tool
|
||
invocation (Hermes pattern: model-agnostic prompt-based tool calling).
|
||
"""
|
||
core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
|
||
core_tools = [t for t in tools if t.name in core_names]
|
||
extended_tools = [t for t in tools if t.name not in core_names]
|
||
|
||
sections: list[str] = []
|
||
if core_tools:
|
||
sections.append(self._render_core_tools(core_tools))
|
||
if extended_tools:
|
||
sections.append(self._render_extended_tools(extended_tools))
|
||
|
||
tools_text = "\n\n".join(sections)
|
||
|
||
has_tool_search = any(t.name == "tool_search" for t in tools)
|
||
search_hint = ""
|
||
if has_tool_search and extended_tools:
|
||
search_hint = (
|
||
"\n\n注意:上方「扩展工具」仅显示名称和简短描述。"
|
||
'如需使用某个扩展工具,请先调用 tool_search(query="关键词") '
|
||
"获取其完整参数说明。"
|
||
)
|
||
|
||
return (
|
||
"## 可用工具\n\n"
|
||
"你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n"
|
||
"<tool_use>\n"
|
||
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||
"</tool_use>\n\n"
|
||
"重要规则:\n"
|
||
"1. 每次只调用一个工具\n"
|
||
"2. 等待工具返回结果后再决定下一步\n"
|
||
"3. 如果不需要工具就能回答,直接回答即可\n"
|
||
"4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n"
|
||
f"工具列表:\n\n{tools_text}{search_hint}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _render_core_tools(tools: list[Tool]) -> str:
|
||
"""Render core tools with full descriptions (name + description + parameters)."""
|
||
descriptions: list[str] = []
|
||
for tool in tools:
|
||
params_desc = ""
|
||
if tool.input_schema:
|
||
props = tool.input_schema.get("properties", {})
|
||
required = tool.input_schema.get("required", [])
|
||
param_parts: list[str] = []
|
||
for pname, pinfo in props.items():
|
||
ptype = pinfo.get("type", "string")
|
||
pdesc = pinfo.get("description", "")
|
||
req_flag = " (required)" if pname in required else ""
|
||
param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}")
|
||
if param_parts:
|
||
params_desc = "\n".join(param_parts)
|
||
descriptions.append(f"- {tool.name}: {tool.description}\n{params_desc}")
|
||
return "### 核心工具(完整描述)\n\n" + "\n\n".join(descriptions)
|
||
|
||
@staticmethod
|
||
def _render_extended_tools(tools: list[Tool]) -> str:
|
||
"""Render extended tools with name + one-line description only."""
|
||
lines: list[str] = []
|
||
for tool in tools:
|
||
desc = tool.description.strip().split("\n")[0]
|
||
if len(desc) > 100:
|
||
desc = desc[:97] + "..."
|
||
lines.append(f"- {tool.name}: {desc}")
|
||
return "### 扩展工具(仅名称和简短描述,使用 tool_search 获取详情)\n\n" + "\n".join(lines)
|
||
|
||
def _maybe_add_tool_search(self, tools: list[Tool]) -> list[Tool]:
|
||
"""Add ``tool_search`` tool if enabled and there are extended tools.
|
||
|
||
Builds a :class:`ToolSearchIndex` from the extended tools so the
|
||
LLM can discover full tool descriptions on demand via BM25 search.
|
||
If all tools are core tools, or ``tool_search`` is already present,
|
||
or ``enable_tool_search`` is False, the list is returned unchanged.
|
||
"""
|
||
if not self._enable_tool_search:
|
||
return tools
|
||
if any(t.name == "tool_search" for t in tools):
|
||
return tools
|
||
|
||
core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
|
||
extended_tools = [t for t in tools if t.name not in core_names]
|
||
if not extended_tools:
|
||
return tools
|
||
|
||
from agentkit.tools.builtin import ToolSearchTool
|
||
from agentkit.tools.search import ToolSearchIndex
|
||
|
||
index = ToolSearchIndex(extended_tools)
|
||
search_tool = ToolSearchTool(search_index=index)
|
||
return tools + [search_tool]
|
||
|
||
@staticmethod
|
||
def _build_response_from_stream(
|
||
content: str,
|
||
tool_calls: list[Any],
|
||
usage: Any,
|
||
model: str,
|
||
) -> LLMResponse:
|
||
"""Build an LLMResponse from accumulated stream chunks."""
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
|
||
if usage is None:
|
||
usage = TokenUsage()
|
||
return LLMResponse(
|
||
content=content,
|
||
tool_calls=tool_calls,
|
||
usage=usage,
|
||
model=model,
|
||
)
|
||
|
||
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
|
||
"""根据名称从可用工具中查找工具"""
|
||
for tool in tools:
|
||
if tool.name == name:
|
||
return tool
|
||
return None
|
||
|
||
# Default token threshold for incremental compression
|
||
_DEFAULT_COMPRESS_THRESHOLD = 8000
|
||
|
||
def _should_compress(
|
||
self, conversation: list[dict], compressor: "CompressionStrategy | None"
|
||
) -> bool:
|
||
"""检查是否需要增量压缩"""
|
||
if not compressor:
|
||
return False
|
||
# U3: Skip if compressor reports unavailable
|
||
is_available_fn = getattr(compressor, "is_available", None)
|
||
if is_available_fn is not None and not is_available_fn():
|
||
return False
|
||
# U3: Delegate to compressor's headroom-based should_compress if available
|
||
should_compress_fn = getattr(compressor, "should_compress", None)
|
||
if should_compress_fn is not None:
|
||
return should_compress_fn(conversation)
|
||
# Fallback: fixed threshold for compressors without headroom support
|
||
total_chars = sum(len(str(m.get("content", ""))) for m in conversation)
|
||
estimated_tokens = total_chars // 4
|
||
return estimated_tokens > self._DEFAULT_COMPRESS_THRESHOLD
|
||
|
||
async def _build_tool_result_message(
|
||
self,
|
||
tool_call_id: str,
|
||
result: Any,
|
||
compressor: "CompressionStrategy | None" = None,
|
||
tool_name: str | None = None,
|
||
) -> dict:
|
||
"""构建工具结果消息用于对话历史"""
|
||
content = str(result)
|
||
if compressor and tool_name:
|
||
try:
|
||
content = await compressor.compress_tool_result(tool_name, result)
|
||
except Exception as e:
|
||
logger.warning(f"Tool result compression failed for '{tool_name}': {e}")
|
||
content = str(result)
|
||
return {
|
||
"role": "tool",
|
||
"tool_call_id": tool_call_id,
|
||
"content": content,
|
||
}
|
||
|
||
async def _execute_tool(
|
||
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
|
||
) -> dict:
|
||
"""执行工具调用,处理成功和失败情况"""
|
||
tool = self._find_tool(tool_name, tools)
|
||
if tool is None:
|
||
error_msg = f"Tool '{tool_name}' not found"
|
||
logger.warning(error_msg)
|
||
return {"error": error_msg}
|
||
|
||
# Strip internal metadata keys before passing to tool
|
||
clean_args = {k: v for k, v in arguments.items() if not k.startswith("_")}
|
||
|
||
try:
|
||
result = await tool.safe_execute(**clean_args)
|
||
return result
|
||
except ToolValidationError as e:
|
||
# 保留类型化错误码,不被通用 except 平坦化为字符串
|
||
error_msg = f"Tool '{tool_name}' schema validation failed: {e}"
|
||
logger.warning(error_msg)
|
||
return {
|
||
"error": str(e),
|
||
"error_code": e.error_code,
|
||
"details": e.details,
|
||
}
|
||
except Exception as e:
|
||
error_msg = f"Tool '{tool_name}' execution failed: {e}"
|
||
logger.warning(error_msg)
|
||
return {"error": error_msg}
|
||
|
||
async def _execute_tool_with_confirmation(
|
||
self,
|
||
tc: Any,
|
||
tools: list[Tool],
|
||
step: int,
|
||
confirmation_handler: Any,
|
||
) -> tuple[Any, list[ReActEvent]]:
|
||
"""Execute a tool call with confirmation flow support.
|
||
|
||
Used in the parallel execution path for serial (non-parallelizable) tools
|
||
that may require user confirmation before execution.
|
||
|
||
Returns:
|
||
Tuple of (tool_result, list of ReActEvents to yield)
|
||
"""
|
||
events: list[ReActEvent] = []
|
||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||
|
||
# Check if tool returned a confirmation request
|
||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||
confirmation_id = tool_result["confirmation_id"]
|
||
command = tool_result.get("command", "")
|
||
reason = tool_result.get("reason", "")
|
||
|
||
events.append(
|
||
ReActEvent(
|
||
event_type="confirmation_request",
|
||
step=step,
|
||
data={
|
||
"confirmation_id": confirmation_id,
|
||
"tool_name": tc.name,
|
||
"command": command,
|
||
"reason": reason,
|
||
},
|
||
)
|
||
)
|
||
|
||
# Wait for user confirmation
|
||
approved = False
|
||
if confirmation_handler is not None:
|
||
try:
|
||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||
except Exception as e:
|
||
logger.warning(f"Confirmation handler error: {e}")
|
||
|
||
if approved:
|
||
# User approved: re-execute with _skip_dangerous_check
|
||
tool = self._find_tool(tc.name, tools)
|
||
if tool and hasattr(tool, "_is_dangerous"):
|
||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = await tool.safe_execute(**clean_args)
|
||
except Exception as e:
|
||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||
else:
|
||
# Non-dangerous tool: re-execute with skip flag
|
||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||
clean_args["_skip_dangerous_check"] = True
|
||
try:
|
||
tool_result = (
|
||
await tool.safe_execute(**clean_args)
|
||
if tool
|
||
else {"error": f"Tool '{tc.name}' not found"}
|
||
)
|
||
except Exception as e:
|
||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||
|
||
events.append(
|
||
ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={"confirmation_id": confirmation_id, "approved": True},
|
||
)
|
||
)
|
||
else:
|
||
# User rejected
|
||
tool_result = {
|
||
"output": "",
|
||
"exit_code": 126,
|
||
"is_error": True,
|
||
"error_type": "permission_denied",
|
||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||
}
|
||
events.append(
|
||
ReActEvent(
|
||
event_type="confirmation_result",
|
||
step=step,
|
||
data={"confirmation_id": confirmation_id, "approved": False},
|
||
)
|
||
)
|
||
|
||
return tool_result, events
|
||
|
||
def _should_execute_parallel(self, tool_calls: list[Any]) -> bool:
|
||
"""Determine if tool calls should be executed in parallel.
|
||
|
||
- parallel_tools=True: always parallel (if >1 tool)
|
||
- parallel_tools=False: never parallel
|
||
- parallel_tools="auto": parallel if any tool_call has _parallelizable=true in arguments
|
||
"""
|
||
if len(tool_calls) <= 1:
|
||
return False
|
||
if self._parallel_tools is True:
|
||
return True
|
||
if self._parallel_tools is False:
|
||
return False
|
||
# "auto" mode: check _parallelizable metadata in tool call arguments
|
||
if self._parallel_tools == "auto":
|
||
parallelizable_indices = self._get_parallelizable_indices(tool_calls)
|
||
return len(parallelizable_indices) > 1
|
||
return False
|
||
|
||
def _get_parallelizable_indices(self, tool_calls: list[Any]) -> list[int]:
|
||
"""Get indices of tool_calls that have _parallelizable=true in arguments.
|
||
|
||
LLM marks parallelizable tools by including _parallelizable: true
|
||
in the tool_call arguments.
|
||
"""
|
||
indices = []
|
||
for i, tc in enumerate(tool_calls):
|
||
args = tc.arguments if hasattr(tc, "arguments") else {}
|
||
if isinstance(args, dict) and args.get("_parallelizable") is True:
|
||
indices.append(i)
|
||
return indices
|
||
|
||
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
|
||
"""从文本中解析工具调用模式
|
||
|
||
支持格式:
|
||
1. Action: tool_name(args)
|
||
2. ```tool\n{"name": "...", "arguments": {...}}\n```
|
||
3. <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
|
||
"""
|
||
calls: list[dict[str, Any]] = []
|
||
|
||
# 格式 1: Action: tool_name(args)
|
||
action_pattern = re.compile(r"Action:\s*(\w+)\((.+?)\)", re.DOTALL)
|
||
for match in action_pattern.finditer(content):
|
||
name = match.group(1)
|
||
args_str = match.group(2)
|
||
try:
|
||
arguments = json.loads(args_str)
|
||
except (json.JSONDecodeError, TypeError):
|
||
arguments = {"raw_input": args_str}
|
||
calls.append({"name": name, "arguments": arguments})
|
||
|
||
if calls:
|
||
return calls
|
||
|
||
# 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n```
|
||
code_block_pattern = re.compile(r"```tool\s*\n(.*?)\n\s*```", re.DOTALL)
|
||
for match in code_block_pattern.finditer(content):
|
||
json_str = match.group(1).strip()
|
||
try:
|
||
parsed = json.loads(json_str)
|
||
name = parsed.get("name", "")
|
||
arguments = parsed.get("arguments", {})
|
||
if name:
|
||
calls.append({"name": name, "arguments": arguments})
|
||
except (json.JSONDecodeError, TypeError):
|
||
logger.warning(f"Failed to parse tool call from text: {json_str}")
|
||
|
||
if calls:
|
||
return calls
|
||
|
||
# 格式 3: <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
|
||
# 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式
|
||
tool_use_pattern = re.compile(r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL)
|
||
for match in tool_use_pattern.finditer(content):
|
||
json_str = match.group(1).strip()
|
||
try:
|
||
parsed = json.loads(json_str)
|
||
name = parsed.get("name", "")
|
||
arguments = parsed.get("arguments", {})
|
||
if name:
|
||
calls.append({"name": name, "arguments": arguments})
|
||
except (json.JSONDecodeError, TypeError):
|
||
# Try XML-like inner tags: <name>x</name><arguments>{...}</arguments>
|
||
name_match = re.search(r"<name>\s*(.*?)\s*</name>", json_str, re.DOTALL)
|
||
args_match = re.search(r"<arguments>\s*(.*?)\s*</arguments>", json_str, re.DOTALL)
|
||
if name_match:
|
||
name = name_match.group(1).strip()
|
||
args_str = args_match.group(1).strip() if args_match else "{}"
|
||
try:
|
||
arguments = json.loads(args_str)
|
||
except (json.JSONDecodeError, TypeError):
|
||
arguments = {"raw": args_str}
|
||
calls.append({"name": name, "arguments": arguments})
|
||
else:
|
||
logger.warning(f"Failed to parse tool_use block: {json_str[:200]}")
|
||
|
||
if calls:
|
||
return calls
|
||
|
||
# 格式 4: 畸形 <tool_use> — 缺少闭合标签或 JSON 被截断/混入杂标签
|
||
# 兜底解析:从 <tool_use> 后提取 JSON 片段,用大括号匹配法恢复完整 JSON
|
||
open_pattern = re.compile(r"<tool_use>\s*", re.IGNORECASE)
|
||
for match in open_pattern.finditer(content):
|
||
remainder = content[match.end():]
|
||
parsed = self._extract_tool_call_from_malformed(remainder)
|
||
if parsed:
|
||
calls.append(parsed)
|
||
|
||
return calls
|
||
|
||
@staticmethod
|
||
def _extract_tool_call_from_malformed(text: str) -> dict[str, Any] | None:
|
||
"""从畸形文本中尝试提取工具调用。
|
||
|
||
处理场景:
|
||
1. JSON 被截断(缺少闭合大括号)
|
||
2. JSON 中混入 <parameter> 等 XML 标签
|
||
3. 完全无法解析时返回 None
|
||
"""
|
||
# 尝试用大括号匹配提取第一个 JSON 对象
|
||
brace_start = text.find("{")
|
||
if brace_start == -1:
|
||
return None
|
||
|
||
depth = 0
|
||
json_end = -1
|
||
in_string = False
|
||
escape = False
|
||
for i in range(brace_start, len(text)):
|
||
ch = text[i]
|
||
if escape:
|
||
escape = False
|
||
continue
|
||
if ch == "\\":
|
||
escape = True
|
||
continue
|
||
if ch == '"':
|
||
in_string = not in_string
|
||
continue
|
||
if in_string:
|
||
continue
|
||
if ch == "{":
|
||
depth += 1
|
||
elif ch == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
json_end = i + 1
|
||
break
|
||
|
||
if json_end == -1:
|
||
# JSON 被截断 — 尝试补全大括号后解析
|
||
json_str = text[brace_start:].strip()
|
||
# 截断掉非 JSON 尾部(如 </parameter>, <function> 等)
|
||
cut = json_str.find("}")
|
||
if cut != -1:
|
||
json_str = json_str[: cut + 1]
|
||
else:
|
||
# 补全缺失的大括号
|
||
open_braces = json_str.count("{") - json_str.count("}")
|
||
json_str = json_str + "}" * max(open_braces, 0)
|
||
else:
|
||
json_str = text[brace_start:json_end]
|
||
|
||
try:
|
||
parsed = json.loads(json_str)
|
||
name = parsed.get("name", "")
|
||
arguments = parsed.get("arguments", {})
|
||
if name:
|
||
return {"name": name, "arguments": arguments}
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
# 最终兜底:用正则提取 name 和已知的参数字段
|
||
name_match = re.search(r'"name"\s*:\s*"([^"]+)"', text)
|
||
if not name_match:
|
||
return None
|
||
name = name_match.group(1)
|
||
|
||
arguments: dict[str, Any] = {}
|
||
# 提取 "key": "value" 模式
|
||
for kv_match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', text):
|
||
key = kv_match.group(1)
|
||
if key in ("name",):
|
||
continue
|
||
arguments[key] = kv_match.group(2)
|
||
|
||
# 提取 <parameter=key>value</parameter> 模式
|
||
for pm in re.finditer(r"<parameter=(\w+)>\s*(.*?)\s*</parameter>", text, re.DOTALL):
|
||
arguments[pm.group(1)] = pm.group(2).strip()
|
||
|
||
if name:
|
||
return {"name": name, "arguments": arguments}
|
||
|
||
return None
|