feat(U4): G1 verify 失败回灌 ReAct

- ReActEngine 新增 max_reinjections 构造参数(默认 1,=0 等价原行为)
- execute()/execute_stream() verify 块从循环后移到循环内 final-answer 检测点:
  - verify 通过 → 正常 break
  - verify 失败 + reinjections < max + step < max_steps → errors 作为 user 消息回灌 conversation, continue 让 LLM 自纠正
  - verify 失败 + 达到 max_reinjections 或 max_steps → 记录 verify log 到 trajectory, trace_outcome="verify_failed", break
- execute_stream 的 final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号
- ReActResult.status 现在传递 trace_outcome(原默认 "success")
- ServerConfig.verification 配置项(max_reinjections)
- test_verify_reinjection.py 10 测试:characterization(max=0)+ 新行为(R1/R2/R3/R14)
This commit is contained in:
chiguyong 2026-06-29 21:35:08 +08:00
parent 0f3f0a7550
commit cd211c6cd9
3 changed files with 574 additions and 97 deletions

View File

@ -165,6 +165,7 @@ class ReActEngine:
middleware_chain: "MiddlewareChain | None" = None,
prompt_cache_enable: bool = True,
flush_interval_ms: int = 0,
max_reinjections: int = 1,
):
if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
@ -184,6 +185,10 @@ class ReActEngine:
# U3/G8: token chunk 节流间隔(ms)。0 = 逐 chunk yield(向后兼容)。
# 用 time.monotonic() 不受系统时钟跳变影响。
self._flush_interval_ms = flush_interval_ms
# U4/G1: verify 失败回灌最大重试次数。0 = 不回灌(当前行为,仅记录 trajectory);
# 1 = 首次失败回灌一次 errors 给 LLM 自纠正,二次失败中断。
# 受 max_steps 上限约束(不无限循环)。verification_enabled=False 时无效。
self._max_reinjections = max_reinjections
# Tiered tool description injection config
self._core_tool_names: tuple[str, ...] | None = (
tuple(core_tool_names) if core_tool_names is not None else None
@ -445,11 +450,14 @@ class ReActEngine:
query = str(messages[-1].get("content", "")) if messages else ""
top_k = (retrieval_config or {}).get("top_k", 5)
token_budget = (retrieval_config or {}).get("token_budget", 2000)
memory_context = await memory_retriever.get_context_string(
memory_context = (
await memory_retriever.get_context_string(
query=query,
top_k=top_k,
token_budget=token_budget,
) or ""
)
or ""
)
except Exception as e:
logger.warning(
f"Memory retrieval failed, continuing without context: {e}", exc_info=True
@ -478,6 +486,8 @@ class ReActEngine:
trace_outcome = "success"
step = 0
output = ""
# U4/G1: verify 失败回灌计数器。受 max_steps 上限约束(不无限循环)。
reinjections = 0
while step < self._max_steps:
step += 1
@ -854,16 +864,14 @@ class ReActEngine:
f"Step {step}: content contains <tool_use> but "
f"parsing failed — injecting correction"
)
conversation.append(
{"role": "assistant", "content": response.content}
)
conversation.append({"role": "assistant", "content": response.content})
conversation.append(
{
"role": "user",
"content": (
"你上一次的工具调用格式有误,无法解析。"
"请使用正确的格式重新调用工具:\n"
'<tool_use>\n'
"<tool_use>\n"
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
"</tool_use>\n"
"确保 JSON 完整且不要混入其他标签。"
@ -891,9 +899,11 @@ class ReActEngine:
duration_ms=llm_duration_ms,
tokens_used=step_tokens,
)
break
# Verification: 如果启用验证,在 final answer 后运行测试
# U4/G1: verify at final-answer point with reinjection.
# 原为循环后一次性运行;现改为循环内检测 final answer 后立即 verify,
# 失败则把 errors 作为 user 消息回灌 conversation,continue 主循环让 LLM 自纠正。
# max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。
if self._verification_enabled and output:
try:
from agentkit.core.verification_loop import VerificationLoop
@ -901,9 +911,29 @@ class ReActEngine:
vloop = VerificationLoop(commands=self._verification_commands)
vresult = await vloop.verify()
if not vresult.passed:
# 将验证失败信息作为 ReActStep 添加到轨迹
if (
reinjections < self._max_reinjections
and step < self._max_steps
):
# 回灌 errors 作为 user 消息,让 LLM 自纠正
errors_text = "\n".join(vresult.errors)
conversation.append(
{
"role": "user",
"content": (f"验证失败,错误如下:\n{errors_text}"),
}
)
reinjections += 1
logger.info(
"Verification failed (reinjection %d/%d), "
"errors injected into conversation",
reinjections,
self._max_reinjections,
)
continue
# 达到 max_reinjections 或 max_steps → 记录 verify log 并中断
verification_step = ReActStep(
step=step + 1,
step=step,
action="tool_call",
tool_name="verification",
arguments={"commands": self._verification_commands},
@ -912,16 +942,23 @@ class ReActEngine:
"errors": vresult.errors,
"test_output": vresult.test_output,
},
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
content=(
f"Verification failed:\n{vresult.test_output[:2000]}"
),
)
trajectory.append(verification_step)
trace_outcome = "verify_failed"
logger.info(
"Verification failed after final answer, "
"appended feedback to trajectory"
"Verification failed after %d reinjections, "
"interrupting with verify log",
reinjections,
)
break
except Exception as e:
logger.warning(f"Verification loop failed: {e}")
break # verify 通过或未启用 → 正常退出
# 达到 max_steps 时,返回当前最佳输出
if step >= self._max_steps and not output:
trace_outcome = "partial"
@ -964,6 +1001,7 @@ class ReActEngine:
trajectory=trajectory,
total_steps=len(trajectory),
total_tokens=total_tokens,
status=trace_outcome,
)
finally:
# Telemetry: end span and record duration — always runs
@ -1058,11 +1096,14 @@ class ReActEngine:
query = str(messages[-1].get("content", "")) if messages else ""
top_k = (retrieval_config or {}).get("top_k", 5)
token_budget = (retrieval_config or {}).get("token_budget", 2000)
memory_context = await memory_retriever.get_context_string(
memory_context = (
await memory_retriever.get_context_string(
query=query,
top_k=top_k,
token_budget=token_budget,
) or ""
)
or ""
)
except Exception as e:
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
@ -1090,6 +1131,8 @@ class ReActEngine:
step = 0
output = ""
trace_outcome = "success"
# U4/G1: verify 失败回灌计数器(execute_stream 版)。受 max_steps 上限约束。
reinjections = 0
_stream_start = time.monotonic()
effective_timeout = (
timeout_seconds if timeout_seconds is not None else self._default_timeout
@ -1579,16 +1622,14 @@ class ReActEngine:
f"Step {step}: content contains <tool_use> but "
f"parsing failed — injecting correction (stream)"
)
conversation.append(
{"role": "assistant", "content": response.content}
)
conversation.append({"role": "assistant", "content": response.content})
conversation.append(
{
"role": "user",
"content": (
"你上一次的工具调用格式有误,无法解析。"
"请使用正确的格式重新调用工具:\n"
'<tool_use>\n'
"<tool_use>\n"
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
"</tool_use>\n"
"确保 JSON 完整且不要混入其他标签。"
@ -1622,18 +1663,10 @@ class ReActEngine:
tokens_used=step_tokens,
)
yield ReActEvent(
event_type="final_answer",
step=step,
data={
"output": output,
"total_steps": len(trajectory),
"total_tokens": total_tokens,
},
)
break
# Verification: 如果启用验证,在 final answer 后运行测试
# U4/G1: verify at final-answer point with reinjection (stream 版)。
# 与 execute() 同模式:失败回灌 errors 作为 user 消息,continue 主循环。
# max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。
# 注意:final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号。
if self._verification_enabled and output:
try:
from agentkit.core.verification_loop import VerificationLoop
@ -1641,8 +1674,34 @@ class ReActEngine:
vloop = VerificationLoop(commands=self._verification_commands)
vresult = await vloop.verify()
if not vresult.passed:
if (
reinjections < self._max_reinjections
and step < self._max_steps
):
# 回灌 errors,不发 final_answer 事件,继续循环
errors_text = "\n".join(vresult.errors)
conversation.append(
{
"role": "user",
"content": (f"验证失败,错误如下:\n{errors_text}"),
}
)
reinjections += 1
yield ReActEvent(
event_type="step",
step=step,
data={
"message": (
f"验证失败,已注入错误信息让 LLM 自纠正 "
f"(reinjection {reinjections}/{self._max_reinjections})"
),
"verify_errors": vresult.errors,
},
)
continue
# 达到 max_reinjections 或 max_steps → 记录 verify log 并中断
verification_step = ReActStep(
step=step + 1,
step=step,
action="tool_call",
tool_name="verification",
arguments={"commands": self._verification_commands},
@ -1651,12 +1710,15 @@ class ReActEngine:
"errors": vresult.errors,
"test_output": vresult.test_output,
},
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
content=(
f"Verification failed:\n{vresult.test_output[:2000]}"
),
)
trajectory.append(verification_step)
trace_outcome = "verify_failed"
yield ReActEvent(
event_type="tool_result",
step=step + 1,
step=step,
data={
"tool_name": "verification",
"result": {
@ -1667,12 +1729,25 @@ class ReActEngine:
},
)
logger.info(
"Verification failed after final answer, "
"appended feedback to trajectory"
"Verification failed after %d reinjections, "
"interrupting with verify log",
reinjections,
)
break
except Exception as e:
logger.warning(f"Verification loop failed: {e}")
yield ReActEvent(
event_type="final_answer",
step=step,
data={
"output": output,
"total_steps": len(trajectory),
"total_tokens": total_tokens,
},
)
break # verify 通过或未启用 → 正常退出
if step >= self._max_steps and not output:
trace_outcome = "partial"
if trajectory and trajectory[-1].content:
@ -1784,16 +1859,20 @@ class ReActEngine:
if provider_name == "anthropic":
blocks: list[dict[str, Any]] = []
if stable:
blocks.append({
blocks.append(
{
"type": "text",
"text": stable,
"cache_control": {"type": "ephemeral"},
})
}
)
if volatile:
blocks.append({
blocks.append(
{
"type": "text",
"text": f"## 参考信息\n{volatile}",
})
}
)
return blocks if blocks else None
# 非 Anthropic:字符串拼接,stable 前缀命中 OpenAI/DashScope 自动前缀缓存
@ -2213,7 +2292,7 @@ class ReActEngine:
# 兜底解析:从 <tool_use> 后提取 JSON 片段,用大括号匹配法恢复完整 JSON
open_pattern = re.compile(r"<tool_use>\s*", re.IGNORECASE)
for match in open_pattern.finditer(content):
remainder = content[match.end():]
remainder = content[match.end() :]
parsed = self._extract_tool_call_from_malformed(remainder)
if parsed:
calls.append(parsed)

View File

@ -118,6 +118,7 @@ class ServerConfig:
board: dict[str, Any] | None = None,
prompt_cache: dict[str, Any] | None = None,
streaming: dict[str, Any] | None = None,
verification: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None,
):
self.host = host
@ -149,6 +150,9 @@ class ServerConfig:
self.prompt_cache = prompt_cache or {}
# U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield)
self.streaming = streaming or {}
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
# verification_enabled=False 时此配置无效
self.verification = verification or {}
self.on_change = on_change
# Config watching state

View File

@ -0,0 +1,394 @@
"""U4/G1: Verify 失败回灌 ReAct 测试
characterization-first: 先覆盖现有 verify 行为(max_reinjections=0 等价于不回灌),
再测新回灌行为(reinjection on first fail, break on second fail)
R1: verify 首次失败 errors 注入 conversation LLM 自纠正 二次 verify 通过
R2: verify 二次失败 中断返回错误附 verify log
R3: max_reinjections 可配置(默认 1),=0 等价于不回灌;回灌受 max_steps 约束
R13: ServerConfig.verification 配置项
R14: max_reinjections 默认值为 1
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.core.react import ReActEngine
from agentkit.core.verification_loop import VerificationResult
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── Helpers ──────────────────────────────────────────────
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
"""创建按顺序返回给定响应的 mock LLMGateway。"""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
gateway.get_provider_name_for_model = MagicMock(return_value=None)
return gateway
def make_response(content: str = "") -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
tool_calls=[],
)
def make_verify_result(passed: bool, errors: list[str] | None = None) -> VerificationResult:
return VerificationResult(
passed=passed,
attempts=1,
test_output="$ pytest\nFAILED test_x.py" if not passed else "$ pytest\nOK",
errors=errors or ([] if passed else ["test_x.py::test_failed"]),
)
def make_mock_vloop(verify_results: list[VerificationResult]) -> MagicMock:
"""创建一个 mock VerificationLoop,verify() 按顺序返回给定结果。"""
vloop = MagicMock()
vloop.verify = AsyncMock(side_effect=verify_results)
return vloop
# ── Characterization: max_reinjections=0 等价于当前行为 ──────────
class TestVerifyCharacterization:
"""现有 verify 行为(max_reinjections=0):失败仅记录 trajectory,不回灌。"""
async def test_verify_disabled_no_verify_step(self):
"""verification_enabled=False → 不运行 verify,trajectory 无 verification step。"""
gateway = make_mock_gateway([make_response("final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
assert result.output == "final answer"
assert all(s.tool_name != "verification" for s in result.trajectory)
async def test_verify_pass_no_extra_step(self):
"""verify 通过 → 不追加 verification step。"""
gateway = make_mock_gateway([make_response("answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=3,
verification_enabled=True,
verification_commands=["echo ok"],
max_reinjections=0,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop([make_verify_result(passed=True)]),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
assert result.output == "answer"
assert all(s.tool_name != "verification" for s in result.trajectory)
async def test_verify_fail_max_zero_no_reinjection(self):
"""max_reinjections=0 + verify 失败 → 仅记录 trajectory,不回灌 LLM。
这是当前行为的 characterization:gateway.chat 只被调用一次
"""
gateway = make_mock_gateway([make_response("bad answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=3,
verification_enabled=True,
verification_commands=["false"],
max_reinjections=0,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop([make_verify_result(passed=False)]),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
# LLM 只被调用一次(无回灌)
assert gateway.chat.await_count == 1
# 输出仍保留
assert result.output == "bad answer"
# trajectory 包含 verification step
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
assert len(verify_steps) == 1
assert verify_steps[0].result["passed"] is False
# ── R1: 回灌后 LLM 自纠正 → 二次 verify 通过 ──────────
class TestVerifyReinjection:
"""verify 失败回灌 conversation,LLM 自纠正后二次 verify 通过。"""
async def test_first_fail_reinject_second_pass(self):
"""R1: verify 首次失败 → errors 注入 conversation → LLM 修正 → 二次 verify 通过。"""
gateway = make_mock_gateway(
[
make_response("bad code"), # 第一次:错误答案
make_response("fixed code"), # 第二次:修正后答案
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=1, # 默认值,允许 1 次回灌
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False), # 第一次 verify 失败
make_verify_result(passed=True), # 第二次 verify 通过
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# LLM 被调用两次
assert gateway.chat.await_count == 2
# 最终输出是修正后的答案
assert result.output == "fixed code"
# 二次 verify 通过,不追加 verification step
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
assert len(verify_steps) == 0
async def test_reinjected_user_message_appears_in_conversation(self):
"""R1 集成:回灌的 user 消息出现在 conversation,含 errors 文本。"""
gateway = make_mock_gateway(
[
make_response("bad"),
make_response("good"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=1,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False, errors=["AssertionError: x != y"]),
make_verify_result(passed=True),
]
),
):
await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# 第二次 LLM 调用时,conversation 应包含回灌的 user 消息
second_call = gateway.chat.await_args_list[1]
msgs_sent = second_call.kwargs.get("messages") or second_call[1].get("messages")
reinjected = [
m for m in msgs_sent if m.get("role") == "user" and "验证失败" in m.get("content", "")
]
assert len(reinjected) >= 1
assert "AssertionError: x != y" in reinjected[-1]["content"]
# ── R2: 二次 verify 失败 → 中断返回错误 ──────────
class TestVerifyDoubleFailure:
"""verify 二次失败 → 中断,返回错误附 verify log。"""
async def test_second_fail_breaks_with_verify_log(self):
"""R2: 二次 verify 失败 → 中断,trajectory 含 verify log + errors。"""
gateway = make_mock_gateway(
[
make_response("bad v1"),
make_response("bad v2"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=1,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False, errors=["err1"]),
make_verify_result(passed=False, errors=["err2"]),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# LLM 被调用两次(initial + 1 reinjection)
assert gateway.chat.await_count == 2
# 状态标记 verify 失败
assert result.status == "verify_failed"
# 输出保留(LLM 最后的答案)
assert result.output == "bad v2"
# trajectory 含 verification step with errors
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
assert len(verify_steps) == 1
assert verify_steps[0].result["passed"] is False
assert "err2" in verify_steps[0].result["errors"]
# ── R3: 配置 + 边界 ──────────
class TestVerifyReinjectionConfig:
"""max_reinjections 配置测试。"""
def test_default_max_reinjections_is_one(self):
"""R14 self-check: max_reinjections 默认值为 1。"""
gateway = make_mock_gateway([])
engine = ReActEngine(llm_gateway=gateway)
assert engine._max_reinjections == 1
async def test_max_reinjections_zero_skips_reinjection(self):
"""R3: max_reinjections=0 → 等价于不回灌(当前行为)。"""
gateway = make_mock_gateway([make_response("only answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["false"],
max_reinjections=0,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop([make_verify_result(passed=False)]),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
assert gateway.chat.await_count == 1 # 无回灌
assert result.output == "only answer"
async def test_reinjection_hits_max_steps_interrupts(self):
"""R3 edge: 回灌期间达到 max_steps → 中断(不无限循环)。"""
# max_steps=2, max_reinjections=5(max_steps 先到)
# LLM 调用 1:final answer → verify 失败 → reinject
# LLM 调用 2:final answer → verify 失败 → step=2=max_steps → 中断
gateway = make_mock_gateway(
[
make_response("ans1"),
make_response("ans2"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=2,
verification_enabled=True,
verification_commands=["false"],
max_reinjections=5, # 远大于 max_steps,验证 max_steps 优先
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=False),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
# LLM 被调用 2 次(受 max_steps=2 限制)
assert gateway.chat.await_count == 2
# 第二次 verify 失败,因 max_reinjections=5 未到但 max_steps 到了
# → 应中断(verify_failed 或 partial)
assert result.status in ("verify_failed", "partial")
# ── execute_stream 回灌 ──────────
class TestVerifyReinjectionStream:
"""execute_stream 模式下的回灌行为。"""
async def test_stream_reinjection_first_fail_second_pass(self):
"""R1 stream: verify 首次失败 → 回灌 → 二次通过。"""
from agentkit.llm.protocol import StreamChunk
def make_stream_chunks(content: str):
"""返回一个返回 chunk 列表的 async generator factory。"""
async def _stream(**kwargs):
# Simulate streaming: yield content in 2 chunks
mid = len(content) // 2
yield StreamChunk(content=content[:mid], model="test-model")
yield StreamChunk(content=content[mid:], model="test-model")
return _stream
gateway = MagicMock(spec=LLMGateway)
gateway.chat_stream = MagicMock(
side_effect=[
make_stream_chunks("bad code")(),
make_stream_chunks("fixed code")(),
]
)
gateway.get_provider_name_for_model = MagicMock(return_value=None)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=1,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=True),
]
),
):
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "write code"}],
):
events.append(event)
# chat_stream 被调用两次
assert gateway.chat_stream.call_count == 2
# 有 final_answer 事件
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) >= 1
# 最终输出是修正后的答案
assert "fixed code" in final_events[-1].data.get("output", "")