"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset(). Regression coverage for the P1 finding where ``_execute_loop`` called ``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint counters, zeroing them out and breaking checkpoint reconstruction. Covers: - restore_budget_state() sets _state_restored flag - execute() does NOT zero out restored counters (reset skipped) - _state_restored flag is cleared after execute() finishes - A subsequent execute() without restore resets counters normally """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState from agentkit.core.react import ReActEngine from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMResponse, TokenUsage # ── helpers ─────────────────────────────────────────────────────────── def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: """Mock LLMGateway. chat returns the responses in order.""" gateway = MagicMock(spec=LLMGateway) gateway.chat = AsyncMock(side_effect=responses) return gateway def make_response(content: str = "") -> LLMResponse: return LLMResponse( content=content, model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=20), tool_calls=[], ) def _wildcard_policy(start: PhaseState) -> PhasePolicy: """PhasePolicy allowing all tools in all phases.""" return PhasePolicy( whitelist={ PhaseState.PLANNING: frozenset({WILDCARD}), PhaseState.BUILDING: frozenset({WILDCARD}), PhaseState.VERIFICATION: frozenset({WILDCARD}), PhaseState.DELIVERY: frozenset({WILDCARD}), }, start_phase=start, ) # ── restore_budget_state + execute() integration (KTD-7) ────────────── class TestRestoreBudgetStateSurvivesExecute: """KTD-7: restored counters must survive into _execute_loop (not zeroed).""" async def test_restored_counters_survive_execute(self) -> None: """restore_budget_state() then execute() — counters must NOT be zeroed. Without the fix, _execute_loop calls self.reset() which zeros _think_count/_verify_count/_reflect_count. The _state_restored flag guards against this. """ # Start in VERIFICATION so think_count is not incremented by the loop # (the increment only happens in PLANNING/BUILDING phases). policy = _wildcard_policy(start=PhaseState.VERIFICATION) gateway = make_mock_gateway([make_response(content="done")]) engine = ReActEngine( llm_gateway=gateway, phase_policy=policy, phase_budgets={"think": 7, "verify": 2, "reflect": 1}, ) # Simulate checkpoint restore engine.restore_budget_state(think=5, verify=2, reflect=1) assert engine._state_restored is True assert engine._think_count == 5 assert engine._verify_count == 2 assert engine._reflect_count == 1 # Execute — _execute_loop must skip reset() due to _state_restored await engine.execute( messages=[{"role": "user", "content": "resume checkpoint"}], ) # Counters survived (think=5 unchanged because we started in VERIFICATION; # verify/reflect unchanged because verification_enabled=False default). assert engine._think_count == 5, ( f"Expected _think_count==5 (restored), got {engine._think_count} " "(reset() zeroed the restored checkpoint — KTD-7 regression)" ) assert engine._verify_count == 2 assert engine._reflect_count == 1 async def test_state_restored_flag_cleared_after_execute(self) -> None: """_state_restored must be cleared in finally so next execute() resets.""" policy = _wildcard_policy(start=PhaseState.VERIFICATION) gateway = make_mock_gateway([make_response(content="done")]) engine = ReActEngine( llm_gateway=gateway, phase_policy=policy, phase_budgets={"think": 7, "verify": 2, "reflect": 1}, ) engine.restore_budget_state(think=5, verify=2, reflect=1) assert engine._state_restored is True await engine.execute( messages=[{"role": "user", "content": "resume"}], ) # Flag cleared in finally block assert engine._state_restored is False, ( "_state_restored not cleared after execute() — subsequent execute() " "calls would incorrectly skip reset()" ) async def test_second_execute_without_restore_resets_counters(self) -> None: """After a restored execute(), the next execute() must reset normally.""" policy = _wildcard_policy(start=PhaseState.VERIFICATION) gateway = make_mock_gateway( [make_response(content="first"), make_response(content="second")] ) engine = ReActEngine( llm_gateway=gateway, phase_policy=policy, phase_budgets={"think": 7, "verify": 2, "reflect": 1}, ) # First execute with restored state engine.restore_budget_state(think=5, verify=2, reflect=1) await engine.execute(messages=[{"role": "user", "content": "resume"}]) assert engine._think_count == 5 # survived # Second execute WITHOUT restore — must reset to 0 await engine.execute(messages=[{"role": "user", "content": "fresh"}]) assert engine._think_count == 0, ( f"Expected _think_count==0 after fresh execute(), got " f"{engine._think_count} (flag not cleared, reset incorrectly skipped)" ) assert engine._verify_count == 0 assert engine._reflect_count == 0 async def test_execute_without_restore_behaves_unchanged(self) -> None: """No restore_budget_state() call — execute() resets as before (backward compat).""" policy = _wildcard_policy(start=PhaseState.VERIFICATION) gateway = make_mock_gateway([make_response(content="done")]) engine = ReActEngine( llm_gateway=gateway, phase_policy=policy, phase_budgets={"think": 7, "verify": 2, "reflect": 1}, ) # Manually set counters (simulating stale state from a prior run) engine._think_count = 9 engine._verify_count = 3 engine._reflect_count = 2 assert engine._state_restored is False await engine.execute(messages=[{"role": "user", "content": "fresh"}]) # reset() ran normally, zeroing the stale counters assert engine._think_count == 0 assert engine._verify_count == 0 assert engine._reflect_count == 0