feat(core): step budget phases + keep working bias (U4, R11/R10)
This commit is contained in:
parent
b8418968c2
commit
4255cb33ba
|
|
@ -95,6 +95,10 @@ class PhasePolicy:
|
|||
# VerificationLoop defaults). An empty list means "no commands" (verification
|
||||
# passes trivially — for non-coding tasks using Spec-declared commands instead).
|
||||
verification_commands: list[str] | None = None
|
||||
# U4/R11: total step budget for the plan (sum of think+verify+reflect).
|
||||
# None = use ReActEngine's max_steps. Provides a checkpoint-reconstructable
|
||||
# record of the plan's total step budget (KTD-7).
|
||||
step_budget: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Fail-fast: empty whitelist for a non-wildcard phase = bug.
|
||||
|
|
@ -142,6 +146,7 @@ class PhasePolicy:
|
|||
"auto_advance_after_steps": self.auto_advance_after_steps,
|
||||
"start_phase": self.start_phase.value,
|
||||
"verification_commands": self.verification_commands,
|
||||
"step_budget": self.step_budget,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ logger = logging.getLogger(__name__)
|
|||
# 最大重规划次数
|
||||
_DEFAULT_MAX_REPLANS = 2
|
||||
|
||||
# U4/R11: default phase budgets for PLAN_EXEC. think=7 (exploration),
|
||||
# verify=2 (two verification attempts), reflect=1 (one re-injection).
|
||||
_DEFAULT_PHASE_BUDGETS: dict[str, int] = {"think": 7, "verify": 2, "reflect": 1}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StreamState:
|
||||
|
|
@ -82,6 +86,10 @@ class PlanExecEngine:
|
|||
# default) — only PLAN_EXEC/TEAM_COLLAB verify.
|
||||
verification_enabled: bool = True,
|
||||
verification_commands: list[str] | None = None,
|
||||
# U4/R11: per-phase step quotas for PLAN_EXEC. None = use defaults
|
||||
# (think=7, verify=2, reflect=1). Threaded through to each step's
|
||||
# ReActEngine.
|
||||
phase_budgets: dict[str, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -97,6 +105,8 @@ class PlanExecEngine:
|
|||
verification_commands: Optional override for the verification
|
||||
commands. None = let ReActEngine / VerificationLoop use its
|
||||
own defaults (pytest -x -q, ruff check src/).
|
||||
phase_budgets: U4/R11 — per-phase step quotas. None = use
|
||||
_DEFAULT_PHASE_BUDGETS (think=7, verify=2, reflect=1).
|
||||
"""
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_replans = max_replans
|
||||
|
|
@ -107,6 +117,10 @@ class PlanExecEngine:
|
|||
self._confirmation_handler: Any | None = None
|
||||
self._verification_enabled = verification_enabled
|
||||
self._verification_commands = verification_commands
|
||||
# U4/R11: copy the default to avoid mutating the module-level dict.
|
||||
self._phase_budgets = (
|
||||
dict(phase_budgets) if phase_budgets is not None else dict(_DEFAULT_PHASE_BUDGETS)
|
||||
)
|
||||
|
||||
# 组合子组件
|
||||
self._planner = GoalPlanner(llm_gateway=llm_gateway)
|
||||
|
|
@ -946,6 +960,7 @@ class PlanExecEngine:
|
|||
confirmation_handler=self._confirmation_handler,
|
||||
verification_enabled=self._verification_enabled,
|
||||
verification_commands=self._verification_commands,
|
||||
phase_budgets=self._phase_budgets,
|
||||
)
|
||||
return PlanExecutor(
|
||||
agent_pool=step_executor,
|
||||
|
|
@ -1108,6 +1123,7 @@ class ReActStepExecutor:
|
|||
confirmation_handler: Any | None = None,
|
||||
verification_enabled: bool = False,
|
||||
verification_commands: list[str] | None = None,
|
||||
phase_budgets: dict[str, int] | None = None,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._messages = messages or []
|
||||
|
|
@ -1118,6 +1134,8 @@ class ReActStepExecutor:
|
|||
self._confirmation_handler = confirmation_handler
|
||||
self._verification_enabled = verification_enabled
|
||||
self._verification_commands = verification_commands
|
||||
# U4/R11: thread through to each step's ReActEngine.
|
||||
self._phase_budgets = phase_budgets
|
||||
self._agents: dict[str, _ReActStepAgent] = {}
|
||||
|
||||
async def create_agent_from_skill(self, skill_name: str):
|
||||
|
|
@ -1133,6 +1151,7 @@ class ReActStepExecutor:
|
|||
confirmation_handler=self._confirmation_handler,
|
||||
verification_enabled=self._verification_enabled,
|
||||
verification_commands=self._verification_commands,
|
||||
phase_budgets=self._phase_budgets,
|
||||
)
|
||||
self._agents[skill_name] = agent
|
||||
return agent
|
||||
|
|
@ -1151,6 +1170,7 @@ class ReActStepExecutor:
|
|||
max_steps=self._max_steps,
|
||||
verification_enabled=self._verification_enabled,
|
||||
verification_commands=self._verification_commands,
|
||||
phase_budgets=self._phase_budgets,
|
||||
)
|
||||
self._agents[key] = agent
|
||||
return agent
|
||||
|
|
@ -1175,6 +1195,7 @@ class _ReActStepAgent:
|
|||
confirmation_handler: Any | None = None,
|
||||
verification_enabled: bool = False,
|
||||
verification_commands: list[str] | None = None,
|
||||
phase_budgets: dict[str, int] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self._llm_gateway = llm_gateway
|
||||
|
|
@ -1186,6 +1207,8 @@ class _ReActStepAgent:
|
|||
self._confirmation_handler = confirmation_handler
|
||||
self._verification_enabled = verification_enabled
|
||||
self._verification_commands = verification_commands
|
||||
# U4/R11: per-phase step quotas, passed to ReActEngine.
|
||||
self._phase_budgets = phase_budgets
|
||||
|
||||
async def execute(self, task_msg: TaskMessage) -> "TaskResult":
|
||||
"""执行步骤:通过 ReActEngine 循环调用"""
|
||||
|
|
@ -1215,6 +1238,7 @@ class _ReActStepAgent:
|
|||
max_steps=self._max_steps,
|
||||
verification_enabled=self._verification_enabled,
|
||||
verification_commands=self._verification_commands,
|
||||
phase_budgets=self._phase_budgets,
|
||||
)
|
||||
|
||||
# 构建 messages
|
||||
|
|
|
|||
|
|
@ -189,6 +189,14 @@ class ReActEngine:
|
|||
# cannot make outbound network calls during verification. None = no
|
||||
# sandbox (backward compat for DIRECT_CHAT/REACT and existing tests).
|
||||
sandbox: "WorkspaceSandbox | None" = None,
|
||||
# U4/R11: per-phase step quotas (opt-in for PLAN_EXEC/TEAM_COLLAB).
|
||||
# None = current behavior (max_steps total budget). When set:
|
||||
# think — max steps in PLANNING/BUILDING before forced verify
|
||||
# verify — max verification attempts before returning best result
|
||||
# reflect — max re-injections after verify fail (overrides
|
||||
# max_reinjections)
|
||||
# Loop detector threshold raised from 2 to 3 (R10/RV22).
|
||||
phase_budgets: dict[str, int] | None = None,
|
||||
):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
|
|
@ -259,6 +267,19 @@ class ReActEngine:
|
|||
# U3/RV3: minimum sandbox. When set and current phase is VERIFICATION,
|
||||
# _execute_tool wraps tool.safe_execute() in sandbox.network_block().
|
||||
self._sandbox = sandbox
|
||||
# U4/R11: per-phase budget quotas.
|
||||
self._phase_budgets = phase_budgets
|
||||
if phase_budgets is not None:
|
||||
# R10/RV22: keep-working mode raises loop threshold 2->3.
|
||||
self._loop_threshold = 3
|
||||
# R10: reflect quota overrides _max_reinjections.
|
||||
if "reflect" in phase_budgets:
|
||||
self._max_reinjections = phase_budgets["reflect"]
|
||||
# U4/KTD-7: budget counters (checkpoint-reconstructable via
|
||||
# restore_budget_state). Reset to 0 on fresh execute().
|
||||
self._think_count: int = 0
|
||||
self._verify_count: int = 0
|
||||
self._reflect_count: int = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
|
@ -269,8 +290,7 @@ class ReActEngine:
|
|||
# ReActEngine is stateless between calls — conversation history,
|
||||
# step counts, and trajectory are local to each execute call.
|
||||
# This method exists for API clarity and future stateful extensions.
|
||||
self._loop_window.clear()
|
||||
self._loop_corrected = False
|
||||
self._reset_loop_detector()
|
||||
# U3/G6: reset phase state to start_phase (if policy set). Each
|
||||
# execute() call begins a fresh PLANNING phase.
|
||||
if self._phase_policy is not None:
|
||||
|
|
@ -278,6 +298,57 @@ class ReActEngine:
|
|||
self._steps_in_phase = 0
|
||||
# Wave 4 U2: clear any pending violations from a prior run.
|
||||
self._phase_violations = []
|
||||
# U4/KTD-7: reset budget counters on fresh execute(). For checkpoint
|
||||
# resume, use restore_budget_state() AFTER reset() to override.
|
||||
self._think_count = 0
|
||||
self._verify_count = 0
|
||||
self._reflect_count = 0
|
||||
|
||||
def _reset_loop_detector(self) -> None:
|
||||
"""Clear loop detection state only (KTD-9).
|
||||
|
||||
Called between reflexion retry attempts to prevent the loop detector
|
||||
from misfiring due to ``_loop_window`` state leaking across attempts.
|
||||
Does NOT reset phase state or budget counters (KTD-7).
|
||||
"""
|
||||
self._loop_window.clear()
|
||||
self._loop_corrected = False
|
||||
|
||||
def restore_budget_state(self, think: int, verify: int, reflect: int) -> None:
|
||||
"""Restore budget counters from checkpoint (KTD-7).
|
||||
|
||||
On resume, counters derive from persisted plan phase statuses, not
|
||||
reset to zero. Call AFTER ``reset()`` but BEFORE ``execute()``.
|
||||
|
||||
Args:
|
||||
think: Spent think steps (PLANNING/BUILDING phases).
|
||||
verify: Spent verify attempts.
|
||||
reflect: Spent reflect (re-injection) attempts.
|
||||
"""
|
||||
self._think_count = think
|
||||
self._verify_count = verify
|
||||
self._reflect_count = reflect
|
||||
|
||||
def _force_advance_to_verification(self) -> None:
|
||||
"""Force advance to VERIFICATION phase, skipping remaining think phases.
|
||||
|
||||
Called when the think quota is exhausted (U4/R11). Advances through
|
||||
PLANNING/BUILDING until VERIFICATION is reached or no more phases.
|
||||
No-op if no phase_policy is set.
|
||||
"""
|
||||
if self._phase_policy is None or self._current_phase is None:
|
||||
return
|
||||
from agentkit.core.phase import PhaseState
|
||||
|
||||
while self._current_phase not in (PhaseState.VERIFICATION, PhaseState.DELIVERY):
|
||||
nxt = self.advance_phase()
|
||||
if nxt is None:
|
||||
break
|
||||
logger.info(
|
||||
"Think quota exhausted (%d steps), forced advance to %s",
|
||||
self._think_count,
|
||||
self._current_phase.value if self._current_phase else "?",
|
||||
)
|
||||
|
||||
# ── U3/G6: phase state machine ────────────────────────────────────
|
||||
|
||||
|
|
@ -716,7 +787,8 @@ class ReActEngine:
|
|||
|
||||
trace_outcome = "success"
|
||||
# U4/G1: verify 失败回灌计数器。受 max_steps 上限约束(不无限循环)。
|
||||
reinjections = 0
|
||||
# U4/KTD-7: initialize from restored budget state (checkpoint resume).
|
||||
reinjections = self._reflect_count
|
||||
_loop_start = time.monotonic()
|
||||
|
||||
while step < self._max_steps:
|
||||
|
|
@ -731,6 +803,21 @@ class ReActEngine:
|
|||
self._steps_in_phase += 1
|
||||
self._maybe_auto_advance()
|
||||
|
||||
# U4/R11: think quota enforcement. Count steps in PLANNING/
|
||||
# BUILDING and force advance to VERIFICATION when exhausted.
|
||||
if (
|
||||
self._phase_budgets is not None
|
||||
and self._phase_policy is not None
|
||||
and self._current_phase is not None
|
||||
):
|
||||
from agentkit.core.phase import PhaseState as _PS
|
||||
|
||||
if self._current_phase in (_PS.PLANNING, _PS.BUILDING):
|
||||
self._think_count += 1
|
||||
think_quota = self._phase_budgets.get("think")
|
||||
if think_quota is not None and self._think_count >= think_quota:
|
||||
self._force_advance_to_verification()
|
||||
|
||||
# 超时检查(仅 stream=True;stream=False 由 asyncio.wait_for 强制)
|
||||
if stream and effective_timeout > 0:
|
||||
elapsed = time.monotonic() - _loop_start
|
||||
|
|
@ -1324,6 +1411,32 @@ class ReActEngine:
|
|||
|
||||
# U4/G1: verify at final-answer point with reinjection.
|
||||
if self._verification_enabled and output:
|
||||
# U4/R11: verify quota -- skip verification when
|
||||
# exhausted, return best result as-is.
|
||||
verify_quota = (
|
||||
self._phase_budgets.get("verify")
|
||||
if self._phase_budgets is not None
|
||||
else None
|
||||
)
|
||||
if verify_quota is not None and self._verify_count >= verify_quota:
|
||||
logger.info(
|
||||
"Verify quota exhausted (%d/%d), "
|
||||
"returning best result without verify",
|
||||
self._verify_count,
|
||||
verify_quota,
|
||||
)
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": len(trajectory),
|
||||
"total_tokens": total_tokens,
|
||||
"verify_quota_exhausted": True,
|
||||
},
|
||||
)
|
||||
break
|
||||
self._verify_count += 1
|
||||
try:
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
||||
|
|
@ -1342,6 +1455,16 @@ class ReActEngine:
|
|||
}
|
||||
)
|
||||
reinjections += 1
|
||||
# U4/R10: track reflect count for
|
||||
# checkpoint reconstruction (KTD-7).
|
||||
self._reflect_count += 1
|
||||
# U4/KTD-9: reset loop detector
|
||||
# between retry attempts so
|
||||
# _loop_window state doesn't leak.
|
||||
self._reset_loop_detector()
|
||||
# U4/R10: reset think quota for the
|
||||
# next attempt (keep-working bias).
|
||||
self._think_count = 0
|
||||
yield ReActEvent(
|
||||
event_type="step",
|
||||
step=step,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,633 @@
|
|||
"""Unit tests for U4: step budget phases + keep working bias (R11/R10).
|
||||
|
||||
Covers:
|
||||
- ReActEngine.phase_budgets configuration (R11)
|
||||
- Loop detector threshold 3 with budgets vs 2 without (R10/RV22)
|
||||
- _reset_loop_detector preserves budget counters (KTD-9)
|
||||
- restore_budget_state checkpoint reconstruction (KTD-7)
|
||||
- PhasePolicy.step_budget field + serialization
|
||||
- PlanExecEngine threads phase_budgets through to ReActEngine
|
||||
- _force_advance_to_verification behavior
|
||||
- Integration: think quota forces phase advance
|
||||
- Integration: verify quota exhausted returns best result
|
||||
- Integration: reflect quota overrides max_reinjections
|
||||
- Backward compat: no phase_budgets = unchanged behavior
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState
|
||||
from agentkit.core.plan_exec_engine import PlanExecEngine, ReActStepExecutor
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse] | None = None) -> MagicMock:
|
||||
"""Mock LLMGateway. If responses given, chat returns them in order."""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
if responses is not None:
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
else:
|
||||
gateway.chat = AsyncMock(return_value=MagicMock())
|
||||
return gateway
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
||||
|
||||
class _FakeTool(Tool):
|
||||
"""Minimal tool for integration tests."""
|
||||
|
||||
def __init__(self, name: str = "search", result: dict | None = None) -> None:
|
||||
super().__init__(name=name, description="fake tool")
|
||||
self._result = result or {"status": "ok"}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
return self._result
|
||||
|
||||
|
||||
def _wildcard_policy(start: PhaseState = PhaseState.PLANNING) -> PhasePolicy:
|
||||
"""PhasePolicy allowing all tools in all phases."""
|
||||
return PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({WILDCARD}),
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
start_phase=start,
|
||||
)
|
||||
|
||||
|
||||
# ── Configuration tests (R11) ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPhaseBudgetsConfig:
|
||||
def test_phase_budgets_stored(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
|
||||
)
|
||||
assert engine._phase_budgets == {"think": 7, "verify": 2, "reflect": 1}
|
||||
|
||||
def test_phase_budgets_default_none(self) -> None:
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
assert engine._phase_budgets is None
|
||||
|
||||
def test_loop_threshold_raised_to_3_with_budgets(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 1},
|
||||
)
|
||||
assert engine._loop_threshold == 3
|
||||
|
||||
def test_loop_threshold_default_2_without_budgets(self) -> None:
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
assert engine._loop_threshold == 2
|
||||
|
||||
def test_max_reinjections_overridden_by_reflect_budget(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
max_reinjections=5,
|
||||
phase_budgets={"reflect": 2},
|
||||
)
|
||||
assert engine._max_reinjections == 2
|
||||
|
||||
def test_max_reinjections_unchanged_without_reflect_budget(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
max_reinjections=3,
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
assert engine._max_reinjections == 3
|
||||
|
||||
def test_budget_counters_init_zero(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 1},
|
||||
)
|
||||
assert engine._think_count == 0
|
||||
assert engine._verify_count == 0
|
||||
assert engine._reflect_count == 0
|
||||
|
||||
|
||||
# ── _reset_loop_detector (KTD-9) ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestResetLoopDetector:
|
||||
def test_clears_loop_window(self) -> None:
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
engine._loop_window.append("hash1")
|
||||
engine._loop_window.append("hash2")
|
||||
engine._loop_corrected = True
|
||||
engine._reset_loop_detector()
|
||||
assert len(engine._loop_window) == 0
|
||||
assert engine._loop_corrected is False
|
||||
|
||||
def test_preserves_budget_counters(self) -> None:
|
||||
"""KTD-9: _reset_loop_detector must NOT reset budget counters."""
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
engine._think_count = 3
|
||||
engine._verify_count = 1
|
||||
engine._reflect_count = 2
|
||||
engine._loop_window.append("hash1")
|
||||
engine._reset_loop_detector()
|
||||
assert engine._think_count == 3
|
||||
assert engine._verify_count == 1
|
||||
assert engine._reflect_count == 2
|
||||
|
||||
def test_preserves_phase_state(self) -> None:
|
||||
"""KTD-9: _reset_loop_detector must NOT reset phase state."""
|
||||
policy = _wildcard_policy()
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_policy=policy,
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
engine._current_phase = PhaseState.BUILDING
|
||||
engine._steps_in_phase = 4
|
||||
engine._reset_loop_detector()
|
||||
assert engine._current_phase == PhaseState.BUILDING
|
||||
assert engine._steps_in_phase == 4
|
||||
|
||||
|
||||
# ── restore_budget_state (KTD-7) ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestRestoreBudgetState:
|
||||
def test_restores_counters(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
engine.restore_budget_state(think=4, verify=2, reflect=1)
|
||||
assert engine._think_count == 4
|
||||
assert engine._verify_count == 2
|
||||
assert engine._reflect_count == 1
|
||||
|
||||
def test_restore_after_reset(self) -> None:
|
||||
"""KTD-7: restore_budget_state called after reset() overrides zeros."""
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
engine._think_count = 3
|
||||
engine._verify_count = 1
|
||||
engine._reflect_count = 1
|
||||
engine.reset()
|
||||
assert engine._think_count == 0
|
||||
engine.restore_budget_state(think=3, verify=1, reflect=1)
|
||||
assert engine._think_count == 3
|
||||
assert engine._verify_count == 1
|
||||
assert engine._reflect_count == 1
|
||||
|
||||
|
||||
# ── reset() behavior ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResetClearsBudgets:
|
||||
def test_reset_zeros_budget_counters(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
engine._think_count = 7
|
||||
engine._verify_count = 3
|
||||
engine._reflect_count = 2
|
||||
engine.reset()
|
||||
assert engine._think_count == 0
|
||||
assert engine._verify_count == 0
|
||||
assert engine._reflect_count == 0
|
||||
|
||||
def test_reset_clears_loop_detector(self) -> None:
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
engine._loop_window.append("hash1")
|
||||
engine._loop_corrected = True
|
||||
engine.reset()
|
||||
assert len(engine._loop_window) == 0
|
||||
assert engine._loop_corrected is False
|
||||
|
||||
|
||||
# ── _check_tool_loop threshold (R10/RV22) ─────────────────────────────
|
||||
|
||||
|
||||
class TestCheckToolLoopThreshold:
|
||||
def test_threshold_3_with_budgets(self) -> None:
|
||||
"""R10/RV22: loop threshold raised from 2 to 3 with phase_budgets."""
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_budgets={"think": 5},
|
||||
)
|
||||
assert engine._loop_threshold == 3
|
||||
tc = [ToolCall(id="1", name="search", arguments={"q": "x"})]
|
||||
# 1st call: count=1 < 3
|
||||
assert engine._check_tool_loop(tc) is None
|
||||
# 2nd call: count=2 < 3
|
||||
assert engine._check_tool_loop(tc) is None
|
||||
# 3rd call: count=3 >= 3
|
||||
assert engine._check_tool_loop(tc) == "search"
|
||||
|
||||
def test_threshold_2_without_budgets(self) -> None:
|
||||
"""Backward compat: threshold stays 2 without phase_budgets."""
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
assert engine._loop_threshold == 2
|
||||
tc = [ToolCall(id="1", name="search", arguments={"q": "x"})]
|
||||
# 1st call: count=1 < 2
|
||||
assert engine._check_tool_loop(tc) is None
|
||||
# 2nd call: count=2 >= 2
|
||||
assert engine._check_tool_loop(tc) == "search"
|
||||
|
||||
|
||||
# ── PhasePolicy.step_budget (KTD-7) ───────────────────────────────────
|
||||
|
||||
|
||||
class TestPhasePolicyStepBudget:
|
||||
def test_step_budget_defaults_none(self) -> None:
|
||||
policy = PhasePolicy(
|
||||
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
||||
)
|
||||
assert policy.step_budget is None
|
||||
|
||||
def test_step_budget_set(self) -> None:
|
||||
policy = PhasePolicy(
|
||||
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
||||
step_budget=42,
|
||||
)
|
||||
assert policy.step_budget == 42
|
||||
|
||||
def test_to_dict_includes_step_budget(self) -> None:
|
||||
policy = PhasePolicy(
|
||||
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
||||
step_budget=10,
|
||||
)
|
||||
d = policy.to_dict()
|
||||
assert d["step_budget"] == 10
|
||||
|
||||
def test_to_dict_step_budget_none(self) -> None:
|
||||
policy = PhasePolicy(
|
||||
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
||||
)
|
||||
d = policy.to_dict()
|
||||
assert d["step_budget"] is None
|
||||
|
||||
|
||||
# ── PlanExecEngine threading (R11) ────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanExecEngineBudgets:
|
||||
def test_default_phase_budgets(self) -> None:
|
||||
engine = PlanExecEngine(llm_gateway=None)
|
||||
assert engine._phase_budgets == {"think": 7, "verify": 2, "reflect": 1}
|
||||
|
||||
def test_custom_phase_budgets(self) -> None:
|
||||
custom = {"think": 10, "verify": 3, "reflect": 2}
|
||||
engine = PlanExecEngine(llm_gateway=None, phase_budgets=custom)
|
||||
assert engine._phase_budgets == custom
|
||||
# Ensure the module-level default wasn't mutated.
|
||||
assert engine._phase_budgets is not custom
|
||||
|
||||
def test_executor_threads_budgets(self) -> None:
|
||||
executor = ReActStepExecutor(
|
||||
phase_budgets={"think": 5, "verify": 1, "reflect": 0},
|
||||
)
|
||||
assert executor._phase_budgets == {"think": 5, "verify": 1, "reflect": 0}
|
||||
|
||||
def test_executor_defaults_none(self) -> None:
|
||||
executor = ReActStepExecutor()
|
||||
assert executor._phase_budgets is None
|
||||
|
||||
|
||||
# ── _force_advance_to_verification ────────────────────────────────────
|
||||
|
||||
|
||||
class TestForceAdvanceToVerification:
|
||||
def test_advances_from_planning_to_verification(self) -> None:
|
||||
policy = _wildcard_policy(start=PhaseState.PLANNING)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_policy=policy,
|
||||
phase_budgets={"think": 1},
|
||||
)
|
||||
assert engine.current_phase == PhaseState.PLANNING
|
||||
engine._force_advance_to_verification()
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
|
||||
def test_advances_from_building_to_verification(self) -> None:
|
||||
policy = _wildcard_policy(start=PhaseState.BUILDING)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_policy=policy,
|
||||
)
|
||||
assert engine.current_phase == PhaseState.BUILDING
|
||||
engine._force_advance_to_verification()
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
|
||||
def test_no_op_when_already_verification(self) -> None:
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
phase_policy=policy,
|
||||
)
|
||||
engine._force_advance_to_verification()
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
|
||||
def test_no_op_without_policy(self) -> None:
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
engine._force_advance_to_verification()
|
||||
assert engine.current_phase is None
|
||||
|
||||
|
||||
# ── Integration: think quota forces phase advance ─────────────────────
|
||||
|
||||
|
||||
class TestThinkQuotaIntegration:
|
||||
async def test_think_quota_forces_advance_to_verification(self) -> None:
|
||||
"""R11: think quota exhausted forces advance to VERIFICATION."""
|
||||
policy = _wildcard_policy(start=PhaseState.PLANNING)
|
||||
tool = _FakeTool(name="search", result={"found": True})
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(tool_calls=[ToolCall(id="tc_1", name="search", arguments={})]),
|
||||
make_response(content="Done"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
phase_budgets={"think": 1},
|
||||
)
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "search and report"}],
|
||||
tools=[tool],
|
||||
)
|
||||
# After 1 think step, phase should have advanced to VERIFICATION.
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
assert result.status == "success"
|
||||
assert result.output == "Done"
|
||||
|
||||
async def test_think_quota_not_triggered_when_in_verification(self) -> None:
|
||||
"""Think quota only counts PLANNING/BUILDING steps, not VERIFICATION."""
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
tool = _FakeTool(name="search", result={"found": True})
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(tool_calls=[ToolCall(id="tc_1", name="search", arguments={})]),
|
||||
make_response(tool_calls=[ToolCall(id="tc_2", name="search", arguments={})]),
|
||||
make_response(content="Done"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
phase_budgets={"think": 1},
|
||||
)
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "verify stuff"}],
|
||||
tools=[tool],
|
||||
)
|
||||
# Starting in VERIFICATION, think_count should stay 0.
|
||||
assert engine._think_count == 0
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
|
||||
|
||||
# ── Integration: verify quota exhausted returns best result ────────────
|
||||
|
||||
|
||||
class TestVerifyQuotaIntegration:
|
||||
async def test_verify_quota_exhausted_returns_best(self, monkeypatch) -> None:
|
||||
"""R11: when verify quota exhausted, return best result without verify."""
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
|
||||
class _FailVLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
return VerificationResult(
|
||||
passed=False, attempts=1, test_output="fail", errors=["err"]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("agentkit.core.verification_loop.VerificationLoop", _FailVLoop)
|
||||
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="answer 1"),
|
||||
make_response(content="answer 2"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
phase_budgets={"think": 5, "verify": 1, "reflect": 1},
|
||||
)
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
# First answer: verify_count=0 < 1, verify fails, reinject.
|
||||
# Second answer: verify_count=1 >= 1, skip verify, return best.
|
||||
assert result.output == "answer 2"
|
||||
assert engine._verify_count == 1
|
||||
|
||||
async def test_verify_quota_zero_skips_verification(self, monkeypatch) -> None:
|
||||
"""R11: verify quota 0 means never verify."""
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
|
||||
class _NeverCalledVLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
raise AssertionError("verify() should not be called with quota 0")
|
||||
|
||||
monkeypatch.setattr("agentkit.core.verification_loop.VerificationLoop", _NeverCalledVLoop)
|
||||
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="immediate answer"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
phase_budgets={"think": 5, "verify": 0, "reflect": 0},
|
||||
)
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "quick task"}],
|
||||
)
|
||||
assert result.output == "immediate answer"
|
||||
assert engine._verify_count == 0
|
||||
|
||||
|
||||
# ── Integration: reflect quota (R10 keep-working bias) ─────────────────
|
||||
|
||||
|
||||
class TestReflectQuotaIntegration:
|
||||
async def test_reflect_quota_resets_loop_detector(self, monkeypatch) -> None:
|
||||
"""R10/KTD-9: reflect reinjection resets loop detector between attempts."""
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
|
||||
class _FailVLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
return VerificationResult(
|
||||
passed=False, attempts=1, test_output="fail", errors=["err"]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("agentkit.core.verification_loop.VerificationLoop", _FailVLoop)
|
||||
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="attempt 1"),
|
||||
make_response(content="attempt 2"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
phase_budgets={"think": 5, "verify": 3, "reflect": 1},
|
||||
)
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
# After reinjection, _reflect_count should be 1 and loop_window cleared.
|
||||
assert engine._reflect_count == 1
|
||||
assert len(engine._loop_window) == 0
|
||||
assert engine._loop_corrected is False
|
||||
|
||||
async def test_reflect_quota_resets_think_count(self, monkeypatch) -> None:
|
||||
"""R10: reflect reinjection resets think quota for next attempt."""
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
|
||||
class _FailVLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
return VerificationResult(
|
||||
passed=False, attempts=1, test_output="fail", errors=["err"]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("agentkit.core.verification_loop.VerificationLoop", _FailVLoop)
|
||||
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="attempt 1"),
|
||||
make_response(content="attempt 2"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
phase_budgets={"think": 5, "verify": 3, "reflect": 1},
|
||||
)
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
# After reinjection, think_count should be reset to 0.
|
||||
assert engine._think_count == 0
|
||||
|
||||
async def test_reflect_quota_exhausted_breaks(self, monkeypatch) -> None:
|
||||
"""R10: when reflect quota exhausted, verify fail breaks (not reinject)."""
|
||||
from agentkit.core.verification_loop import VerificationResult
|
||||
|
||||
class _FailVLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
return VerificationResult(
|
||||
passed=False, attempts=1, test_output="fail", errors=["err"]
|
||||
)
|
||||
|
||||
monkeypatch.setattr("agentkit.core.verification_loop.VerificationLoop", _FailVLoop)
|
||||
|
||||
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="only attempt"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=policy,
|
||||
verification_enabled=True,
|
||||
verification_commands=["pytest"],
|
||||
phase_budgets={"think": 5, "verify": 3, "reflect": 0},
|
||||
)
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "do something"}],
|
||||
)
|
||||
# reflect=0 means max_reinjections=0, so verify fail breaks immediately.
|
||||
assert engine._reflect_count == 0
|
||||
assert result.status == "verify_failed"
|
||||
|
||||
|
||||
# ── Backward compatibility ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBackwardCompat:
|
||||
async def test_no_budgets_unchanged_behavior(self) -> None:
|
||||
"""Without phase_budgets, engine behaves identically to before U4."""
|
||||
gateway = make_mock_gateway(
|
||||
[
|
||||
make_response(content="hello"),
|
||||
]
|
||||
)
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
assert result.output == "hello"
|
||||
assert result.status == "success"
|
||||
assert engine._loop_threshold == 2
|
||||
assert engine._phase_budgets is None
|
||||
|
||||
async def test_no_budgets_loop_threshold_2(self) -> None:
|
||||
"""Without phase_budgets, loop detector still uses threshold 2."""
|
||||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||||
assert engine._loop_threshold == 2
|
||||
tc = [ToolCall(id="1", name="search", arguments={"q": "x"})]
|
||||
assert engine._check_tool_loop(tc) is None
|
||||
assert engine._check_tool_loop(tc) == "search"
|
||||
|
||||
def test_max_reinjections_respected_without_budgets(self) -> None:
|
||||
engine = ReActEngine(
|
||||
llm_gateway=make_mock_gateway(),
|
||||
max_reinjections=3,
|
||||
)
|
||||
assert engine._max_reinjections == 3
|
||||
Loading…
Reference in New Issue