171 lines
6.8 KiB
Python
171 lines
6.8 KiB
Python
"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset().
|
|
|
|
Regression coverage for the P1 finding where ``_execute_loop`` called
|
|
``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint
|
|
counters, zeroing them out and breaking checkpoint reconstruction.
|
|
|
|
Covers:
|
|
- restore_budget_state() sets _state_restored flag
|
|
- execute() does NOT zero out restored counters (reset skipped)
|
|
- _state_restored flag is cleared after execute() finishes
|
|
- A subsequent execute() without restore resets counters normally
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState
|
|
from agentkit.core.react import ReActEngine
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────
|
|
|
|
|
|
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
|
"""Mock LLMGateway. chat returns the responses in order."""
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
gateway.chat = AsyncMock(side_effect=responses)
|
|
return gateway
|
|
|
|
|
|
def make_response(content: str = "") -> LLMResponse:
|
|
return LLMResponse(
|
|
content=content,
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
tool_calls=[],
|
|
)
|
|
|
|
|
|
def _wildcard_policy(start: PhaseState) -> PhasePolicy:
|
|
"""PhasePolicy allowing all tools in all phases."""
|
|
return PhasePolicy(
|
|
whitelist={
|
|
PhaseState.PLANNING: frozenset({WILDCARD}),
|
|
PhaseState.BUILDING: frozenset({WILDCARD}),
|
|
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
|
},
|
|
start_phase=start,
|
|
)
|
|
|
|
|
|
# ── restore_budget_state + execute() integration (KTD-7) ──────────────
|
|
|
|
|
|
class TestRestoreBudgetStateSurvivesExecute:
|
|
"""KTD-7: restored counters must survive into _execute_loop (not zeroed)."""
|
|
|
|
async def test_restored_counters_survive_execute(self) -> None:
|
|
"""restore_budget_state() then execute() — counters must NOT be zeroed.
|
|
|
|
Without the fix, _execute_loop calls self.reset() which zeros
|
|
_think_count/_verify_count/_reflect_count. The _state_restored flag
|
|
guards against this.
|
|
"""
|
|
# Start in VERIFICATION so think_count is not incremented by the loop
|
|
# (the increment only happens in PLANNING/BUILDING phases).
|
|
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
|
gateway = make_mock_gateway([make_response(content="done")])
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=policy,
|
|
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
|
|
)
|
|
|
|
# Simulate checkpoint restore
|
|
engine.restore_budget_state(think=5, verify=2, reflect=1)
|
|
assert engine._state_restored is True
|
|
assert engine._think_count == 5
|
|
assert engine._verify_count == 2
|
|
assert engine._reflect_count == 1
|
|
|
|
# Execute — _execute_loop must skip reset() due to _state_restored
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "resume checkpoint"}],
|
|
)
|
|
|
|
# Counters survived (think=5 unchanged because we started in VERIFICATION;
|
|
# verify/reflect unchanged because verification_enabled=False default).
|
|
assert engine._think_count == 5, (
|
|
f"Expected _think_count==5 (restored), got {engine._think_count} "
|
|
"(reset() zeroed the restored checkpoint — KTD-7 regression)"
|
|
)
|
|
assert engine._verify_count == 2
|
|
assert engine._reflect_count == 1
|
|
|
|
async def test_state_restored_flag_cleared_after_execute(self) -> None:
|
|
"""_state_restored must be cleared in finally so next execute() resets."""
|
|
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
|
gateway = make_mock_gateway([make_response(content="done")])
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=policy,
|
|
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
|
|
)
|
|
|
|
engine.restore_budget_state(think=5, verify=2, reflect=1)
|
|
assert engine._state_restored is True
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "resume"}],
|
|
)
|
|
|
|
# Flag cleared in finally block
|
|
assert engine._state_restored is False, (
|
|
"_state_restored not cleared after execute() — subsequent execute() "
|
|
"calls would incorrectly skip reset()"
|
|
)
|
|
|
|
async def test_second_execute_without_restore_resets_counters(self) -> None:
|
|
"""After a restored execute(), the next execute() must reset normally."""
|
|
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
|
gateway = make_mock_gateway(
|
|
[make_response(content="first"), make_response(content="second")]
|
|
)
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=policy,
|
|
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
|
|
)
|
|
|
|
# First execute with restored state
|
|
engine.restore_budget_state(think=5, verify=2, reflect=1)
|
|
await engine.execute(messages=[{"role": "user", "content": "resume"}])
|
|
assert engine._think_count == 5 # survived
|
|
|
|
# Second execute WITHOUT restore — must reset to 0
|
|
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
|
|
assert engine._think_count == 0, (
|
|
f"Expected _think_count==0 after fresh execute(), got "
|
|
f"{engine._think_count} (flag not cleared, reset incorrectly skipped)"
|
|
)
|
|
assert engine._verify_count == 0
|
|
assert engine._reflect_count == 0
|
|
|
|
async def test_execute_without_restore_behaves_unchanged(self) -> None:
|
|
"""No restore_budget_state() call — execute() resets as before (backward compat)."""
|
|
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
|
|
gateway = make_mock_gateway([make_response(content="done")])
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=policy,
|
|
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
|
|
)
|
|
|
|
# Manually set counters (simulating stale state from a prior run)
|
|
engine._think_count = 9
|
|
engine._verify_count = 3
|
|
engine._reflect_count = 2
|
|
assert engine._state_restored is False
|
|
|
|
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
|
|
|
|
# reset() ran normally, zeroing the stale counters
|
|
assert engine._think_count == 0
|
|
assert engine._verify_count == 0
|
|
assert engine._reflect_count == 0
|