diff --git a/src/agentkit/core/config_driven.py b/src/agentkit/core/config_driven.py index 0af0591..822015b 100644 --- a/src/agentkit/core/config_driven.py +++ b/src/agentkit/core/config_driven.py @@ -13,10 +13,14 @@ import logging import os from collections.abc import AsyncGenerator, Awaitable from datetime import datetime, timezone -from typing import Any, Callable, Coroutine +from typing import TYPE_CHECKING, Any, Callable, Coroutine import yaml +if TYPE_CHECKING: + from agentkit.core.spec_manager import SpecManager + from agentkit.evolution.pitfall_detector import PitfallDetector + from agentkit.core.base import BaseAgent from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError from agentkit.core.protocol import ( @@ -243,6 +247,11 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): llm_gateway: object | None = None, # NEW v2 param: LLMGateway mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs compressor: object | None = None, # CompressionStrategy | None + # U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine + # (KTD-5). None = skip pitfall injection / spec review gate (backward compat). + pitfall_detector: "PitfallDetector | None" = None, + spec_review_handler: Any | None = None, + spec_manager: "SpecManager | None" = None, ): # v2: If SkillConfig, extract skill info from agentkit.skills.base import SkillConfig, Skill @@ -324,6 +333,14 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): # v2: Store compressor for ReAct engine self._compressor = compressor + # U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine + # so PLAN_EXEC streaming/non-streaming paths actually invoke pitfall + # injection (R12) and the spec review gate (R8). None = no-op (backward + # compat). See _handle_plan_exec_stream / _handle_plan_exec. + self._pitfall_detector = pitfall_detector + self._spec_review_handler = spec_review_handler + self._spec_manager = spec_manager + # 从配置构建 Prompt 模板 if config.prompt: sections = PromptSection( @@ -929,6 +946,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): llm_gateway=self._llm_gateway, max_replans=2, default_timeout=300.0, + pitfall_detector=self._pitfall_detector, + spec_review_handler=self._spec_review_handler, + spec_manager=self._spec_manager, ) async for event in plan_exec_engine.execute_stream( messages=user_messages, @@ -1118,6 +1138,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin): llm_gateway=self._llm_gateway, max_replans=2, default_timeout=300.0, + pitfall_detector=self._pitfall_detector, + spec_review_handler=self._spec_review_handler, + spec_manager=self._spec_manager, ) result = await plan_exec_engine.execute( diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 9b03aca..15bb080 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -291,6 +291,11 @@ class ReActEngine: # tracks error reinjections). Incremented each time a reflection is # generated and injected for retry. self._reflection_count: int = 0 + # KTD-7: guard flag set by restore_budget_state() so _execute_loop's + # self.reset() call does NOT zero out the restored counters. Cleared in + # _execute_loop's finally block so subsequent execute() calls without a + # restore still reset properly. + self._state_restored: bool = False def reset(self) -> None: """Reset internal state for reuse across conversations. @@ -391,6 +396,11 @@ class ReActEngine: On resume, counters derive from persisted plan phase statuses, not reset to zero. Call AFTER ``reset()`` but BEFORE ``execute()``. + Sets ``_state_restored`` so the subsequent ``execute()``/``execute_stream()`` + call (which invokes ``_execute_loop`` → ``self.reset()``) does NOT zero out + the restored counters. The flag is cleared in ``_execute_loop``'s finally + block so the next call without a restore resets normally. + Args: think: Spent think steps (PLANNING/BUILDING phases). verify: Spent verify attempts. @@ -399,6 +409,7 @@ class ReActEngine: self._think_count = think self._verify_count = verify self._reflect_count = reflect + self._state_restored = True def _force_advance_to_verification(self) -> None: """Force advance to VERIFICATION phase, skipping remaining think phases. @@ -764,8 +775,12 @@ class ReActEngine: effective_timeout: 超时秒数;stream=True 时在循环内检查, stream=False 时由 caller 的 asyncio.wait_for 强制 """ - # P2 #9: Reset loop detection state so reuse across conversations is clean - self.reset() + # P2 #9: Reset loop detection state so reuse across conversations is clean. + # KTD-7: skip reset when restore_budget_state() was called so restored + # counters survive into the loop. Flag is cleared in the finally block + # below so the next execute() without a restore resets normally. + if not self._state_restored: + self.reset() tools = tools or [] if tools: tools = self._maybe_add_tool_search(tools) @@ -1747,6 +1762,9 @@ class ReActEngine: data={"result": final_result}, ) finally: + # KTD-7: clear the restore guard so the next execute() without a + # restore_budget_state() call resets counters normally. + self._state_restored = False # 结束轨迹记录 — always runs even if consumer doesn't fully iterate if trace_recorder is not None: trace_recorder.end_trace(outcome=trace_outcome) diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 5d0c547..ed75eaa 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -4,6 +4,7 @@ import asyncio import logging import os from contextlib import asynccontextmanager +from typing import Any from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -152,14 +153,84 @@ def _build_skill_registry(config: ServerConfig) -> SkillRegistry: return registry +def _try_get_experience_store(server_config) -> Any | None: + """Build a PostgreSQL ExperienceStore from server_config, or None if unavailable. + + Mirrors cli/skill.py:_try_get_experience_store. database_url lookup order: + 1. server_config.evolution.database_url + 2. server_config.memory.episodic.database_url + 3. DATABASE_URL env var + + Returns an ExperienceStore instance or None (lazy import — return type is + Any to avoid a module-level dependency on the experience_store module). + """ + database_url: str | None = None + + evo_conf = getattr(server_config, "evolution", None) or {} + database_url = evo_conf.get("database_url") if isinstance(evo_conf, dict) else None + + if not database_url: + epi_conf = (getattr(server_config, "memory", None) or {}).get("episodic", {}) + database_url = epi_conf.get("database_url") if isinstance(epi_conf, dict) else None + + if not database_url: + database_url = os.environ.get("DATABASE_URL") + + if not database_url: + return None + + try: + from agentkit.evolution.experience_store import ExperienceStore + from agentkit.memory.models import ExperienceModel, create_experience_session_factory + + session_factory = create_experience_session_factory(database_url) + return ExperienceStore( + session_factory=session_factory, + experience_model=ExperienceModel, + ) + except Exception as e: + logger.warning(f"Failed to create PostgreSQL ExperienceStore: {e}") + return None + + @asynccontextmanager async def lifespan(app: FastAPI): # Startup task_store = app.state.task_store await task_store.start_cleanup() - # Start config watcher if server_config is available + # U7/R12 + U8/R8 (KTD-5): instantiate PitfallDetector + SpecManager as + # app-state singletons so PLAN_EXEC tasks can access them. PitfallDetector + # requires the PostgreSQL ExperienceStore; if unavailable (no DB), it is + # skipped gracefully (pitfall injection becomes a no-op). SpecManager is + # file-based and always available. + app.state.pitfall_detector = None + app.state.spec_manager = None server_config = getattr(app.state, "server_config", None) + try: + from agentkit.core.spec_manager import SpecManager + + app.state.spec_manager = SpecManager() + logger.info("SpecManager initialized (file-based)") + except Exception: # noqa: BLE001 — SpecManager init; must not block startup + logger.debug("SpecManager init failed — spec persistence unavailable", exc_info=True) + + try: + experience_store = _try_get_experience_store(server_config) + if experience_store is not None: + from agentkit.evolution.pitfall_detector import PitfallDetector + + app.state.pitfall_detector = PitfallDetector(experience_store) + logger.info("PitfallDetector initialized (ExperienceStore ready)") + else: + logger.debug( + "PitfallDetector skipped — no PostgreSQL ExperienceStore configured " + "(pitfall injection is a no-op for PLAN_EXEC)" + ) + except Exception: # noqa: BLE001 — PitfallDetector init; must not block startup + logger.debug("PitfallDetector init failed — pitfall injection disabled", exc_info=True) + + # Start config watcher if server_config is available if server_config is not None and server_config._config_path: server_config.on_change = lambda cfg: _on_config_change(app, cfg) server_config.watch_config() @@ -253,6 +324,20 @@ async def lifespan(app: FastAPI): try: agent = await app.state.agent_pool.create_agent(default_config) + # U7/R12 + U8/R8 (KTD-5): wire app-state singletons onto the default + # agent so its PLAN_EXEC path (ConfigDrivenAgent._handle_plan_exec_*) + # threads pitfall_detector + spec_manager into PlanExecEngine. + # ponytail: known gap — agents created later via + # AgentPool.create_agent/create_agent_from_skill (skill-loaded agents) + # do NOT receive these singletons because AgentPool does not forward + # them yet. Upgrade path: add pitfall_detector/spec_manager params to + # AgentPool.__init__ and pass through in create_agent(). The default + # chat agent is wired here as the most critical path; skill agents + # fall back to None (no pitfall injection / spec review) until the + # pool is updated. + agent._pitfall_detector = app.state.pitfall_detector + agent._spec_manager = app.state.spec_manager + # Register tools into the agent's tool registry search_api_keys = { "tavily_api_key": os.environ.get("TAVILY_API_KEY"), diff --git a/tests/unit/test_budget_restore.py b/tests/unit/test_budget_restore.py new file mode 100644 index 0000000..3a20ea0 --- /dev/null +++ b/tests/unit/test_budget_restore.py @@ -0,0 +1,170 @@ +"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset(). + +Regression coverage for the P1 finding where ``_execute_loop`` called +``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint +counters, zeroing them out and breaking checkpoint reconstruction. + +Covers: +- restore_budget_state() sets _state_restored flag +- execute() does NOT zero out restored counters (reset skipped) +- _state_restored flag is cleared after execute() finishes +- A subsequent execute() without restore resets counters normally +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState +from agentkit.core.react import ReActEngine +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, TokenUsage + + +# ── helpers ─────────────────────────────────────────────────────────── + + +def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock: + """Mock LLMGateway. chat returns the responses in order.""" + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock(side_effect=responses) + return gateway + + +def make_response(content: str = "") -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + tool_calls=[], + ) + + +def _wildcard_policy(start: PhaseState) -> PhasePolicy: + """PhasePolicy allowing all tools in all phases.""" + return PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({WILDCARD}), + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + }, + start_phase=start, + ) + + +# ── restore_budget_state + execute() integration (KTD-7) ────────────── + + +class TestRestoreBudgetStateSurvivesExecute: + """KTD-7: restored counters must survive into _execute_loop (not zeroed).""" + + async def test_restored_counters_survive_execute(self) -> None: + """restore_budget_state() then execute() — counters must NOT be zeroed. + + Without the fix, _execute_loop calls self.reset() which zeros + _think_count/_verify_count/_reflect_count. The _state_restored flag + guards against this. + """ + # Start in VERIFICATION so think_count is not incremented by the loop + # (the increment only happens in PLANNING/BUILDING phases). + policy = _wildcard_policy(start=PhaseState.VERIFICATION) + gateway = make_mock_gateway([make_response(content="done")]) + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=policy, + phase_budgets={"think": 7, "verify": 2, "reflect": 1}, + ) + + # Simulate checkpoint restore + engine.restore_budget_state(think=5, verify=2, reflect=1) + assert engine._state_restored is True + assert engine._think_count == 5 + assert engine._verify_count == 2 + assert engine._reflect_count == 1 + + # Execute — _execute_loop must skip reset() due to _state_restored + await engine.execute( + messages=[{"role": "user", "content": "resume checkpoint"}], + ) + + # Counters survived (think=5 unchanged because we started in VERIFICATION; + # verify/reflect unchanged because verification_enabled=False default). + assert engine._think_count == 5, ( + f"Expected _think_count==5 (restored), got {engine._think_count} " + "(reset() zeroed the restored checkpoint — KTD-7 regression)" + ) + assert engine._verify_count == 2 + assert engine._reflect_count == 1 + + async def test_state_restored_flag_cleared_after_execute(self) -> None: + """_state_restored must be cleared in finally so next execute() resets.""" + policy = _wildcard_policy(start=PhaseState.VERIFICATION) + gateway = make_mock_gateway([make_response(content="done")]) + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=policy, + phase_budgets={"think": 7, "verify": 2, "reflect": 1}, + ) + + engine.restore_budget_state(think=5, verify=2, reflect=1) + assert engine._state_restored is True + + await engine.execute( + messages=[{"role": "user", "content": "resume"}], + ) + + # Flag cleared in finally block + assert engine._state_restored is False, ( + "_state_restored not cleared after execute() — subsequent execute() " + "calls would incorrectly skip reset()" + ) + + async def test_second_execute_without_restore_resets_counters(self) -> None: + """After a restored execute(), the next execute() must reset normally.""" + policy = _wildcard_policy(start=PhaseState.VERIFICATION) + gateway = make_mock_gateway( + [make_response(content="first"), make_response(content="second")] + ) + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=policy, + phase_budgets={"think": 7, "verify": 2, "reflect": 1}, + ) + + # First execute with restored state + engine.restore_budget_state(think=5, verify=2, reflect=1) + await engine.execute(messages=[{"role": "user", "content": "resume"}]) + assert engine._think_count == 5 # survived + + # Second execute WITHOUT restore — must reset to 0 + await engine.execute(messages=[{"role": "user", "content": "fresh"}]) + assert engine._think_count == 0, ( + f"Expected _think_count==0 after fresh execute(), got " + f"{engine._think_count} (flag not cleared, reset incorrectly skipped)" + ) + assert engine._verify_count == 0 + assert engine._reflect_count == 0 + + async def test_execute_without_restore_behaves_unchanged(self) -> None: + """No restore_budget_state() call — execute() resets as before (backward compat).""" + policy = _wildcard_policy(start=PhaseState.VERIFICATION) + gateway = make_mock_gateway([make_response(content="done")]) + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=policy, + phase_budgets={"think": 7, "verify": 2, "reflect": 1}, + ) + + # Manually set counters (simulating stale state from a prior run) + engine._think_count = 9 + engine._verify_count = 3 + engine._reflect_count = 2 + assert engine._state_restored is False + + await engine.execute(messages=[{"role": "user", "content": "fresh"}]) + + # reset() ran normally, zeroing the stale counters + assert engine._think_count == 0 + assert engine._verify_count == 0 + assert engine._reflect_count == 0