feat(agent): Wave 1 quick wins (G1/G2/G3/G8) + review fixes #4
|
|
@ -165,6 +165,7 @@ class ReActEngine:
|
|||
middleware_chain: "MiddlewareChain | None" = None,
|
||||
prompt_cache_enable: bool = True,
|
||||
flush_interval_ms: int = 0,
|
||||
max_reinjections: int = 1,
|
||||
):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
|
|
@ -184,6 +185,10 @@ class ReActEngine:
|
|||
# U3/G8: token chunk 节流间隔(ms)。0 = 逐 chunk yield(向后兼容)。
|
||||
# 用 time.monotonic() 不受系统时钟跳变影响。
|
||||
self._flush_interval_ms = flush_interval_ms
|
||||
# U4/G1: verify 失败回灌最大重试次数。0 = 不回灌(当前行为,仅记录 trajectory);
|
||||
# 1 = 首次失败回灌一次 errors 给 LLM 自纠正,二次失败中断。
|
||||
# 受 max_steps 上限约束(不无限循环)。verification_enabled=False 时无效。
|
||||
self._max_reinjections = max_reinjections
|
||||
# Tiered tool description injection config
|
||||
self._core_tool_names: tuple[str, ...] | None = (
|
||||
tuple(core_tool_names) if core_tool_names is not None else None
|
||||
|
|
@ -445,11 +450,14 @@ class ReActEngine:
|
|||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
memory_context = (
|
||||
await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
) or ""
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Memory retrieval failed, continuing without context: {e}", exc_info=True
|
||||
|
|
@ -478,6 +486,8 @@ class ReActEngine:
|
|||
trace_outcome = "success"
|
||||
step = 0
|
||||
output = ""
|
||||
# U4/G1: verify 失败回灌计数器。受 max_steps 上限约束(不无限循环)。
|
||||
reinjections = 0
|
||||
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
|
@ -854,16 +864,14 @@ class ReActEngine:
|
|||
f"Step {step}: content contains <tool_use> but "
|
||||
f"parsing failed — injecting correction"
|
||||
)
|
||||
conversation.append(
|
||||
{"role": "assistant", "content": response.content}
|
||||
)
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"你上一次的工具调用格式有误,无法解析。"
|
||||
"请使用正确的格式重新调用工具:\n"
|
||||
'<tool_use>\n'
|
||||
"<tool_use>\n"
|
||||
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||||
"</tool_use>\n"
|
||||
"确保 JSON 完整且不要混入其他标签。"
|
||||
|
|
@ -891,9 +899,11 @@ class ReActEngine:
|
|||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
break
|
||||
|
||||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||||
# U4/G1: verify at final-answer point with reinjection.
|
||||
# 原为循环后一次性运行;现改为循环内检测 final answer 后立即 verify,
|
||||
# 失败则把 errors 作为 user 消息回灌 conversation,continue 主循环让 LLM 自纠正。
|
||||
# max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。
|
||||
if self._verification_enabled and output:
|
||||
try:
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
|
@ -901,9 +911,29 @@ class ReActEngine:
|
|||
vloop = VerificationLoop(commands=self._verification_commands)
|
||||
vresult = await vloop.verify()
|
||||
if not vresult.passed:
|
||||
# 将验证失败信息作为 ReActStep 添加到轨迹
|
||||
if (
|
||||
reinjections < self._max_reinjections
|
||||
and step < self._max_steps
|
||||
):
|
||||
# 回灌 errors 作为 user 消息,让 LLM 自纠正
|
||||
errors_text = "\n".join(vresult.errors)
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (f"验证失败,错误如下:\n{errors_text}"),
|
||||
}
|
||||
)
|
||||
reinjections += 1
|
||||
logger.info(
|
||||
"Verification failed (reinjection %d/%d), "
|
||||
"errors injected into conversation",
|
||||
reinjections,
|
||||
self._max_reinjections,
|
||||
)
|
||||
continue
|
||||
# 达到 max_reinjections 或 max_steps → 记录 verify log 并中断
|
||||
verification_step = ReActStep(
|
||||
step=step + 1,
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name="verification",
|
||||
arguments={"commands": self._verification_commands},
|
||||
|
|
@ -912,16 +942,23 @@ class ReActEngine:
|
|||
"errors": vresult.errors,
|
||||
"test_output": vresult.test_output,
|
||||
},
|
||||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||||
content=(
|
||||
f"Verification failed:\n{vresult.test_output[:2000]}"
|
||||
),
|
||||
)
|
||||
trajectory.append(verification_step)
|
||||
trace_outcome = "verify_failed"
|
||||
logger.info(
|
||||
"Verification failed after final answer, "
|
||||
"appended feedback to trajectory"
|
||||
"Verification failed after %d reinjections, "
|
||||
"interrupting with verify log",
|
||||
reinjections,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification loop failed: {e}")
|
||||
|
||||
break # verify 通过或未启用 → 正常退出
|
||||
|
||||
# 达到 max_steps 时,返回当前最佳输出
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
|
|
@ -964,6 +1001,7 @@ class ReActEngine:
|
|||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
status=trace_outcome,
|
||||
)
|
||||
finally:
|
||||
# Telemetry: end span and record duration — always runs
|
||||
|
|
@ -1058,11 +1096,14 @@ class ReActEngine:
|
|||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
memory_context = (
|
||||
await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
token_budget=token_budget,
|
||||
) or ""
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
|
||||
|
|
@ -1090,6 +1131,8 @@ class ReActEngine:
|
|||
step = 0
|
||||
output = ""
|
||||
trace_outcome = "success"
|
||||
# U4/G1: verify 失败回灌计数器(execute_stream 版)。受 max_steps 上限约束。
|
||||
reinjections = 0
|
||||
_stream_start = time.monotonic()
|
||||
effective_timeout = (
|
||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
|
|
@ -1579,16 +1622,14 @@ class ReActEngine:
|
|||
f"Step {step}: content contains <tool_use> but "
|
||||
f"parsing failed — injecting correction (stream)"
|
||||
)
|
||||
conversation.append(
|
||||
{"role": "assistant", "content": response.content}
|
||||
)
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"你上一次的工具调用格式有误,无法解析。"
|
||||
"请使用正确的格式重新调用工具:\n"
|
||||
'<tool_use>\n'
|
||||
"<tool_use>\n"
|
||||
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||||
"</tool_use>\n"
|
||||
"确保 JSON 完整且不要混入其他标签。"
|
||||
|
|
@ -1622,18 +1663,10 @@ class ReActEngine:
|
|||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": len(trajectory),
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||||
# U4/G1: verify at final-answer point with reinjection (stream 版)。
|
||||
# 与 execute() 同模式:失败回灌 errors 作为 user 消息,continue 主循环。
|
||||
# max_reinjections=0 等价于原行为(仅记录 trajectory,不回灌)。
|
||||
# 注意:final_answer 事件在 verify 通过后才 yield,避免客户端过早收到完成信号。
|
||||
if self._verification_enabled and output:
|
||||
try:
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
|
@ -1641,8 +1674,34 @@ class ReActEngine:
|
|||
vloop = VerificationLoop(commands=self._verification_commands)
|
||||
vresult = await vloop.verify()
|
||||
if not vresult.passed:
|
||||
if (
|
||||
reinjections < self._max_reinjections
|
||||
and step < self._max_steps
|
||||
):
|
||||
# 回灌 errors,不发 final_answer 事件,继续循环
|
||||
errors_text = "\n".join(vresult.errors)
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (f"验证失败,错误如下:\n{errors_text}"),
|
||||
}
|
||||
)
|
||||
reinjections += 1
|
||||
yield ReActEvent(
|
||||
event_type="step",
|
||||
step=step,
|
||||
data={
|
||||
"message": (
|
||||
f"验证失败,已注入错误信息让 LLM 自纠正 "
|
||||
f"(reinjection {reinjections}/{self._max_reinjections})"
|
||||
),
|
||||
"verify_errors": vresult.errors,
|
||||
},
|
||||
)
|
||||
continue
|
||||
# 达到 max_reinjections 或 max_steps → 记录 verify log 并中断
|
||||
verification_step = ReActStep(
|
||||
step=step + 1,
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name="verification",
|
||||
arguments={"commands": self._verification_commands},
|
||||
|
|
@ -1651,12 +1710,15 @@ class ReActEngine:
|
|||
"errors": vresult.errors,
|
||||
"test_output": vresult.test_output,
|
||||
},
|
||||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||||
content=(
|
||||
f"Verification failed:\n{vresult.test_output[:2000]}"
|
||||
),
|
||||
)
|
||||
trajectory.append(verification_step)
|
||||
trace_outcome = "verify_failed"
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step + 1,
|
||||
step=step,
|
||||
data={
|
||||
"tool_name": "verification",
|
||||
"result": {
|
||||
|
|
@ -1667,12 +1729,25 @@ class ReActEngine:
|
|||
},
|
||||
)
|
||||
logger.info(
|
||||
"Verification failed after final answer, "
|
||||
"appended feedback to trajectory"
|
||||
"Verification failed after %d reinjections, "
|
||||
"interrupting with verify log",
|
||||
reinjections,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification loop failed: {e}")
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": len(trajectory),
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
break # verify 通过或未启用 → 正常退出
|
||||
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
if trajectory and trajectory[-1].content:
|
||||
|
|
@ -1784,16 +1859,20 @@ class ReActEngine:
|
|||
if provider_name == "anthropic":
|
||||
blocks: list[dict[str, Any]] = []
|
||||
if stable:
|
||||
blocks.append({
|
||||
blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": stable,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
})
|
||||
}
|
||||
)
|
||||
if volatile:
|
||||
blocks.append({
|
||||
blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"## 参考信息\n{volatile}",
|
||||
})
|
||||
}
|
||||
)
|
||||
return blocks if blocks else None
|
||||
|
||||
# 非 Anthropic:字符串拼接,stable 前缀命中 OpenAI/DashScope 自动前缀缓存
|
||||
|
|
@ -2213,7 +2292,7 @@ class ReActEngine:
|
|||
# 兜底解析:从 <tool_use> 后提取 JSON 片段,用大括号匹配法恢复完整 JSON
|
||||
open_pattern = re.compile(r"<tool_use>\s*", re.IGNORECASE)
|
||||
for match in open_pattern.finditer(content):
|
||||
remainder = content[match.end():]
|
||||
remainder = content[match.end() :]
|
||||
parsed = self._extract_tool_call_from_malformed(remainder)
|
||||
if parsed:
|
||||
calls.append(parsed)
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ class ServerConfig:
|
|||
board: dict[str, Any] | None = None,
|
||||
prompt_cache: dict[str, Any] | None = None,
|
||||
streaming: dict[str, Any] | None = None,
|
||||
verification: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -149,6 +150,9 @@ class ServerConfig:
|
|||
self.prompt_cache = prompt_cache or {}
|
||||
# U3/G8: streaming.flush_interval_ms 控制 token chunk 节流(默认 0 = 逐 chunk yield)
|
||||
self.streaming = streaming or {}
|
||||
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
||||
# verification_enabled=False 时此配置无效
|
||||
self.verification = verification or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
|
|||
|
|
@ -0,0 +1,394 @@
|
|||
"""U4/G1: Verify 失败回灌 ReAct 测试
|
||||
|
||||
characterization-first: 先覆盖现有 verify 行为(max_reinjections=0 等价于不回灌),
|
||||
再测新回灌行为(reinjection on first fail, break on second fail)。
|
||||
|
||||
R1: verify 首次失败 → errors 注入 conversation → LLM 自纠正 → 二次 verify 通过
|
||||
R2: verify 二次失败 → 中断返回错误附 verify log
|
||||
R3: max_reinjections 可配置(默认 1),=0 等价于不回灌;回灌受 max_steps 约束
|
||||
R13: ServerConfig.verification 配置项
|
||||
R14: max_reinjections 默认值为 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||||
"""创建按顺序返回给定响应的 mock LLMGateway。"""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
gateway.get_provider_name_for_model = MagicMock(return_value=None)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_response(content: str = "") -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
tool_calls=[],
|
||||
)
|
||||
|
||||
|
||||
def make_verify_result(passed: bool, errors: list[str] | None = None) -> VerificationResult:
|
||||
return VerificationResult(
|
||||
passed=passed,
|
||||
attempts=1,
|
||||
test_output="$ pytest\nFAILED test_x.py" if not passed else "$ pytest\nOK",
|
||||
errors=errors or ([] if passed else ["test_x.py::test_failed"]),
|
||||
)
|
||||
|
||||
|
||||
def make_mock_vloop(verify_results: list[VerificationResult]) -> MagicMock:
|
||||
"""创建一个 mock VerificationLoop,verify() 按顺序返回给定结果。"""
|
||||
vloop = MagicMock()
|
||||
vloop.verify = AsyncMock(side_effect=verify_results)
|
||||
return vloop
|
||||
|
||||
|
||||
# ── Characterization: max_reinjections=0 等价于当前行为 ──────────
|
||||
|
||||
|
||||
class TestVerifyCharacterization:
|
||||
"""现有 verify 行为(max_reinjections=0):失败仅记录 trajectory,不回灌。"""
|
||||
|
||||
async def test_verify_disabled_no_verify_step(self):
|
||||
"""verification_enabled=False → 不运行 verify,trajectory 无 verification step。"""
|
||||
gateway = make_mock_gateway([make_response("final answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
|
||||
assert result.output == "final answer"
|
||||
assert all(s.tool_name != "verification" for s in result.trajectory)
|
||||
|
||||
async def test_verify_pass_no_extra_step(self):
|
||||
"""verify 通过 → 不追加 verification step。"""
|
||||
gateway = make_mock_gateway([make_response("answer")])
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=3,
|
||||
verification_enabled=True,
|
||||
verification_commands=["echo ok"],
|
||||
max_reinjections=0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop([make_verify_result(passed=True)]),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
|
||||
assert result.output == "answer"
|
||||
assert all(s.tool_name != "verification" for s in result.trajectory)
|
||||
|
||||
async def test_verify_fail_max_zero_no_reinjection(self):
|
||||
"""max_reinjections=0 + verify 失败 → 仅记录 trajectory,不回灌 LLM。
|
||||
|
||||
这是当前行为的 characterization:gateway.chat 只被调用一次。
|
||||
"""
|
||||
gateway = make_mock_gateway([make_response("bad answer")])
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=3,
|
||||
verification_enabled=True,
|
||||
verification_commands=["false"],
|
||||
max_reinjections=0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop([make_verify_result(passed=False)]),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
|
||||
# LLM 只被调用一次(无回灌)
|
||||
assert gateway.chat.await_count == 1
|
||||
# 输出仍保留
|
||||
assert result.output == "bad answer"
|
||||
# trajectory 包含 verification step
|
||||
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
|
||||
assert len(verify_steps) == 1
|
||||
assert verify_steps[0].result["passed"] is False
|
||||
|
||||
|
||||
# ── R1: 回灌后 LLM 自纠正 → 二次 verify 通过 ──────────
|
||||
|
||||
|
||||
class TestVerifyReinjection:
|
||||
"""verify 失败回灌 conversation,LLM 自纠正后二次 verify 通过。"""
|
||||
|
||||
async def test_first_fail_reinject_second_pass(self):
|
||||
"""R1: verify 首次失败 → errors 注入 conversation → LLM 修正 → 二次 verify 通过。"""
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response("bad code"), # 第一次:错误答案
|
||||
make_response("fixed code"), # 第二次:修正后答案
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=5,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
max_reinjections=1, # 默认值,允许 1 次回灌
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop(
|
||||
[
|
||||
make_verify_result(passed=False), # 第一次 verify 失败
|
||||
make_verify_result(passed=True), # 第二次 verify 通过
|
||||
]
|
||||
),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "write code"}],
|
||||
)
|
||||
|
||||
# LLM 被调用两次
|
||||
assert gateway.chat.await_count == 2
|
||||
# 最终输出是修正后的答案
|
||||
assert result.output == "fixed code"
|
||||
# 二次 verify 通过,不追加 verification step
|
||||
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
|
||||
assert len(verify_steps) == 0
|
||||
|
||||
async def test_reinjected_user_message_appears_in_conversation(self):
|
||||
"""R1 集成:回灌的 user 消息出现在 conversation,含 errors 文本。"""
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response("bad"),
|
||||
make_response("good"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=5,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
max_reinjections=1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop(
|
||||
[
|
||||
make_verify_result(passed=False, errors=["AssertionError: x != y"]),
|
||||
make_verify_result(passed=True),
|
||||
]
|
||||
),
|
||||
):
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "write code"}],
|
||||
)
|
||||
|
||||
# 第二次 LLM 调用时,conversation 应包含回灌的 user 消息
|
||||
second_call = gateway.chat.await_args_list[1]
|
||||
msgs_sent = second_call.kwargs.get("messages") or second_call[1].get("messages")
|
||||
reinjected = [
|
||||
m for m in msgs_sent if m.get("role") == "user" and "验证失败" in m.get("content", "")
|
||||
]
|
||||
assert len(reinjected) >= 1
|
||||
assert "AssertionError: x != y" in reinjected[-1]["content"]
|
||||
|
||||
|
||||
# ── R2: 二次 verify 失败 → 中断返回错误 ──────────
|
||||
|
||||
|
||||
class TestVerifyDoubleFailure:
|
||||
"""verify 二次失败 → 中断,返回错误附 verify log。"""
|
||||
|
||||
async def test_second_fail_breaks_with_verify_log(self):
|
||||
"""R2: 二次 verify 失败 → 中断,trajectory 含 verify log + errors。"""
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response("bad v1"),
|
||||
make_response("bad v2"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=5,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
max_reinjections=1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop(
|
||||
[
|
||||
make_verify_result(passed=False, errors=["err1"]),
|
||||
make_verify_result(passed=False, errors=["err2"]),
|
||||
]
|
||||
),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "write code"}],
|
||||
)
|
||||
|
||||
# LLM 被调用两次(initial + 1 reinjection)
|
||||
assert gateway.chat.await_count == 2
|
||||
# 状态标记 verify 失败
|
||||
assert result.status == "verify_failed"
|
||||
# 输出保留(LLM 最后的答案)
|
||||
assert result.output == "bad v2"
|
||||
# trajectory 含 verification step with errors
|
||||
verify_steps = [s for s in result.trajectory if s.tool_name == "verification"]
|
||||
assert len(verify_steps) == 1
|
||||
assert verify_steps[0].result["passed"] is False
|
||||
assert "err2" in verify_steps[0].result["errors"]
|
||||
|
||||
|
||||
# ── R3: 配置 + 边界 ──────────
|
||||
|
||||
|
||||
class TestVerifyReinjectionConfig:
|
||||
"""max_reinjections 配置测试。"""
|
||||
|
||||
def test_default_max_reinjections_is_one(self):
|
||||
"""R14 self-check: max_reinjections 默认值为 1。"""
|
||||
gateway = make_mock_gateway([])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
assert engine._max_reinjections == 1
|
||||
|
||||
async def test_max_reinjections_zero_skips_reinjection(self):
|
||||
"""R3: max_reinjections=0 → 等价于不回灌(当前行为)。"""
|
||||
gateway = make_mock_gateway([make_response("only answer")])
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=5,
|
||||
verification_enabled=True,
|
||||
verification_commands=["false"],
|
||||
max_reinjections=0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop([make_verify_result(passed=False)]),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
|
||||
assert gateway.chat.await_count == 1 # 无回灌
|
||||
assert result.output == "only answer"
|
||||
|
||||
async def test_reinjection_hits_max_steps_interrupts(self):
|
||||
"""R3 edge: 回灌期间达到 max_steps → 中断(不无限循环)。"""
|
||||
# max_steps=2, max_reinjections=5(max_steps 先到)
|
||||
# LLM 调用 1:final answer → verify 失败 → reinject
|
||||
# LLM 调用 2:final answer → verify 失败 → step=2=max_steps → 中断
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response("ans1"),
|
||||
make_response("ans2"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=2,
|
||||
verification_enabled=True,
|
||||
verification_commands=["false"],
|
||||
max_reinjections=5, # 远大于 max_steps,验证 max_steps 优先
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop(
|
||||
[
|
||||
make_verify_result(passed=False),
|
||||
make_verify_result(passed=False),
|
||||
]
|
||||
),
|
||||
):
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
|
||||
# LLM 被调用 2 次(受 max_steps=2 限制)
|
||||
assert gateway.chat.await_count == 2
|
||||
# 第二次 verify 失败,因 max_reinjections=5 未到但 max_steps 到了
|
||||
# → 应中断(verify_failed 或 partial)
|
||||
assert result.status in ("verify_failed", "partial")
|
||||
|
||||
|
||||
# ── execute_stream 回灌 ──────────
|
||||
|
||||
|
||||
class TestVerifyReinjectionStream:
|
||||
"""execute_stream 模式下的回灌行为。"""
|
||||
|
||||
async def test_stream_reinjection_first_fail_second_pass(self):
|
||||
"""R1 stream: verify 首次失败 → 回灌 → 二次通过。"""
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
|
||||
def make_stream_chunks(content: str):
|
||||
"""返回一个返回 chunk 列表的 async generator factory。"""
|
||||
|
||||
async def _stream(**kwargs):
|
||||
# Simulate streaming: yield content in 2 chunks
|
||||
mid = len(content) // 2
|
||||
yield StreamChunk(content=content[:mid], model="test-model")
|
||||
yield StreamChunk(content=content[mid:], model="test-model")
|
||||
|
||||
return _stream
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat_stream = MagicMock(
|
||||
side_effect=[
|
||||
make_stream_chunks("bad code")(),
|
||||
make_stream_chunks("fixed code")(),
|
||||
]
|
||||
)
|
||||
gateway.get_provider_name_for_model = MagicMock(return_value=None)
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
max_steps=5,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
max_reinjections=1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agentkit.core.verification_loop.VerificationLoop",
|
||||
return_value=make_mock_vloop(
|
||||
[
|
||||
make_verify_result(passed=False),
|
||||
make_verify_result(passed=True),
|
||||
]
|
||||
),
|
||||
):
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "write code"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# chat_stream 被调用两次
|
||||
assert gateway.chat_stream.call_count == 2
|
||||
# 有 final_answer 事件
|
||||
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(final_events) >= 1
|
||||
# 最终输出是修正后的答案
|
||||
assert "fixed code" in final_events[-1].data.get("output", "")
|
||||
Loading…
Reference in New Issue