From cd211c6cd9ac6929b71bc470a9d73a39623d2325 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 29 Jun 2026 21:35:08 +0800 Subject: [PATCH] =?UTF-8?q?feat(U4):=20G1=20verify=20=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E5=9B=9E=E7=81=8C=20ReAct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ReActEngine 新增 max_reinjections 构造参数(默认 1,=0 等价原行为) - execute()/execute_stream() verify 块从循环后移到循环内 final-answer 检测点: - verify 通过 → 正常 break - verify 失败 + reinjections < max + step < max_steps → errors 作为 user 消息回灌 conversation, continue 让 LLM 自纠正 - verify 失败 + 达到 max_reinjections 或 max_steps → 记录 verify log 到 trajectory, trace_outcome="verify_failed", break - execute_stream 的 final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号 - ReActResult.status 现在传递 trace_outcome(原默认 "success") - ServerConfig.verification 配置项(max_reinjections) - test_verify_reinjection.py 10 测试:characterization(max=0)+ 新行为(R1/R2/R3/R14) --- src/agentkit/core/react.py | 273 +++++++++++------- src/agentkit/server/config.py | 4 + tests/unit/test_verify_reinjection.py | 394 ++++++++++++++++++++++++++ 3 files changed, 574 insertions(+), 97 deletions(-) create mode 100644 tests/unit/test_verify_reinjection.py diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 6c78bb2..a5259b6 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -165,6 +165,7 @@ class ReActEngine: middleware_chain: "MiddlewareChain | None" = None, prompt_cache_enable: bool = True, flush_interval_ms: int = 0, + max_reinjections: int = 1, ): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") @@ -184,6 +185,10 @@ class ReActEngine: # U3/G8: token chunk 节流间隔(ms)。0 = 逐 chunk yield(向后兼容)。 # 用 time.monotonic() 不受系统时钟跳变影响。 self._flush_interval_ms = flush_interval_ms + # U4/G1: verify 失败回灌最大重试次数。0 = 不回灌(当前行为,仅记录 trajectory); + # 1 = 首次失败回灌一次 errors 给 LLM 自纠正,二次失败中断。 + # 受 max_steps 上限约束(不无限循环)。verification_enabled=False 时无效。 + self._max_reinjections = max_reinjections # Tiered tool description injection config self._core_tool_names: tuple[str, ...] | None = ( tuple(core_tool_names) if core_tool_names is not None else None @@ -445,11 +450,14 @@ class ReActEngine: query = str(messages[-1].get("content", "")) if messages else "" top_k = (retrieval_config or {}).get("top_k", 5) token_budget = (retrieval_config or {}).get("token_budget", 2000) - memory_context = await memory_retriever.get_context_string( - query=query, - top_k=top_k, - token_budget=token_budget, - ) or "" + memory_context = ( + await memory_retriever.get_context_string( + query=query, + top_k=top_k, + token_budget=token_budget, + ) + or "" + ) except Exception as e: logger.warning( f"Memory retrieval failed, continuing without context: {e}", exc_info=True @@ -478,6 +486,8 @@ class ReActEngine: trace_outcome = "success" step = 0 output = "" + # U4/G1: verify 失败回灌计数器。受 max_steps 上限约束(不无限循环)。 + reinjections = 0 while step < self._max_steps: step += 1 @@ -854,16 +864,14 @@ class ReActEngine: f"Step {step}: content contains but " f"parsing failed — injecting correction" ) - conversation.append( - {"role": "assistant", "content": response.content} - ) + conversation.append({"role": "assistant", "content": response.content}) conversation.append( { "role": "user", "content": ( "你上一次的工具调用格式有误,无法解析。" "请使用正确的格式重新调用工具:\n" - '\n' + "\n" '{"name": "工具名", "arguments": {"参数名": "参数值"}}\n' "\n" "确保 JSON 完整且不要混入其他标签。" @@ -891,36 +899,65 @@ class ReActEngine: duration_ms=llm_duration_ms, tokens_used=step_tokens, ) - break - # Verification: 如果启用验证,在 final answer 后运行测试 - if self._verification_enabled and output: - try: - from agentkit.core.verification_loop import VerificationLoop + # U4/G1: verify at final-answer point with reinjection. + # 原为循环后一次性运行;现改为循环内检测 final answer 后立即 verify, + # 失败则把 errors 作为 user 消息回灌 conversation,continue 主循环让 LLM 自纠正。 + # max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。 + if self._verification_enabled and output: + try: + from agentkit.core.verification_loop import VerificationLoop - vloop = VerificationLoop(commands=self._verification_commands) - vresult = await vloop.verify() - if not vresult.passed: - # 将验证失败信息作为 ReActStep 添加到轨迹 - verification_step = ReActStep( - step=step + 1, - action="tool_call", - tool_name="verification", - arguments={"commands": self._verification_commands}, - result={ - "passed": vresult.passed, - "errors": vresult.errors, - "test_output": vresult.test_output, - }, - content=(f"Verification failed:\n{vresult.test_output[:2000]}"), - ) - trajectory.append(verification_step) - logger.info( - "Verification failed after final answer, " - "appended feedback to trajectory" - ) - except Exception as e: - logger.warning(f"Verification loop failed: {e}") + vloop = VerificationLoop(commands=self._verification_commands) + vresult = await vloop.verify() + if not vresult.passed: + if ( + reinjections < self._max_reinjections + and step < self._max_steps + ): + # 回灌 errors 作为 user 消息,让 LLM 自纠正 + errors_text = "\n".join(vresult.errors) + conversation.append( + { + "role": "user", + "content": (f"验证失败,错误如下:\n{errors_text}"), + } + ) + reinjections += 1 + logger.info( + "Verification failed (reinjection %d/%d), " + "errors injected into conversation", + reinjections, + self._max_reinjections, + ) + continue + # 达到 max_reinjections 或 max_steps → 记录 verify log 并中断 + verification_step = ReActStep( + step=step, + action="tool_call", + tool_name="verification", + arguments={"commands": self._verification_commands}, + result={ + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + content=( + f"Verification failed:\n{vresult.test_output[:2000]}" + ), + ) + trajectory.append(verification_step) + trace_outcome = "verify_failed" + logger.info( + "Verification failed after %d reinjections, " + "interrupting with verify log", + reinjections, + ) + break + except Exception as e: + logger.warning(f"Verification loop failed: {e}") + + break # verify 通过或未启用 → 正常退出 # 达到 max_steps 时,返回当前最佳输出 if step >= self._max_steps and not output: @@ -964,6 +1001,7 @@ class ReActEngine: trajectory=trajectory, total_steps=len(trajectory), total_tokens=total_tokens, + status=trace_outcome, ) finally: # Telemetry: end span and record duration — always runs @@ -1058,11 +1096,14 @@ class ReActEngine: query = str(messages[-1].get("content", "")) if messages else "" top_k = (retrieval_config or {}).get("top_k", 5) token_budget = (retrieval_config or {}).get("token_budget", 2000) - memory_context = await memory_retriever.get_context_string( - query=query, - top_k=top_k, - token_budget=token_budget, - ) or "" + memory_context = ( + await memory_retriever.get_context_string( + query=query, + top_k=top_k, + token_budget=token_budget, + ) + or "" + ) except Exception as e: logger.warning(f"Memory retrieval failed, continuing without context: {e}") @@ -1090,6 +1131,8 @@ class ReActEngine: step = 0 output = "" trace_outcome = "success" + # U4/G1: verify 失败回灌计数器(execute_stream 版)。受 max_steps 上限约束。 + reinjections = 0 _stream_start = time.monotonic() effective_timeout = ( timeout_seconds if timeout_seconds is not None else self._default_timeout @@ -1579,16 +1622,14 @@ class ReActEngine: f"Step {step}: content contains but " f"parsing failed — injecting correction (stream)" ) - conversation.append( - {"role": "assistant", "content": response.content} - ) + conversation.append({"role": "assistant", "content": response.content}) conversation.append( { "role": "user", "content": ( "你上一次的工具调用格式有误,无法解析。" "请使用正确的格式重新调用工具:\n" - '\n' + "\n" '{"name": "工具名", "arguments": {"参数名": "参数值"}}\n' "\n" "确保 JSON 完整且不要混入其他标签。" @@ -1622,6 +1663,80 @@ class ReActEngine: tokens_used=step_tokens, ) + # U4/G1: verify at final-answer point with reinjection (stream 版)。 + # 与 execute() 同模式:失败回灌 errors 作为 user 消息,continue 主循环。 + # max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。 + # 注意:final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号。 + if self._verification_enabled and output: + try: + from agentkit.core.verification_loop import VerificationLoop + + vloop = VerificationLoop(commands=self._verification_commands) + vresult = await vloop.verify() + if not vresult.passed: + if ( + reinjections < self._max_reinjections + and step < self._max_steps + ): + # 回灌 errors,不发 final_answer 事件,继续循环 + errors_text = "\n".join(vresult.errors) + conversation.append( + { + "role": "user", + "content": (f"验证失败,错误如下:\n{errors_text}"), + } + ) + reinjections += 1 + yield ReActEvent( + event_type="step", + step=step, + data={ + "message": ( + f"验证失败,已注入错误信息让 LLM 自纠正 " + f"(reinjection {reinjections}/{self._max_reinjections})" + ), + "verify_errors": vresult.errors, + }, + ) + continue + # 达到 max_reinjections 或 max_steps → 记录 verify log 并中断 + verification_step = ReActStep( + step=step, + action="tool_call", + tool_name="verification", + arguments={"commands": self._verification_commands}, + result={ + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + content=( + f"Verification failed:\n{vresult.test_output[:2000]}" + ), + ) + trajectory.append(verification_step) + trace_outcome = "verify_failed" + yield ReActEvent( + event_type="tool_result", + step=step, + data={ + "tool_name": "verification", + "result": { + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + }, + ) + logger.info( + "Verification failed after %d reinjections, " + "interrupting with verify log", + reinjections, + ) + break + except Exception as e: + logger.warning(f"Verification loop failed: {e}") + yield ReActEvent( event_type="final_answer", step=step, @@ -1631,47 +1746,7 @@ class ReActEngine: "total_tokens": total_tokens, }, ) - break - - # Verification: 如果启用验证,在 final answer 后运行测试 - if self._verification_enabled and output: - try: - from agentkit.core.verification_loop import VerificationLoop - - vloop = VerificationLoop(commands=self._verification_commands) - vresult = await vloop.verify() - if not vresult.passed: - verification_step = ReActStep( - step=step + 1, - action="tool_call", - tool_name="verification", - arguments={"commands": self._verification_commands}, - result={ - "passed": vresult.passed, - "errors": vresult.errors, - "test_output": vresult.test_output, - }, - content=(f"Verification failed:\n{vresult.test_output[:2000]}"), - ) - trajectory.append(verification_step) - yield ReActEvent( - event_type="tool_result", - step=step + 1, - data={ - "tool_name": "verification", - "result": { - "passed": vresult.passed, - "errors": vresult.errors, - "test_output": vresult.test_output, - }, - }, - ) - logger.info( - "Verification failed after final answer, " - "appended feedback to trajectory" - ) - except Exception as e: - logger.warning(f"Verification loop failed: {e}") + break # verify 通过或未启用 → 正常退出 if step >= self._max_steps and not output: trace_outcome = "partial" @@ -1784,16 +1859,20 @@ class ReActEngine: if provider_name == "anthropic": blocks: list[dict[str, Any]] = [] if stable: - blocks.append({ - "type": "text", - "text": stable, - "cache_control": {"type": "ephemeral"}, - }) + blocks.append( + { + "type": "text", + "text": stable, + "cache_control": {"type": "ephemeral"}, + } + ) if volatile: - blocks.append({ - "type": "text", - "text": f"## 参考信息\n{volatile}", - }) + blocks.append( + { + "type": "text", + "text": f"## 参考信息\n{volatile}", + } + ) return blocks if blocks else None # 非 Anthropic:字符串拼接,stable 前缀命中 OpenAI/DashScope 自动前缀缓存 @@ -2213,7 +2292,7 @@ class ReActEngine: # 兜底解析:从 后提取 JSON 片段,用大括号匹配法恢复完整 JSON open_pattern = re.compile(r"\s*", re.IGNORECASE) for match in open_pattern.finditer(content): - remainder = content[match.end():] + remainder = content[match.end() :] parsed = self._extract_tool_call_from_malformed(remainder) if parsed: calls.append(parsed) diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index 2e8a540..b5b33da 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -118,6 +118,7 @@ class ServerConfig: board: dict[str, Any] | None = None, prompt_cache: dict[str, Any] | None = None, streaming: dict[str, Any] | None = None, + verification: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -149,6 +150,9 @@ class ServerConfig: self.prompt_cache = prompt_cache or {} # U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield) self.streaming = streaming or {} + # U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1) + # verification_enabled=False 时此配置无效 + self.verification = verification or {} self.on_change = on_change # Config watching state diff --git a/tests/unit/test_verify_reinjection.py b/tests/unit/test_verify_reinjection.py new file mode 100644 index 0000000..985e2cc --- /dev/null +++ b/tests/unit/test_verify_reinjection.py @@ -0,0 +1,394 @@ +"""U4/G1: Verify 失败回灌 ReAct 测试 + +characterization-first: 先覆盖现有 verify 行为(max_reinjections=0 等价于不回灌), +再测新回灌行为(reinjection on first fail, break on second fail)。 + +R1: verify 首次失败 → errors 注入 conversation → LLM 自纠正 → 二次 verify 通过 +R2: verify 二次失败 → 中断返回错误附 verify log +R3: max_reinjections 可配置(默认 1),=0 等价于不回灌;回灌受 max_steps 约束 +R13: ServerConfig.verification 配置项 +R14: max_reinjections 默认值为 1 +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + + +from agentkit.core.react import ReActEngine +from agentkit.core.verification_loop import VerificationResult +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, TokenUsage + + +# ── Helpers ────────────────────────────────────────────── + + +def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway: + """创建按顺序返回给定响应的 mock LLMGateway。""" + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=responses) + gateway.get_provider_name_for_model = MagicMock(return_value=None) + return gateway + + +def make_response(content: str = "") -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + tool_calls=[], + ) + + +def make_verify_result(passed: bool, errors: list[str] | None = None) -> VerificationResult: + return VerificationResult( + passed=passed, + attempts=1, + test_output="$ pytest\nFAILED test_x.py" if not passed else "$ pytest\nOK", + errors=errors or ([] if passed else ["test_x.py::test_failed"]), + ) + + +def make_mock_vloop(verify_results: list[VerificationResult]) -> MagicMock: + """创建一个 mock VerificationLoop,verify() 按顺序返回给定结果。""" + vloop = MagicMock() + vloop.verify = AsyncMock(side_effect=verify_results) + return vloop + + +# ── Characterization: max_reinjections=0 等价于当前行为 ────────── + + +class TestVerifyCharacterization: + """现有 verify 行为(max_reinjections=0):失败仅记录 trajectory,不回灌。""" + + async def test_verify_disabled_no_verify_step(self): + """verification_enabled=False → 不运行 verify,trajectory 无 verification step。""" + gateway = make_mock_gateway([make_response("final answer")]) + engine = ReActEngine(llm_gateway=gateway, max_steps=3) + + result = await engine.execute( + messages=[{"role": "user", "content": "do something"}], + ) + + assert result.output == "final answer" + assert all(s.tool_name != "verification" for s in result.trajectory) + + async def test_verify_pass_no_extra_step(self): + """verify 通过 → 不追加 verification step。""" + gateway = make_mock_gateway([make_response("answer")]) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=3, + verification_enabled=True, + verification_commands=["echo ok"], + max_reinjections=0, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop([make_verify_result(passed=True)]), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "do something"}], + ) + + assert result.output == "answer" + assert all(s.tool_name != "verification" for s in result.trajectory) + + async def test_verify_fail_max_zero_no_reinjection(self): + """max_reinjections=0 + verify 失败 → 仅记录 trajectory,不回灌 LLM。 + + 这是当前行为的 characterization:gateway.chat 只被调用一次。 + """ + gateway = make_mock_gateway([make_response("bad answer")]) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=3, + verification_enabled=True, + verification_commands=["false"], + max_reinjections=0, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop([make_verify_result(passed=False)]), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "do something"}], + ) + + # LLM 只被调用一次(无回灌) + assert gateway.chat.await_count == 1 + # 输出仍保留 + assert result.output == "bad answer" + # trajectory 包含 verification step + verify_steps = [s for s in result.trajectory if s.tool_name == "verification"] + assert len(verify_steps) == 1 + assert verify_steps[0].result["passed"] is False + + +# ── R1: 回灌后 LLM 自纠正 → 二次 verify 通过 ────────── + + +class TestVerifyReinjection: + """verify 失败回灌 conversation,LLM 自纠正后二次 verify 通过。""" + + async def test_first_fail_reinject_second_pass(self): + """R1: verify 首次失败 → errors 注入 conversation → LLM 修正 → 二次 verify 通过。""" + gateway = make_mock_gateway( + [ + make_response("bad code"), # 第一次:错误答案 + make_response("fixed code"), # 第二次:修正后答案 + ] + ) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=5, + verification_enabled=True, + verification_commands=["pytest"], + max_reinjections=1, # 默认值,允许 1 次回灌 + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop( + [ + make_verify_result(passed=False), # 第一次 verify 失败 + make_verify_result(passed=True), # 第二次 verify 通过 + ] + ), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "write code"}], + ) + + # LLM 被调用两次 + assert gateway.chat.await_count == 2 + # 最终输出是修正后的答案 + assert result.output == "fixed code" + # 二次 verify 通过,不追加 verification step + verify_steps = [s for s in result.trajectory if s.tool_name == "verification"] + assert len(verify_steps) == 0 + + async def test_reinjected_user_message_appears_in_conversation(self): + """R1 集成:回灌的 user 消息出现在 conversation,含 errors 文本。""" + gateway = make_mock_gateway( + [ + make_response("bad"), + make_response("good"), + ] + ) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=5, + verification_enabled=True, + verification_commands=["pytest"], + max_reinjections=1, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop( + [ + make_verify_result(passed=False, errors=["AssertionError: x != y"]), + make_verify_result(passed=True), + ] + ), + ): + await engine.execute( + messages=[{"role": "user", "content": "write code"}], + ) + + # 第二次 LLM 调用时,conversation 应包含回灌的 user 消息 + second_call = gateway.chat.await_args_list[1] + msgs_sent = second_call.kwargs.get("messages") or second_call[1].get("messages") + reinjected = [ + m for m in msgs_sent if m.get("role") == "user" and "验证失败" in m.get("content", "") + ] + assert len(reinjected) >= 1 + assert "AssertionError: x != y" in reinjected[-1]["content"] + + +# ── R2: 二次 verify 失败 → 中断返回错误 ────────── + + +class TestVerifyDoubleFailure: + """verify 二次失败 → 中断,返回错误附 verify log。""" + + async def test_second_fail_breaks_with_verify_log(self): + """R2: 二次 verify 失败 → 中断,trajectory 含 verify log + errors。""" + gateway = make_mock_gateway( + [ + make_response("bad v1"), + make_response("bad v2"), + ] + ) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=5, + verification_enabled=True, + verification_commands=["pytest"], + max_reinjections=1, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop( + [ + make_verify_result(passed=False, errors=["err1"]), + make_verify_result(passed=False, errors=["err2"]), + ] + ), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "write code"}], + ) + + # LLM 被调用两次(initial + 1 reinjection) + assert gateway.chat.await_count == 2 + # 状态标记 verify 失败 + assert result.status == "verify_failed" + # 输出保留(LLM 最后的答案) + assert result.output == "bad v2" + # trajectory 含 verification step with errors + verify_steps = [s for s in result.trajectory if s.tool_name == "verification"] + assert len(verify_steps) == 1 + assert verify_steps[0].result["passed"] is False + assert "err2" in verify_steps[0].result["errors"] + + +# ── R3: 配置 + 边界 ────────── + + +class TestVerifyReinjectionConfig: + """max_reinjections 配置测试。""" + + def test_default_max_reinjections_is_one(self): + """R14 self-check: max_reinjections 默认值为 1。""" + gateway = make_mock_gateway([]) + engine = ReActEngine(llm_gateway=gateway) + assert engine._max_reinjections == 1 + + async def test_max_reinjections_zero_skips_reinjection(self): + """R3: max_reinjections=0 → 等价于不回灌(当前行为)。""" + gateway = make_mock_gateway([make_response("only answer")]) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=5, + verification_enabled=True, + verification_commands=["false"], + max_reinjections=0, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop([make_verify_result(passed=False)]), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "do something"}], + ) + + assert gateway.chat.await_count == 1 # 无回灌 + assert result.output == "only answer" + + async def test_reinjection_hits_max_steps_interrupts(self): + """R3 edge: 回灌期间达到 max_steps → 中断(不无限循环)。""" + # max_steps=2, max_reinjections=5(max_steps 先到) + # LLM 调用 1:final answer → verify 失败 → reinject + # LLM 调用 2:final answer → verify 失败 → step=2=max_steps → 中断 + gateway = make_mock_gateway( + [ + make_response("ans1"), + make_response("ans2"), + ] + ) + engine = ReActEngine( + llm_gateway=gateway, + max_steps=2, + verification_enabled=True, + verification_commands=["false"], + max_reinjections=5, # 远大于 max_steps,验证 max_steps 优先 + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop( + [ + make_verify_result(passed=False), + make_verify_result(passed=False), + ] + ), + ): + result = await engine.execute( + messages=[{"role": "user", "content": "do something"}], + ) + + # LLM 被调用 2 次(受 max_steps=2 限制) + assert gateway.chat.await_count == 2 + # 第二次 verify 失败,因 max_reinjections=5 未到但 max_steps 到了 + # → 应中断(verify_failed 或 partial) + assert result.status in ("verify_failed", "partial") + + +# ── execute_stream 回灌 ────────── + + +class TestVerifyReinjectionStream: + """execute_stream 模式下的回灌行为。""" + + async def test_stream_reinjection_first_fail_second_pass(self): + """R1 stream: verify 首次失败 → 回灌 → 二次通过。""" + from agentkit.llm.protocol import StreamChunk + + def make_stream_chunks(content: str): + """返回一个返回 chunk 列表的 async generator factory。""" + + async def _stream(**kwargs): + # Simulate streaming: yield content in 2 chunks + mid = len(content) // 2 + yield StreamChunk(content=content[:mid], model="test-model") + yield StreamChunk(content=content[mid:], model="test-model") + + return _stream + + gateway = MagicMock(spec=LLMGateway) + gateway.chat_stream = MagicMock( + side_effect=[ + make_stream_chunks("bad code")(), + make_stream_chunks("fixed code")(), + ] + ) + gateway.get_provider_name_for_model = MagicMock(return_value=None) + + engine = ReActEngine( + llm_gateway=gateway, + max_steps=5, + verification_enabled=True, + verification_commands=["pytest"], + max_reinjections=1, + ) + + with patch( + "agentkit.core.verification_loop.VerificationLoop", + return_value=make_mock_vloop( + [ + make_verify_result(passed=False), + make_verify_result(passed=True), + ] + ), + ): + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "write code"}], + ): + events.append(event) + + # chat_stream 被调用两次 + assert gateway.chat_stream.call_count == 2 + # 有 final_answer 事件 + final_events = [e for e in events if e.event_type == "final_answer"] + assert len(final_events) >= 1 + # 最终输出是修正后的答案 + assert "fixed code" in final_events[-1].data.get("output", "")