fischer-agentkit/tests/unit/test_budget_restore.py

171 lines
6.8 KiB
Python

"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset().
Regression coverage for the P1 finding where ``_execute_loop`` called
``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint
counters, zeroing them out and breaking checkpoint reconstruction.
Covers:
- restore_budget_state() sets _state_restored flag
- execute() does NOT zero out restored counters (reset skipped)
- _state_restored flag is cleared after execute() finishes
- A subsequent execute() without restore resets counters normally
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState
from agentkit.core.react import ReActEngine
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── helpers ───────────────────────────────────────────────────────────
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
"""Mock LLMGateway. chat returns the responses in order."""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
def make_response(content: str = "") -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
tool_calls=[],
)
def _wildcard_policy(start: PhaseState) -> PhasePolicy:
"""PhasePolicy allowing all tools in all phases."""
return PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
start_phase=start,
)
# ── restore_budget_state + execute() integration (KTD-7) ──────────────
class TestRestoreBudgetStateSurvivesExecute:
"""KTD-7: restored counters must survive into _execute_loop (not zeroed)."""
async def test_restored_counters_survive_execute(self) -> None:
"""restore_budget_state() then execute() — counters must NOT be zeroed.
Without the fix, _execute_loop calls self.reset() which zeros
_think_count/_verify_count/_reflect_count. The _state_restored flag
guards against this.
"""
# Start in VERIFICATION so think_count is not incremented by the loop
# (the increment only happens in PLANNING/BUILDING phases).
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Simulate checkpoint restore
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
assert engine._think_count == 5
assert engine._verify_count == 2
assert engine._reflect_count == 1
# Execute — _execute_loop must skip reset() due to _state_restored
await engine.execute(
messages=[{"role": "user", "content": "resume checkpoint"}],
)
# Counters survived (think=5 unchanged because we started in VERIFICATION;
# verify/reflect unchanged because verification_enabled=False default).
assert engine._think_count == 5, (
f"Expected _think_count==5 (restored), got {engine._think_count} "
"(reset() zeroed the restored checkpoint — KTD-7 regression)"
)
assert engine._verify_count == 2
assert engine._reflect_count == 1
async def test_state_restored_flag_cleared_after_execute(self) -> None:
"""_state_restored must be cleared in finally so next execute() resets."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
await engine.execute(
messages=[{"role": "user", "content": "resume"}],
)
# Flag cleared in finally block
assert engine._state_restored is False, (
"_state_restored not cleared after execute() — subsequent execute() "
"calls would incorrectly skip reset()"
)
async def test_second_execute_without_restore_resets_counters(self) -> None:
"""After a restored execute(), the next execute() must reset normally."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway(
[make_response(content="first"), make_response(content="second")]
)
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# First execute with restored state
engine.restore_budget_state(think=5, verify=2, reflect=1)
await engine.execute(messages=[{"role": "user", "content": "resume"}])
assert engine._think_count == 5 # survived
# Second execute WITHOUT restore — must reset to 0
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
assert engine._think_count == 0, (
f"Expected _think_count==0 after fresh execute(), got "
f"{engine._think_count} (flag not cleared, reset incorrectly skipped)"
)
assert engine._verify_count == 0
assert engine._reflect_count == 0
async def test_execute_without_restore_behaves_unchanged(self) -> None:
"""No restore_budget_state() call — execute() resets as before (backward compat)."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Manually set counters (simulating stale state from a prior run)
engine._think_count = 9
engine._verify_count = 3
engine._reflect_count = 2
assert engine._state_restored is False
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
# reset() ran normally, zeroing the stale counters
assert engine._think_count == 0
assert engine._verify_count == 0
assert engine._reflect_count == 0