From e61f98898f82994895c53899a2a2e79640f9f76f Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 16:07:00 +0800 Subject: [PATCH] refactor(core): unify ReActEngine execute/execute_stream via async generator (U1) - Convert _execute_loop to async generator yielding ReActEvent; both execute and execute_stream delegate to it, eliminating ~760 lines of duplicated loop logic (execute_stream 813 -> 53 lines). - Add 'final_result' event_type carrying ReActResult; execute extracts result from final event, execute_stream forwards events (backward-compatible 'final_answer' retained). - Unify _drain_phase_violations across both paths. - Add 14 golden-trajectory characterization tests. - Fix test_execute_stream_with_compressor mock gateway (chat_stream test-infra gap). 130 react tests pass, 762 core+experts pass, no regressions. --- src/agentkit/core/react.py | 1102 ++++++-------------- tests/unit/test_react_compression.py | 33 +- tests/unit/test_react_golden_trajectory.py | 617 +++++++++++ 3 files changed, 947 insertions(+), 805 deletions(-) create mode 100644 tests/unit/test_react_golden_trajectory.py diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 8716df9..8d84faa 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -10,6 +10,7 @@ import logging import re import time from collections import Counter, deque +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any @@ -130,7 +131,7 @@ class ReActResult: class ReActEvent: """ReAct 执行事件""" - event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error" + event_type: str # "thinking","token","tool_call","tool_result","confirmation_request","confirmation_result","phase_violation","step","final_answer","final_result","error" step: int data: dict[str, Any] = field(default_factory=dict) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -421,9 +422,10 @@ class ReActEngine: compressor: 压缩策略,None 时使用实例默认压缩器 cancellation_token: 协作式取消令牌,每次循环迭代检查是否已取消 timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout + + U1: execute() 现在通过 _run_loop_and_extract 收集 _execute_loop async + generator 产出的事件,并从最后的 'final_result' 事件提取 ReActResult。 """ - # P2 #9: Reset loop detection state so reuse across conversations is clean - self.reset() effective_compressor = compressor if compressor is not None else self._compressor effective_timeout = ( timeout_seconds if timeout_seconds is not None else self._default_timeout @@ -446,7 +448,7 @@ class ReActEngine: ) async def _handler(c: RequestContext) -> ReActResult: - return await self._execute_loop( + return await self._run_loop_and_extract( messages=c.messages, tools=c.tools or None, model=c.model, @@ -460,6 +462,8 @@ class ReActEngine: retrieval_config=retrieval_config, cancellation_token=cancellation_token, confirmation_handler=confirmation_handler, + stream=False, + effective_timeout=effective_timeout, ) try: @@ -483,7 +487,7 @@ class ReActEngine: try: if effective_timeout > 0: result = await asyncio.wait_for( - self._execute_loop( + self._run_loop_and_extract( messages=messages, tools=tools, model=model, @@ -497,11 +501,13 @@ class ReActEngine: retrieval_config=retrieval_config, cancellation_token=cancellation_token, confirmation_handler=confirmation_handler, + stream=False, + effective_timeout=effective_timeout, ), timeout=effective_timeout, ) else: - result = await self._execute_loop( + result = await self._run_loop_and_extract( messages=messages, tools=tools, model=model, @@ -515,6 +521,8 @@ class ReActEngine: retrieval_config=retrieval_config, cancellation_token=cancellation_token, confirmation_handler=confirmation_handler, + stream=False, + effective_timeout=effective_timeout, ) except asyncio.TimeoutError: raise TaskTimeoutError( @@ -526,6 +534,24 @@ class ReActEngine: return result + async def _run_loop_and_extract( + self, + **kwargs: Any, + ) -> ReActResult: + """Collect all events from _execute_loop and extract the final ReActResult. + + This is the bridge between the async generator _execute_loop and the + coroutine-based execute() method. It fully iterates the generator and + extracts the ReActResult from the final 'final_result' event. + """ + final_result: ReActResult | None = None + async for event in self._execute_loop(**kwargs): + if event.event_type == "final_result": + final_result = event.data["result"] + if final_result is None: + raise RuntimeError("_execute_loop did not yield a final_result event") + return final_result + async def _execute_loop( self, messages: list[dict[str, str]], @@ -541,7 +567,30 @@ class ReActEngine: retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, confirmation_handler: Any | None = None, - ) -> ReActResult: + stream: bool = False, + effective_timeout: float = 0.0, + ) -> AsyncGenerator[ReActEvent, None]: + """Unified ReAct loop — async generator yielding ReActEvent objects. + + When stream=False: uses gateway.chat() (non-streaming), no token events. + When stream=True: uses gateway.chat_stream() (streaming), yields token + events, checks timeout inside the loop. + + Always yields a 'final_result' event at the end with + data={'result': ReActResult}. Callers that need the ReActResult + (execute) collect all events and extract the final_result. Callers + that need streaming (execute_stream) transparently pass through + all events. + + Args: + compressor: 压缩策略(caller 负责 computing effective_compressor) + cancellation_token: 协作式取消令牌 + stream: True 用 chat_stream(流式),False 用 chat(非流式) + effective_timeout: 超时秒数;stream=True 时在循环内检查, + stream=False 时由 caller 的 asyncio.wait_for 强制 + """ + # P2 #9: Reset loop detection state so reuse across conversations is clean + self.reset() tools = tools or [] if tools: tools = self._maybe_add_tool_search(tools) @@ -573,7 +622,7 @@ class ReActEngine: if _OTEL_AVAILABLE: _span_cm = start_span( - "agent.execute", + "agent.execute_stream" if stream else "agent.execute", attributes={"agent.name": agent_name, "agent.type": task_type or "react"}, ) _span = _span_cm.__enter__() @@ -582,6 +631,9 @@ class ReActEngine: trajectory: list[ReActStep] = [] total_tokens = 0 trace_outcome = "error" + output = "" + step = 0 + response: LLMResponse | None = None try: # 启动轨迹记录 @@ -593,7 +645,6 @@ class ReActEngine: ) # Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message - # U2/G2: 不再拼到 stable(system_prompt)末尾,改由 _build_system_message 组装双块结构 memory_context = "" if memory_retriever: try: @@ -634,10 +685,9 @@ class ReActEngine: ) trace_outcome = "success" - step = 0 - output = "" # U4/G1: verify 失败回灌计数器。受 max_steps 上限约束(不无限循环)。 reinjections = 0 + _loop_start = time.monotonic() while step < self._max_steps: step += 1 @@ -647,671 +697,13 @@ class ReActEngine: cancellation_token.check() # U3/G6: phase auto-advance safety net. - # Incremented per step (LLM call), not per tool_call. When - # auto_advance_after_steps is set, advance the phase after - # the LLM has been stuck in the same phase for N steps. if self._phase_policy is not None: self._steps_in_phase += 1 self._maybe_auto_advance() - # Think: 调用 LLM - llm_start = time.monotonic() - response = await self._llm_gateway.chat( - messages=conversation, - model=model, - agent_name=agent_name, - task_type=task_type, - tools=tool_schemas, - ) - llm_duration_ms = int((time.monotonic() - llm_start) * 1000) - - step_tokens = response.usage.total_tokens - total_tokens += step_tokens - - # 检查是否有 Function Calling 的 tool_calls - if response.has_tool_calls: - # 循环检测:检查是否重复调用相同工具+参数 - looped_tool = self._check_tool_loop(response.tool_calls) - if looped_tool is not None: - if not self._loop_corrected: - # 第一次检测:注入纠正消息,给 LLM 改变策略的机会 - logger.warning( - f"Loop detected: tool '{looped_tool}' repeated, " - f"injecting correction at step {step}" - ) - correction_msg = { - "role": "user", - "content": ( - f"You are repeatedly calling tool '{looped_tool}' " - f"with the same arguments. This indicates a loop. " - f"Please change your strategy or provide a final answer." - ), - } - conversation.append(correction_msg) - self._loop_corrected = True - continue - else: - # 第二次检测:纠正后仍未改变,强制中断 - raise LoopDetectedError( - tool_name=looped_tool, - repetitions=self._loop_threshold + 1, - ) - - # 记录 LLM 调用步骤 - if trace_recorder is not None: - trace_recorder.record_step( - step=step, - action="llm_call", - duration_ms=llm_duration_ms, - tokens_used=step_tokens, - ) - - # Act: 执行工具调用 - # 先记录 assistant 消息(含 tool_calls)到对话历史 - assistant_msg: dict[str, Any] = { - "role": "assistant", - "content": response.content or "", - "tool_calls": [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.arguments), - }, - } - for tc in response.tool_calls - ], - } - conversation.append(assistant_msg) - - # 执行工具调用 - if self._parallel_tools == "auto" and len(response.tool_calls) > 1: - # Auto mode: mixed parallel/serial based on _parallelizable flag - parallelizable_set = set( - self._get_parallelizable_indices(response.tool_calls) - ) - serial_calls = [ - (i, tc) - for i, tc in enumerate(response.tool_calls) - if i not in parallelizable_set - ] - parallel_calls = [ - (i, tc) - for i, tc in enumerate(response.tool_calls) - if i in parallelizable_set - ] - - # Result slots indexed by original position - all_results: list[Any] = [None] * len(response.tool_calls) - - # Execute serial tools first (in order) - for i, tc in serial_calls: - tool_start = time.monotonic() - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - all_results[i] = (tc, tool_result, tool_duration_ms) - - # Execute parallelizable tools in parallel - if len(parallel_calls) > 1: - para_results = await asyncio.gather( - *[ - self._execute_tool(tc.name, tc.arguments, tools) - for _, tc in parallel_calls - ], - return_exceptions=True, - ) - for j, (i, tc) in enumerate(parallel_calls): - tool_result = para_results[j] - if isinstance(tool_result, Exception): - tool_result = {"error": str(tool_result)} - all_results[i] = (tc, tool_result, 0) - elif len(parallel_calls) == 1: - i, tc = parallel_calls[0] - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - all_results[i] = (tc, tool_result, 0) - - # Process all results in original order - for i, tc in enumerate(response.tool_calls): - tc_obj, tool_result, tool_duration_ms = all_results[i] - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=tc.name, - arguments=tc.arguments, - result=tool_result, - tokens=step_tokens, - ) - trajectory.append(react_step) - - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( - step=step, - action="tool_call", - tool_name=tc.name, - input_data=tc.arguments, - output_data=tool_result, - duration_ms=tool_duration_ms, - tokens_used=0, - error=tool_error, - ) - - tool_msg = await self._build_tool_result_message( - tc.id, tool_result, compressor, tc.name - ) - conversation.append(tool_msg) - elif self._should_execute_parallel(response.tool_calls): - # 并行执行多个工具调用 (parallel_tools=True) - tool_results = await asyncio.gather( - *[ - self._execute_tool(tc.name, tc.arguments, tools) - for tc in response.tool_calls - ], - return_exceptions=True, - ) - for idx, tc in enumerate(response.tool_calls): - tool_result = tool_results[idx] - if isinstance(tool_result, Exception): - tool_result = {"error": str(tool_result)} - - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=tc.name, - arguments=tc.arguments, - result=tool_result, - tokens=step_tokens, - ) - trajectory.append(react_step) - - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( - step=step, - action="tool_call", - tool_name=tc.name, - input_data=tc.arguments, - output_data=tool_result, - duration_ms=0, - tokens_used=0, - error=tool_error, - ) - - tool_msg = await self._build_tool_result_message( - tc.id, tool_result, compressor, tc.name - ) - conversation.append(tool_msg) - else: - # 串行执行(单工具或 parallel_tools=False) - for tc in response.tool_calls: - tool_start = time.monotonic() - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - - # Handle confirmation flow - if isinstance(tool_result, dict) and tool_result.get( - "needs_confirmation" - ): - confirmation_id = tool_result["confirmation_id"] - command = tool_result.get("command", "") - reason = tool_result.get("reason", "") - - approved = False - if confirmation_handler is not None: - try: - approved = await confirmation_handler( - confirmation_id, command, reason - ) - except Exception as e: - logger.warning(f"Confirmation handler error: {e}") - - if approved: - tool = self._find_tool(tc.name, tools) - if tool and hasattr(tool, "_is_dangerous"): - clean_args = { - k: v - for k, v in tc.arguments.items() - if not k.startswith("_") - } - clean_args["_skip_dangerous_check"] = True - try: - tool_result = await tool.safe_execute(**clean_args) - except Exception as e: - tool_result = { - "error": f"Tool '{tc.name}' execution failed: {e}" - } - else: - # Non-dangerous tool: confirmation was for the overall action, - # re-execute with skip flag to avoid re-triggering confirmation - clean_args = { - k: v - for k, v in tc.arguments.items() - if not k.startswith("_") - } - clean_args["_skip_dangerous_check"] = True - try: - tool_result = ( - await tool.safe_execute(**clean_args) - if tool - else {"error": f"Tool '{tc.name}' not found"} - ) - except Exception as e: - tool_result = { - "error": f"Tool '{tc.name}' execution failed: {e}" - } - else: - tool_result = { - "output": "", - "exit_code": 126, - "is_error": True, - "error_type": "permission_denied", - "message": f"用户拒绝执行命令: {command[:100]}", - } - - tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=tc.name, - arguments=tc.arguments, - result=tool_result, - tokens=step_tokens, - ) - trajectory.append(react_step) - - # 记录工具调用步骤 - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( - step=step, - action="tool_call", - tool_name=tc.name, - input_data=tc.arguments, - output_data=tool_result, - duration_ms=tool_duration_ms, - tokens_used=0, - error=tool_error, - ) - - # Observe: 将工具结果添加到对话历史 - tool_msg = await self._build_tool_result_message( - tc.id, tool_result, compressor, tc.name - ) - conversation.append(tool_msg) - - # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, compressor): - try: - conversation = await compressor.compress(conversation) - except Exception as e: - logger.warning(f"Incremental compression failed: {e}") - - else: - # 检查文本解析模式 - parsed_calls = self._parse_text_tool_calls(response.content or "") - if parsed_calls and tools: - # 记录 LLM 调用步骤 - if trace_recorder is not None: - trace_recorder.record_step( - step=step, - action="llm_call", - duration_ms=llm_duration_ms, - tokens_used=step_tokens, - ) - - # 文本解析模式执行工具 - conversation.append({"role": "assistant", "content": response.content}) - - for pc in parsed_calls: - tool_start = time.monotonic() - tool_result = await self._execute_tool( - pc["name"], pc["arguments"], tools - ) - tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=pc["name"], - arguments=pc["arguments"], - result=tool_result, - tokens=step_tokens, - ) - trajectory.append(react_step) - - # 记录工具调用步骤 - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( - step=step, - action="tool_call", - tool_name=pc["name"], - input_data=pc["arguments"], - output_data=tool_result, - duration_ms=tool_duration_ms, - tokens_used=0, - error=tool_error, - ) - - # 将工具结果添加到对话历史 - tool_msg = await self._build_tool_result_message( - pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"] - ) - conversation.append(tool_msg) - - # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, compressor): - try: - conversation = await compressor.compress(conversation) - except Exception as e: - logger.warning(f"Incremental compression failed: {e}") - else: - # ponytail: 检查是否为畸形工具调用(含 但解析失败) - # 如果是,注入纠正消息让模型重试,而不是把原始 XML 作为最终答案泄漏 - if "" in (response.content or ""): - logger.warning( - f"Step {step}: content contains but " - f"parsing failed — injecting correction" - ) - conversation.append({"role": "assistant", "content": response.content}) - conversation.append( - { - "role": "user", - "content": ( - "你上一次的工具调用格式有误,无法解析。" - "请使用正确的格式重新调用工具:\n" - "\n" - '{"name": "工具名", "arguments": {"参数名": "参数值"}}\n' - "\n" - "确保 JSON 完整且不要混入其他标签。" - ), - } - ) - continue - - # Final answer: LLM 没有调用工具,返回最终答案 - react_step = ReActStep( - step=step, - action="final_answer", - content=response.content, - tokens=step_tokens, - ) - trajectory.append(react_step) - output = response.content or "" - - # 记录最终答案步骤 - if trace_recorder is not None: - trace_recorder.record_step( - step=step, - action="final_answer", - output_data={"content": response.content}, - duration_ms=llm_duration_ms, - tokens_used=step_tokens, - ) - - # U4/G1: verify at final-answer point with reinjection. - # 原为循环后一次性运行;现改为循环内检测 final answer 后立即 verify, - # 失败则把 errors 作为 user 消息回灌 conversation,continue 主循环让 LLM 自纠正。 - # max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。 - if self._verification_enabled and output: - try: - from agentkit.core.verification_loop import VerificationLoop - - vloop = VerificationLoop(commands=self._verification_commands) - vresult = await vloop.verify() - if not vresult.passed: - if ( - reinjections < self._max_reinjections - and step < self._max_steps - ): - # 回灌 errors 作为 user 消息,让 LLM 自纠正 - errors_text = "\n".join(vresult.errors) - conversation.append( - { - "role": "user", - "content": (f"验证失败,错误如下:\n{errors_text}"), - } - ) - reinjections += 1 - logger.info( - "Verification failed (reinjection %d/%d), " - "errors injected into conversation", - reinjections, - self._max_reinjections, - ) - continue - # 达到 max_reinjections 或 max_steps → 记录 verify log 并中断 - verification_step = ReActStep( - step=step, - action="tool_call", - tool_name="verification", - arguments={"commands": self._verification_commands}, - result={ - "passed": vresult.passed, - "errors": vresult.errors, - "test_output": vresult.test_output, - }, - content=( - f"Verification failed:\n{vresult.test_output[:2000]}" - ), - ) - trajectory.append(verification_step) - trace_outcome = "verify_failed" - logger.info( - "Verification failed after %d reinjections, " - "interrupting with verify log", - reinjections, - ) - break - except Exception as e: - logger.warning(f"Verification loop failed: {e}") - - break # verify 通过或未启用 → 正常退出 - - # 达到 max_steps 时,返回当前最佳输出 - if step >= self._max_steps and not output: - trace_outcome = "partial" - # 使用最后一步的内容作为输出 - if trajectory and trajectory[-1].content: - output = trajectory[-1].content - elif trajectory and trajectory[-1].result is not None: - output = str(trajectory[-1].result) - else: - output = response.content or "" - - # 兜底:确保 output 永远不为空字符串 - if not output or not output.strip(): - from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED - - if step >= self._max_steps: - output = MAX_STEPS_REACHED - else: - output = EMPTY_LLM_RESPONSE - trace_outcome = "empty_fallback" - - # 结束轨迹记录 - if trace_recorder is not None: - trace_recorder.end_trace(outcome=trace_outcome) - - # Memory storage: 执行后写入轨迹摘要到 EpisodicMemory - if memory_retriever and hasattr(memory_retriever, "store_episode"): - try: - summary = output[:500] if output else "" - await memory_retriever.store_episode( - key=f"task:{task_id or 'unknown'}", - value={"output_summary": summary, "agent_name": agent_name}, - metadata={"task_type": task_type, "outcome": trace_outcome}, - ) - except Exception as e: - logger.warning(f"Failed to store task result in episodic memory: {e}") - - return ReActResult( - output=output, - trajectory=trajectory, - total_steps=len(trajectory), - total_tokens=total_tokens, - status=trace_outcome, - ) - finally: - # Telemetry: end span and record duration — always runs - _duration_ms = int((time.monotonic() - _exec_start) * 1000) - if _span is not None: - _span.set_attribute("agent.total_steps", len(trajectory)) - _span.set_attribute("agent.total_tokens", total_tokens) - _span.set_attribute("agent.outcome", trace_outcome) - _span.set_attribute("agent.duration_ms", _duration_ms) - if _span_cm is not None: - _span_cm.__exit__(None, None, None) - agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name}) - - async def execute_stream( - self, - messages: list[dict[str, str]], - tools: list[Tool] | None = None, - model: str = "default", - agent_name: str = "", - task_type: str = "", - system_prompt: str | None = None, - trace_recorder: "TraceRecorder | None" = None, - memory_retriever: "MemoryRetriever | None" = None, - task_id: str | None = None, - compressor: "CompressionStrategy | None" = None, - retrieval_config: dict[str, Any] | None = None, - cancellation_token: CancellationToken | None = None, - timeout_seconds: float | None = None, - confirmation_handler: Any | None = None, - ): - """Execute ReAct loop, yielding ReActEvent objects. - - Same logic as execute() but yields events at each step instead of - accumulating a result. - - Args: - compressor: 压缩策略,None 时使用实例默认压缩器 - """ - # P2 #9: Reset loop detection state so reuse across conversations is clean - self.reset() - effective_compressor = compressor if compressor is not None else self._compressor - tools = tools or [] - if tools: - tools = self._maybe_add_tool_search(tools) - tool_schemas = self._build_tool_schemas(tools) if tools else None - if tool_schemas: - tool_names = [s["function"]["name"] for s in tool_schemas] - logger.info(f"ReActEngine executing with {len(tool_schemas)} tools: {tool_names}") - else: - logger.info("ReActEngine executing with NO tools") - - # Prompt-based tool calling: inject tool descriptions into system prompt - # when tools are available, so LLM can use format even if - # the provider doesn't support native function calling. - if tools and system_prompt is not None: - tool_desc = self._build_tool_use_prompt(tools) - system_prompt = f"{system_prompt}\n\n{tool_desc}" - elif tools and system_prompt is None: - system_prompt = self._build_tool_use_prompt(tools) - - # Telemetry: record agent request - agent_request_counter().add( - 1, {"agent.name": agent_name, "agent.type": task_type or "react"} - ) - - # Start telemetry span for the entire agent execution - _span_cm = None - _span = None - _exec_start = time.monotonic() - - if _OTEL_AVAILABLE: - _span_cm = start_span( - "agent.execute_stream", - attributes={"agent.name": agent_name, "agent.type": task_type or "react"}, - ) - _span = _span_cm.__enter__() - - # 启动轨迹记录 - if trace_recorder is not None: - trace_recorder.start_trace( - task_id="", - agent_name=agent_name, - skill_name=task_type or None, - ) - - # Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message - # U2/G2: 不再拼到 stable(system_prompt)末尾破坏 cache 前缀,改由 _build_system_message - # 组装双块结构(stable + volatile),Anthropic provider 在 stable 上加 cache_control。 - memory_context = "" - if memory_retriever: - try: - query = str(messages[-1].get("content", "")) if messages else "" - top_k = (retrieval_config or {}).get("top_k", 5) - token_budget = (retrieval_config or {}).get("token_budget", 2000) - memory_context = ( - await memory_retriever.get_context_string( - query=query, - top_k=top_k, - token_budget=token_budget, - ) - or "" - ) - except Exception as e: - logger.warning(f"Memory retrieval failed, continuing without context: {e}") - - conversation: list[dict[str, Any]] = [] - system_content = self._build_system_message( - stable=system_prompt or "", - volatile=memory_context, - model=model, - ) - if system_content is not None: - conversation.append({"role": "system", "content": system_content}) - conversation.extend(messages) - - # Context compression: 压缩超长对话历史 - if effective_compressor: - try: - conversation = await effective_compressor.compress(conversation) - except Exception as e: - logger.warning( - f"Context compression failed, continuing with original messages: {e}" - ) - - trajectory: list[ReActStep] = [] - total_tokens = 0 - step = 0 - output = "" - trace_outcome = "success" - # U4/G1: verify 失败回灌计数器(execute_stream 版)。受 max_steps 上限约束。 - reinjections = 0 - _stream_start = time.monotonic() - effective_timeout = ( - timeout_seconds if timeout_seconds is not None else self._default_timeout - ) - - try: - while step < self._max_steps: - step += 1 - - # 协作式取消检查 - if cancellation_token is not None: - cancellation_token.check() - - # U3/G6: phase auto-advance safety net (mirrors _execute_loop). - if self._phase_policy is not None: - self._steps_in_phase += 1 - self._maybe_auto_advance() - - # 超时检查 - if effective_timeout > 0: - elapsed = time.monotonic() - _stream_start + # 超时检查(仅 stream=True;stream=False 由 asyncio.wait_for 强制) + if stream and effective_timeout > 0: + elapsed = time.monotonic() - _loop_start if elapsed > effective_timeout: trace_outcome = "timeout" raise asyncio.TimeoutError( @@ -1325,80 +717,89 @@ class ReActEngine: data={"message": f"Step {step}: Calling LLM..."}, ) - # Think: call LLM (with optional token streaming) + # Think: 调用 LLM llm_start = time.monotonic() - # Use streaming for token-by-token output - stream_content_chunks: list[str] = [] - stream_usage = None - stream_tool_calls: list[Any] = [] - stream_model = model - # U3/G8: delta_flush 节流 buffer,按 flush_interval_ms 批量 yield - _flush_buffer: list[str] = [] - _last_flush_ts = time.monotonic() + if stream: + # 流式模式:用 chat_stream,yield token events + stream_content_chunks: list[str] = [] + stream_usage = None + stream_tool_calls: list[Any] = [] + stream_model = model + # U3/G8: delta_flush 节流 buffer + _flush_buffer: list[str] = [] + _last_flush_ts = time.monotonic() - async for chunk in _ensure_async_iterable( - self._llm_gateway.chat_stream( + async for chunk in _ensure_async_iterable( + self._llm_gateway.chat_stream( + messages=conversation, + model=model, + agent_name=agent_name, + task_type=task_type, + tools=tool_schemas, + ), + label=f"llm_gateway.chat_stream(model={model!r})", + ): + if chunk.content: + stream_content_chunks.append(chunk.content) + _flush_buffer.append(chunk.content) + now = time.monotonic() + if ( + self._flush_interval_ms == 0 + or now - _last_flush_ts >= self._flush_interval_ms / 1000 + ): + yield ReActEvent( + event_type="token", + step=step, + data={"content": "".join(_flush_buffer)}, + ) + _flush_buffer = [] + _last_flush_ts = now + if chunk.usage: + stream_usage = chunk.usage + if chunk.tool_calls: + stream_tool_calls = chunk.tool_calls + if chunk.model: + stream_model = chunk.model + + # 流结束 mid-interval → 最终 flush 剩余 buffer + if _flush_buffer: + yield ReActEvent( + event_type="token", + step=step, + data={"content": "".join(_flush_buffer)}, + ) + _flush_buffer = [] + + stream_content = "".join(stream_content_chunks) + response = self._build_response_from_stream( + content=stream_content, + tool_calls=stream_tool_calls, + usage=stream_usage, + model=stream_model, + ) + else: + # 非流式模式:用 chat + response = await self._llm_gateway.chat( messages=conversation, model=model, agent_name=agent_name, task_type=task_type, tools=tool_schemas, - ), - label=f"llm_gateway.chat_stream(model={model!r})", - ): - if chunk.content: - stream_content_chunks.append(chunk.content) - _flush_buffer.append(chunk.content) - now = time.monotonic() - # flush_interval_ms=0 → 逐 chunk yield(向后兼容,条件短路为 True) - if ( - self._flush_interval_ms == 0 - or now - _last_flush_ts >= self._flush_interval_ms / 1000 - ): - yield ReActEvent( - event_type="token", - step=step, - data={"content": "".join(_flush_buffer)}, - ) - _flush_buffer = [] - _last_flush_ts = now - if chunk.usage: - stream_usage = chunk.usage - if chunk.tool_calls: - stream_tool_calls = chunk.tool_calls - if chunk.model: - stream_model = chunk.model - - # U3/G8: 流结束 mid-interval → 最终 flush 剩余 buffer(不丢字符) - if _flush_buffer: - yield ReActEvent( - event_type="token", - step=step, - data={"content": "".join(_flush_buffer)}, ) - _flush_buffer = [] - # Build response-like object from stream - stream_content = "".join(stream_content_chunks) - response = self._build_response_from_stream( - content=stream_content, - tool_calls=stream_tool_calls, - usage=stream_usage, - model=stream_model, - ) llm_duration_ms = int((time.monotonic() - llm_start) * 1000) - step_tokens = response.usage.total_tokens total_tokens += step_tokens + # 检查是否有 Function Calling 的 tool_calls if response.has_tool_calls: # 循环检测:检查是否重复调用相同工具+参数 looped_tool = self._check_tool_loop(response.tool_calls) if looped_tool is not None: if not self._loop_corrected: logger.warning( - f"Loop detected (stream): tool '{looped_tool}' repeated, " + f"Loop detected: tool '{looped_tool}' repeated, " f"injecting correction at step {step}" ) correction_msg = { @@ -1436,7 +837,7 @@ class ReActEngine: tokens_used=step_tokens, ) - # Record assistant message + # Act: 记录 assistant 消息(含 tool_calls)到对话历史 assistant_msg: dict[str, Any] = { "role": "assistant", "content": response.content or "", @@ -1454,17 +855,11 @@ class ReActEngine: } conversation.append(assistant_msg) - # Execute tool calls with parallel support - if ( - self._parallel_tools - and len(response.tool_calls) > 1 - and self._should_execute_parallel(response.tool_calls) - ): - # Parallel execution path - parallelizable_set = ( - set(self._get_parallelizable_indices(response.tool_calls)) - if self._parallel_tools == "auto" - else set(range(len(response.tool_calls))) + # 执行工具调用 + if self._parallel_tools == "auto" and len(response.tool_calls) > 1: + # Auto mode: mixed parallel/serial based on _parallelizable flag + parallelizable_set = set( + self._get_parallelizable_indices(response.tool_calls) ) serial_calls = [ (i, tc) @@ -1556,18 +951,72 @@ class ReActEngine: step=step, data={"tool_name": tc.name, "result": tool_result}, ) - # Wave 4 U2: drain phase violations recorded by - # _check_phase_permission during this tool call. + # Wave 4 U2: drain phase violations. for _ev in self._drain_phase_violations(step): yield _ev tool_msg = await self._build_tool_result_message( - tc.id, tool_result, effective_compressor, tc.name + tc.id, tool_result, compressor, tc.name + ) + conversation.append(tool_msg) + elif self._should_execute_parallel(response.tool_calls): + # 并行执行多个工具调用 (parallel_tools=True) + tool_results = await asyncio.gather( + *[ + self._execute_tool(tc.name, tc.arguments, tools) + for tc in response.tool_calls + ], + return_exceptions=True, + ) + for idx, tc in enumerate(response.tool_calls): + tool_result = tool_results[idx] + if isinstance(tool_result, Exception): + tool_result = {"error": str(tool_result)} + + yield ReActEvent( + event_type="tool_call", + step=step, + data={"tool_name": tc.name, "arguments": tc.arguments}, + ) + + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=tc.name, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) + trajectory.append(react_step) + + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=0, + tokens_used=0, + error=tool_error, + ) + + yield ReActEvent( + event_type="tool_result", + step=step, + data={"tool_name": tc.name, "result": tool_result}, + ) + for _ev in self._drain_phase_violations(step): + yield _ev + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, compressor, tc.name ) conversation.append(tool_msg) else: - # Serial execution path (with confirmation flow) + # 串行执行(单工具或 parallel_tools=False) for tc in response.tool_calls: - # Yield tool_call event yield ReActEvent( event_type="tool_call", step=step, @@ -1578,7 +1027,7 @@ class ReActEngine: tool_result = await self._execute_tool(tc.name, tc.arguments, tools) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - # 检测工具返回的确认请求 + # Handle confirmation flow if isinstance(tool_result, dict) and tool_result.get( "needs_confirmation" ): @@ -1586,7 +1035,6 @@ class ReActEngine: command = tool_result.get("command", "") reason = tool_result.get("reason", "") - # Yield 确认请求事件 yield ReActEvent( event_type="confirmation_request", step=step, @@ -1598,7 +1046,6 @@ class ReActEngine: }, ) - # 等待用户确认 approved = False if confirmation_handler is not None: try: @@ -1609,10 +1056,8 @@ class ReActEngine: logger.warning(f"Confirmation handler error: {e}") if approved: - # 用户确认执行:使用 per-call override 绕过安全检查 tool = self._find_tool(tc.name, tools) if tool and hasattr(tool, "_is_dangerous"): - # Strip internal metadata and pass skip_dangerous_check flag clean_args = { k: v for k, v in tc.arguments.items() @@ -1621,10 +1066,11 @@ class ReActEngine: clean_args["_skip_dangerous_check"] = True try: tool_result = await tool.safe_execute(**clean_args) - finally: - pass # No shared state mutation needed + except Exception as e: + tool_result = { + "error": f"Tool '{tc.name}' execution failed: {e}" + } else: - # Non-dangerous tool: re-execute with skip flag clean_args = { k: v for k, v in tc.arguments.items() @@ -1642,13 +1088,13 @@ class ReActEngine: "error": f"Tool '{tc.name}' execution failed: {e}" } - yield ReActEvent( - event_type="confirmation_result", - step=step, - data={"confirmation_id": confirmation_id, "approved": True}, - ) - else: - # 用户拒绝执行 + yield ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": approved}, + ) + + if not approved: tool_result = { "output": "", "exit_code": 126, @@ -1656,14 +1102,6 @@ class ReActEngine: "error_type": "permission_denied", "message": f"用户拒绝执行命令: {command[:100]}", } - yield ReActEvent( - event_type="confirmation_result", - step=step, - data={ - "confirmation_id": confirmation_id, - "approved": False, - }, - ) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) @@ -1692,30 +1130,27 @@ class ReActEngine: error=tool_error, ) - # Yield tool_result event yield ReActEvent( event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result}, ) - # Wave 4 U2: drain phase violations. for _ev in self._drain_phase_violations(step): yield _ev - tool_msg = await self._build_tool_result_message( - tc.id, tool_result, effective_compressor, tc.name + tc.id, tool_result, compressor, tc.name ) conversation.append(tool_msg) - # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, effective_compressor): + # Incremental compression + if self._should_compress(conversation, compressor): try: - conversation = await effective_compressor.compress(conversation) + conversation = await compressor.compress(conversation) except Exception as e: logger.warning(f"Incremental compression failed: {e}") else: - # Check text parsing mode + # 检查文本解析模式 parsed_calls = self._parse_text_tool_calls(response.content or "") if parsed_calls and tools: # 记录 LLM 调用步骤 @@ -1740,17 +1175,17 @@ class ReActEngine: pc["name"], pc["arguments"], tools ) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - trajectory.append( - ReActStep( - step=step, - action="tool_call", - tool_name=pc["name"], - arguments=pc["arguments"], - result=tool_result, - tokens=step_tokens, - ) + + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=pc["name"], + arguments=pc["arguments"], + result=tool_result, + tokens=step_tokens, ) - # 记录工具调用步骤 + trajectory.append(react_step) + if trace_recorder is not None: tool_error = None if isinstance(tool_result, dict) and "error" in tool_result: @@ -1765,35 +1200,31 @@ class ReActEngine: tokens_used=0, error=tool_error, ) + yield ReActEvent( event_type="tool_result", step=step, data={"tool_name": pc["name"], "result": tool_result}, ) - # Wave 4 U2: drain phase violations. for _ev in self._drain_phase_violations(step): yield _ev tool_msg = await self._build_tool_result_message( - pc.get("id", f"text_tc_{step}"), - tool_result, - effective_compressor, - pc["name"], + pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"] ) conversation.append(tool_msg) - # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, effective_compressor): + # Incremental compression + if self._should_compress(conversation, compressor): try: - conversation = await effective_compressor.compress(conversation) + conversation = await compressor.compress(conversation) except Exception as e: logger.warning(f"Incremental compression failed: {e}") else: # ponytail: 检查是否为畸形工具调用(含 但解析失败) - # 如果是,注入纠正消息让模型重试,而不是把原始 XML 作为最终答案泄漏 if "" in (response.content or ""): logger.warning( f"Step {step}: content contains but " - f"parsing failed — injecting correction (stream)" + f"parsing failed — injecting correction" ) conversation.append({"role": "assistant", "content": response.content}) conversation.append( @@ -1816,7 +1247,7 @@ class ReActEngine: ) continue - # Final answer + # Final answer: LLM 没有调用工具,返回最终答案 react_step = ReActStep( step=step, action="final_answer", @@ -1826,7 +1257,6 @@ class ReActEngine: trajectory.append(react_step) output = response.content or "" - # 记录最终答案步骤 if trace_recorder is not None: trace_recorder.record_step( step=step, @@ -1836,10 +1266,7 @@ class ReActEngine: tokens_used=step_tokens, ) - # U4/G1: verify at final-answer point with reinjection (stream 版)。 - # 与 execute() 同模式:失败回灌 errors 作为 user 消息,continue 主循环。 - # max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。 - # 注意:final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号。 + # U4/G1: verify at final-answer point with reinjection. if self._verification_enabled and output: try: from agentkit.core.verification_loop import VerificationLoop @@ -1851,7 +1278,6 @@ class ReActEngine: reinjections < self._max_reinjections and step < self._max_steps ): - # 回灌 errors,不发 final_answer 事件,继续循环 errors_text = "\n".join(vresult.errors) conversation.append( { @@ -1872,7 +1298,6 @@ class ReActEngine: }, ) continue - # 达到 max_reinjections 或 max_steps → 记录 verify log 并中断 verification_step = ReActStep( step=step, action="tool_call", @@ -1910,6 +1335,7 @@ class ReActEngine: except Exception as e: logger.warning(f"Verification loop failed: {e}") + # Yield final_answer event (legacy format for execute_stream consumers) yield ReActEvent( event_type="final_answer", step=step, @@ -1921,14 +1347,17 @@ class ReActEngine: ) break # verify 通过或未启用 → 正常退出 + # 达到 max_steps 时,返回当前最佳输出 if step >= self._max_steps and not output: trace_outcome = "partial" if trajectory and trajectory[-1].content: output = trajectory[-1].content elif trajectory and trajectory[-1].result is not None: output = str(trajectory[-1].result) - else: + elif response is not None: output = response.content or "" + else: + output = "" yield ReActEvent( event_type="final_answer", @@ -1960,6 +1389,20 @@ class ReActEngine: "empty_fallback": True, }, ) + + # Yield final_result event (new — carries ReActResult for execute() to extract) + final_result = ReActResult( + output=output, + trajectory=trajectory, + total_steps=len(trajectory), + total_tokens=total_tokens, + status=trace_outcome, + ) + yield ReActEvent( + event_type="final_result", + step=step, + data={"result": final_result}, + ) finally: # 结束轨迹记录 — always runs even if consumer doesn't fully iterate if trace_recorder is not None: @@ -1988,6 +1431,59 @@ class ReActEngine: except Exception as e: logger.warning(f"Failed to store task result in episodic memory: {e}") + async def execute_stream( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + trace_recorder: "TraceRecorder | None" = None, + memory_retriever: "MemoryRetriever | None" = None, + task_id: str | None = None, + compressor: "CompressionStrategy | None" = None, + retrieval_config: dict[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + timeout_seconds: float | None = None, + confirmation_handler: Any | None = None, + ) -> AsyncGenerator[ReActEvent, None]: + """Execute ReAct loop, yielding ReActEvent objects. + + U1: execute_stream() now transparently passes through events from the + unified _execute_loop async generator (stream=True). The ~800 lines of + duplicated loop logic have been removed; both execute() and + execute_stream() share the same _execute_loop skeleton. + + Args: + compressor: 压缩策略,None 时使用实例默认压缩器 + timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout + """ + effective_compressor = compressor if compressor is not None else self._compressor + effective_timeout = ( + timeout_seconds if timeout_seconds is not None else self._default_timeout + ) + + # 透传 _execute_loop 的所有事件(stream=True 启用 chat_stream + token events) + async for event in self._execute_loop( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=effective_compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, + stream=True, + effective_timeout=effective_timeout, + ): + yield event + def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]: """将 Tool 对象转换为 OpenAI Function Calling schema 格式""" schemas = [] diff --git a/tests/unit/test_react_compression.py b/tests/unit/test_react_compression.py index 60999a3..a8fdaf7 100644 --- a/tests/unit/test_react_compression.py +++ b/tests/unit/test_react_compression.py @@ -6,7 +6,7 @@ import pytest from agentkit.core.compressor import CompressionStrategy, ContextCompressor from agentkit.core.react import ReActEngine -from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall +from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall # ── Helpers ────────────────────────────────────────── @@ -27,7 +27,10 @@ def make_mock_gateway() -> MagicMock: def make_mock_gateway_with_tool_call() -> MagicMock: - """创建一个返回 tool_call 的 mock LLMGateway,第二次调用返回最终答案""" + """创建一个返回 tool_call 的 mock LLMGateway,第二次调用返回最终答案 + + 同时设置 chat 和 chat_stream,使 execute 和 execute_stream 路径都能正常工作。 + """ from agentkit.llm.gateway import LLMGateway gateway = MagicMock(spec=LLMGateway) @@ -47,6 +50,32 @@ def make_mock_gateway_with_tool_call() -> MagicMock: usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway.chat = AsyncMock(side_effect=[tool_response, final_response]) + + # ponytail: chat_stream yields StreamChunk equivalents of the chat responses + # so execute_stream (which uses chat_stream) exercises the same tool path. + tool_chunk = StreamChunk( + content="", + model="test-model", + tool_calls=[ToolCall(id="call_1", name="search", arguments={"query": "test"})], + usage=TokenUsage(prompt_tokens=10, completion_tokens=10), + is_final=True, + ) + final_chunk = StreamChunk( + content="Final answer after tool", + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=10), + is_final=True, + ) + + async def _stream(**kwargs): + # Closure state tracks which response to yield (1st call=tool, 2nd=final) + _stream._call_count = getattr(_stream, "_call_count", 0) + 1 + if _stream._call_count == 1: + yield tool_chunk + else: + yield final_chunk + + gateway.chat_stream = _stream return gateway diff --git a/tests/unit/test_react_golden_trajectory.py b/tests/unit/test_react_golden_trajectory.py new file mode 100644 index 0000000..08df91e --- /dev/null +++ b/tests/unit/test_react_golden_trajectory.py @@ -0,0 +1,617 @@ +"""Golden trajectory characterization tests for ReActEngine. + +Locks in the current behavior of execute() and execute_stream() with fixed +mock LLM responses. These tests must pass BEFORE and AFTER the U1 refactor +(_execute_loop unification). Per plan KTD6: characterization-first. + +Scenarios covered (per plan U1 Test scenarios): +- Happy path: single tool call -> final answer (execute + execute_stream) +- Happy path streaming equivalence: execute vs execute_stream same output +- Multi-step loop: 3 tool calls then final answer +- Empty tools: LLM returns text directly +- Max steps: loop reaches max_steps -> status='partial' +- Tool failure: tool raises exception -> error in observation, loop continues +- LLM failure: gateway raises exception -> propagate +- Phase violation: tool blocked by phase policy -> phase_violation event +- Cancellation: CancellationToken cancelled -> TaskCancelledError +- Compression triggered: long conversation triggers compressor.compress() +- Golden trajectory snapshot: fixed mock -> event type sequence +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.core.react import ReActEvent, ReActResult, ReActStep +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall +from agentkit.tools.base import Tool + + +# ── Helpers ────────────────────────────────────────── + + +class FakeTool(Tool): + """Minimal Tool implementation for trajectory tests.""" + + def __init__( + self, + name: str = "fake_tool", + description: str = "A fake tool for testing", + result: dict | None = None, + should_fail: bool = False, + ): + super().__init__(name=name, description=description) + self._result = result or {"status": "ok"} + self._should_fail = should_fail + self.call_count = 0 + + async def execute(self, **kwargs) -> dict: + self.call_count += 1 + if self._should_fail: + raise RuntimeError(f"Tool '{self.name}' execution failed") + return self._result + + +def make_response( + content: str = "", + tool_calls: list[ToolCall] | None = None, + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> LLMResponse: + """Quick LLMResponse builder for non-streaming gateway mocks.""" + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + tool_calls=tool_calls or [], + ) + + +def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: + """Mock LLMGateway whose chat() returns responses in order.""" + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=responses) + return gateway + + +def make_mock_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock: + """Mock LLMGateway whose chat_stream() yields chunks in order. + + Each call to chat_stream consumes one inner list from chunks_list. + """ + gateway = MagicMock(spec=LLMGateway) + + async def _stream(**kwargs): + for chunks in chunks_list: + for chunk in chunks: + yield chunk + # Remove after use so a second call would raise StopIteration + chunks_list.pop(0) + + gateway.chat_stream = _stream + return gateway + + +def _tc(name: str, args: dict | None = None, tc_id: str = "tc_1") -> ToolCall: + """Quick ToolCall builder.""" + return ToolCall(id=tc_id, name=name, arguments=args or {}) + + +def _step_summary(step: ReActStep) -> str: + """Compact ReActStep summary for snapshot comparison.""" + return f"{step.action}@{step.step}:{step.tool_name or ''}" + + +def _stream_tool_call_chunk( + name: str, + args: dict | None = None, + tc_id: str = "tc_1", + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> StreamChunk: + """Single StreamChunk carrying a tool_call (simulates function-calling stream).""" + return StreamChunk( + content="", + model="test-model", + tool_calls=[ToolCall(id=tc_id, name=name, arguments=args or {})], + usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), + is_final=True, + ) + + +def _stream_content_chunk( + content: str, + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> StreamChunk: + """Single StreamChunk carrying final text content.""" + return StreamChunk( + content=content, + model="test-model", + usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), + is_final=True, + ) + + +# ── Happy path: single tool call ────────────────────────────────────────── + + +class TestGoldenHappyPath: + """Single tool call -> final answer. Locks in execute() result shape.""" + + async def test_execute_single_tool_call_trajectory(self): + from agentkit.core.react import ReActEngine + + tool = FakeTool(name="calculator", result={"value": 42}) + gateway = make_mock_gateway( + [ + make_response(tool_calls=[_tc("calculator", {"expr": "6*7"})]), + make_response(content="The result is 42"), + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Calculate 6*7"}], + tools=[tool], + ) + + # Golden trajectory snapshot — locking current shape + assert result.status == "success" + assert result.output == "The result is 42" + assert result.total_steps == 2 + assert result.total_tokens == 60 # (10+20) * 2 + assert [_step_summary(s) for s in result.trajectory] == [ + "tool_call@1:calculator", + "final_answer@2:", + ] + assert result.trajectory[0].result == {"value": 42} + assert result.trajectory[1].content == "The result is 42" + + async def test_execute_stream_single_tool_call_event_types(self): + """execute_stream event type sequence for single tool call. + + Locks current event types. After U1 refactor, an additional + 'final_result' event may appear at the end (not asserted here). + """ + from agentkit.core.react import ReActEngine + + tool = FakeTool(name="calculator", result={"value": 42}) + gateway = make_mock_stream_gateway( + [ + [_stream_tool_call_chunk("calculator", {"expr": "6*7"})], + [_stream_content_chunk("The result is 42")], + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Calculate 6*7"}], + tools=[tool], + ): + events.append(event) + + event_types = [e.event_type for e in events] + # Golden sequence: thinking -> tool_call -> tool_result -> thinking -> final_answer + assert "thinking" in event_types + assert "tool_call" in event_types + assert "tool_result" in event_types + assert "final_answer" in event_types + # tool_result must come after tool_call + assert event_types.index("tool_result") > event_types.index("tool_call") + # final_answer must come after tool_result + assert event_types.index("final_answer") > event_types.index("tool_result") + + # Verify tool_call event data + tool_call_event = next(e for e in events if e.event_type == "tool_call") + assert tool_call_event.data["tool_name"] == "calculator" + assert tool_call_event.data["arguments"] == {"expr": "6*7"} + + # Verify tool_result event data + tool_result_event = next(e for e in events if e.event_type == "tool_result") + assert tool_result_event.data["tool_name"] == "calculator" + assert tool_result_event.data["result"] == {"value": 42} + + # Verify final_answer event data + final_event = next(e for e in events if e.event_type == "final_answer") + assert final_event.data["output"] == "The result is 42" + assert final_event.data["total_steps"] == 2 + assert final_event.data["total_tokens"] == 60 + + +# ── Streaming equivalence ───────────────────────────────────────── + + +class TestStreamingEquivalence: + """execute() and execute_stream() produce equivalent results for same input. + + After U1 refactor, both delegate to the same _execute_loop, so equivalence + is structural. Before refactor, this test characterizes the current drift + (e.g., compress_tool_result called by execute but not execute_stream). + """ + + async def test_execute_and_stream_same_output(self): + from agentkit.core.react import ReActEngine + + tool = FakeTool(name="search", result={"results": ["data"]}) + gateway_exec = make_mock_gateway( + [ + make_response(tool_calls=[_tc("search", {"q": "test"})]), + make_response(content="Found data"), + ] + ) + engine_exec = ReActEngine(llm_gateway=gateway_exec) + result = await engine_exec.execute( + messages=[{"role": "user", "content": "Search"}], + tools=[tool], + ) + + # execute_stream path with equivalent stream chunks + tool2 = FakeTool(name="search", result={"results": ["data"]}) + gateway_stream = make_mock_stream_gateway( + [ + [_stream_tool_call_chunk("search", {"q": "test"})], + [_stream_content_chunk("Found data")], + ] + ) + engine_stream = ReActEngine(llm_gateway=gateway_stream) + events = [] + async for event in engine_stream.execute_stream( + messages=[{"role": "user", "content": "Search"}], + tools=[tool2], + ): + events.append(event) + + final_answer_events = [e for e in events if e.event_type == "final_answer"] + assert len(final_answer_events) == 1 + stream_final = final_answer_events[0].data + + # Equivalence on the user-visible fields + assert result.output == stream_final["output"] + assert result.total_steps == stream_final["total_steps"] + assert result.total_tokens == stream_final["total_tokens"] + + +# ── Multi-step loop ───────────────────────────────────────── + + +class TestGoldenMultiStep: + """3 tool calls then final answer.""" + + async def test_execute_three_step_trajectory(self): + from agentkit.core.react import ReActEngine + + search = FakeTool(name="search", result={"results": ["a"]}) + calc = FakeTool(name="calculator", result={"value": 100}) + gateway = make_mock_gateway( + [ + make_response(tool_calls=[_tc("search", {"query": "Python"})]), + make_response(tool_calls=[_tc("calculator", {"expr": "10*10"})]), + make_response(content="Based on search and calculation, the answer is 100"), + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Search and calculate"}], + tools=[search, calc], + ) + + assert [_step_summary(s) for s in result.trajectory] == [ + "tool_call@1:search", + "tool_call@2:calculator", + "final_answer@3:", + ] + assert result.total_steps == 3 + assert search.call_count == 1 + assert calc.call_count == 1 + + +# ── Empty tools ───────────────────────────────────────── + + +class TestGoldenEmptyTools: + """No tools -> LLM returns text directly.""" + + async def test_execute_no_tools_direct_answer(self): + from agentkit.core.react import ReActEngine + + gateway = make_mock_gateway([make_response(content="Direct answer")]) + engine = ReActEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Hello"}], + tools=None, + ) + + assert result.output == "Direct answer" + assert result.total_steps == 1 + assert result.status == "success" + assert [_step_summary(s) for s in result.trajectory] == ["final_answer@1:"] + + +# ── Max steps ───────────────────────────────────────── + + +class TestGoldenMaxSteps: + """Loop reaches max_steps -> status='partial'.""" + + async def test_execute_max_steps_partial_status(self): + from agentkit.core.react import ReActEngine + + tool = FakeTool(name="search", result={"results": []}) + # Each step uses a different query to avoid loop detection + responses = [ + make_response( + content="Thinking...", + tool_calls=[_tc("search", {"query": f"attempt_{i}"}, tc_id=f"tc_{i}")], + ) + for i in range(20) + ] + gateway = make_mock_gateway(responses) + engine = ReActEngine(llm_gateway=gateway, max_steps=3) + + result = await engine.execute( + messages=[{"role": "user", "content": "Keep searching"}], + tools=[tool], + ) + + assert result.total_steps == 3 + assert result.status == "partial" + # All 3 steps are tool_calls (no final_answer) + assert all(s.action == "tool_call" for s in result.trajectory) + + +# ── Tool failure ───────────────────────────────────────── + + +class TestGoldenToolFailure: + """Tool raises exception -> error in observation, loop continues.""" + + async def test_execute_tool_failure_continues(self): + from agentkit.core.react import ReActEngine + + failing = FakeTool(name="broken", should_fail=True) + gateway = make_mock_gateway( + [ + make_response(tool_calls=[_tc("broken", {})]), + make_response(content="Recovered from tool failure"), + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + result = await engine.execute( + messages=[{"role": "user", "content": "Use broken tool"}], + tools=[failing], + ) + + assert result.trajectory[0].action == "tool_call" + assert "failed" in str(result.trajectory[0].result).lower() + assert result.trajectory[1].action == "final_answer" + assert result.output == "Recovered from tool failure" + assert result.total_steps == 2 + + +# ── LLM failure ───────────────────────────────────────── + + +class TestGoldenLLMFailure: + """LLM gateway raises exception -> propagate to caller.""" + + async def test_execute_llm_failure_propagates(self): + from agentkit.core.react import ReActEngine + + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=RuntimeError("LLM down")) + engine = ReActEngine(llm_gateway=gateway) + + with pytest.raises(RuntimeError, match="LLM down"): + await engine.execute(messages=[{"role": "user", "content": "Hi"}]) + + +# ── Phase violation ───────────────────────────────────────── + + +class TestGoldenPhaseViolation: + """Tool blocked by phase policy -> phase_violation event in stream.""" + + async def test_stream_phase_violation_event(self): + from agentkit.core.phase import default_policy + from agentkit.core.react import ReActEngine + + async def _stream(**kwargs): + yield _stream_tool_call_chunk("write_file", {"path": "/x"}) + yield _stream_content_chunk("done") + + gateway = MagicMock(spec=LLMGateway) + gateway.chat_stream = _stream + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=default_policy(), + max_steps=2, + ) + # write_file is blocked in PLANNING; _find_tool won't be reached + engine._find_tool = lambda name, tools: None + + events: list[ReActEvent] = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[], + ): + events.append(event) + + violation_events = [e for e in events if e.event_type == "phase_violation"] + assert len(violation_events) >= 1 + v = violation_events[0].data + assert v["tool"] == "write_file" + assert v["current_phase"] == "planning" + assert v["error"] == "phase_violation" + + +# ── Cancellation ───────────────────────────────────────── + + +class TestGoldenCancellation: + """CancellationToken cancelled -> TaskCancelledError.""" + + async def test_execute_cancelled_before_start(self): + from agentkit.core.exceptions import TaskCancelledError + from agentkit.core.protocol import CancellationToken + from agentkit.core.react import ReActEngine + + gateway = make_mock_gateway([make_response(content="hi")]) + engine = ReActEngine(llm_gateway=gateway) + token = CancellationToken() + token.cancel() + + with pytest.raises(TaskCancelledError): + await engine.execute( + messages=[{"role": "user", "content": "Hi"}], + cancellation_token=token, + ) + + +# ── Compression triggered ───────────────────────────────────────── + + +class TestGoldenCompression: + """Long conversation triggers compressor.compress().""" + + async def test_execute_compression_triggered(self): + from agentkit.core.compressor import CompressionStrategy + from agentkit.core.react import ReActEngine + + compressor = MagicMock(spec=CompressionStrategy) + # passthrough — return messages unchanged + compressor.compress = AsyncMock(side_effect=lambda msgs: msgs) + compressor.is_available = MagicMock(return_value=True) + compressor.should_compress = MagicMock(return_value=True) + + gateway = make_mock_gateway( + [ + make_response(tool_calls=[_tc("search", {"q": "test"})]), + make_response(content="Done"), + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + mock_tool = MagicMock() + mock_tool.name = "search" + mock_tool.safe_execute = AsyncMock(return_value="result") + + long_content = "x" * 40000 + await engine.execute( + messages=[{"role": "user", "content": long_content}], + tools=[mock_tool], + compressor=compressor, + ) + + # compress should be called (initial + incremental) + assert compressor.compress.call_count >= 1 + + async def test_execute_tool_result_compressed(self): + """execute() path calls compress_tool_result via _build_tool_result_message. + + This is the behavior the U1 refactor must preserve (and which + execute_stream currently lacks — see test_execute_stream_with_compressor + in test_react_compression.py). + """ + from agentkit.core.compressor import CompressionStrategy + from agentkit.core.react import ReActEngine + + compressor = MagicMock(spec=CompressionStrategy) + compressor.compress = AsyncMock(side_effect=lambda msgs: msgs) + compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED") + compressor.is_available = MagicMock(return_value=True) + compressor.should_compress = MagicMock(return_value=False) + + gateway = make_mock_gateway( + [ + make_response(tool_calls=[_tc("search", {"q": "test"})]), + make_response(content="Done"), + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + mock_tool = MagicMock() + mock_tool.name = "search" + mock_tool.safe_execute = AsyncMock(return_value="original result") + + await engine.execute( + messages=[{"role": "user", "content": "Search"}], + tools=[mock_tool], + compressor=compressor, + ) + + # execute() path MUST call compress_tool_result — this is the + # behavior that test_execute_stream_with_compressor expects + # execute_stream to also have after U1 unification. + compressor.compress_tool_result.assert_called_once_with("search", "original result") + + +# ── Golden trajectory snapshot (full event sequence) ──────────────────── + + +class TestGoldenTrajectorySnapshot: + """Full event sequence snapshot for execute_stream. + + Locks the EXACT event type sequence for a fixed 2-step flow. + Any change indicates a behavior change. + """ + + async def test_stream_two_step_event_sequence(self): + from agentkit.core.react import ReActEngine + + tool = FakeTool(name="search", result={"results": ["data"]}) + gateway = make_mock_stream_gateway( + [ + [_stream_tool_call_chunk("search", {"q": "test"})], + [_stream_content_chunk("Final answer")], + ] + ) + engine = ReActEngine(llm_gateway=gateway) + + events: list[ReActEvent] = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Search"}], + tools=[tool], + ): + events.append(event) + + event_types = [e.event_type for e in events] + + # Snapshot (pre-refactor): thinking, tool_call, tool_result, thinking, final_answer + # Post-refactor may append 'final_result' at the end (not asserted here). + # Verify the relative ordering of key events is preserved. + assert event_types[0] == "thinking" + assert "tool_call" in event_types + assert "tool_result" in event_types + assert event_types.index("tool_result") > event_types.index("tool_call") + assert "final_answer" in event_types + assert event_types.index("final_answer") > event_types.index("tool_result") + + # Verify step numbers: tool events on step 1, final on step 2 + tool_call_event = next(e for e in events if e.event_type == "tool_call") + assert tool_call_event.step == 1 + final_event = next(e for e in events if e.event_type == "final_answer") + assert final_event.step == 2 + + async def test_execute_returns_react_result(self): + """execute() returns a ReActResult (not events). Locks the type contract.""" + from agentkit.core.react import ReActEngine + + gateway = make_mock_gateway([make_response(content="Done")]) + engine = ReActEngine(llm_gateway=gateway) + + result = await engine.execute(messages=[{"role": "user", "content": "Hi"}]) + + assert isinstance(result, ReActResult) + assert result.output == "Done" + assert result.status == "success" + assert isinstance(result.trajectory, list) + assert all(isinstance(s, ReActStep) for s in result.trajectory)