fix(review): wire pitfall_detector/spec_review to PlanExecEngine + fix restore_budget_state reset order

This commit is contained in:
chiguyong 2026-07-03 22:05:51 +08:00
parent f1f2e72cad
commit ffb7a51d77
4 changed files with 300 additions and 4 deletions

View File

@ -13,10 +13,14 @@ import logging
import os import os
from collections.abc import AsyncGenerator, Awaitable from collections.abc import AsyncGenerator, Awaitable
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Callable, Coroutine from typing import TYPE_CHECKING, Any, Callable, Coroutine
import yaml import yaml
if TYPE_CHECKING:
from agentkit.core.spec_manager import SpecManager
from agentkit.evolution.pitfall_detector import PitfallDetector
from agentkit.core.base import BaseAgent from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError
from agentkit.core.protocol import ( from agentkit.core.protocol import (
@ -243,6 +247,11 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway: object | None = None, # NEW v2 param: LLMGateway llm_gateway: object | None = None, # NEW v2 param: LLMGateway
mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs
compressor: object | None = None, # CompressionStrategy | None compressor: object | None = None, # CompressionStrategy | None
# U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine
# (KTD-5). None = skip pitfall injection / spec review gate (backward compat).
pitfall_detector: "PitfallDetector | None" = None,
spec_review_handler: Any | None = None,
spec_manager: "SpecManager | None" = None,
): ):
# v2: If SkillConfig, extract skill info # v2: If SkillConfig, extract skill info
from agentkit.skills.base import SkillConfig, Skill from agentkit.skills.base import SkillConfig, Skill
@ -324,6 +333,14 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
# v2: Store compressor for ReAct engine # v2: Store compressor for ReAct engine
self._compressor = compressor self._compressor = compressor
# U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine
# so PLAN_EXEC streaming/non-streaming paths actually invoke pitfall
# injection (R12) and the spec review gate (R8). None = no-op (backward
# compat). See _handle_plan_exec_stream / _handle_plan_exec.
self._pitfall_detector = pitfall_detector
self._spec_review_handler = spec_review_handler
self._spec_manager = spec_manager
# 从配置构建 Prompt 模板 # 从配置构建 Prompt 模板
if config.prompt: if config.prompt:
sections = PromptSection( sections = PromptSection(
@ -929,6 +946,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway=self._llm_gateway, llm_gateway=self._llm_gateway,
max_replans=2, max_replans=2,
default_timeout=300.0, default_timeout=300.0,
pitfall_detector=self._pitfall_detector,
spec_review_handler=self._spec_review_handler,
spec_manager=self._spec_manager,
) )
async for event in plan_exec_engine.execute_stream( async for event in plan_exec_engine.execute_stream(
messages=user_messages, messages=user_messages,
@ -1118,6 +1138,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway=self._llm_gateway, llm_gateway=self._llm_gateway,
max_replans=2, max_replans=2,
default_timeout=300.0, default_timeout=300.0,
pitfall_detector=self._pitfall_detector,
spec_review_handler=self._spec_review_handler,
spec_manager=self._spec_manager,
) )
result = await plan_exec_engine.execute( result = await plan_exec_engine.execute(

View File

@ -291,6 +291,11 @@ class ReActEngine:
# tracks error reinjections). Incremented each time a reflection is # tracks error reinjections). Incremented each time a reflection is
# generated and injected for retry. # generated and injected for retry.
self._reflection_count: int = 0 self._reflection_count: int = 0
# KTD-7: guard flag set by restore_budget_state() so _execute_loop's
# self.reset() call does NOT zero out the restored counters. Cleared in
# _execute_loop's finally block so subsequent execute() calls without a
# restore still reset properly.
self._state_restored: bool = False
def reset(self) -> None: def reset(self) -> None:
"""Reset internal state for reuse across conversations. """Reset internal state for reuse across conversations.
@ -391,6 +396,11 @@ class ReActEngine:
On resume, counters derive from persisted plan phase statuses, not On resume, counters derive from persisted plan phase statuses, not
reset to zero. Call AFTER ``reset()`` but BEFORE ``execute()``. reset to zero. Call AFTER ``reset()`` but BEFORE ``execute()``.
Sets ``_state_restored`` so the subsequent ``execute()``/``execute_stream()``
call (which invokes ``_execute_loop`` ``self.reset()``) does NOT zero out
the restored counters. The flag is cleared in ``_execute_loop``'s finally
block so the next call without a restore resets normally.
Args: Args:
think: Spent think steps (PLANNING/BUILDING phases). think: Spent think steps (PLANNING/BUILDING phases).
verify: Spent verify attempts. verify: Spent verify attempts.
@ -399,6 +409,7 @@ class ReActEngine:
self._think_count = think self._think_count = think
self._verify_count = verify self._verify_count = verify
self._reflect_count = reflect self._reflect_count = reflect
self._state_restored = True
def _force_advance_to_verification(self) -> None: def _force_advance_to_verification(self) -> None:
"""Force advance to VERIFICATION phase, skipping remaining think phases. """Force advance to VERIFICATION phase, skipping remaining think phases.
@ -764,8 +775,12 @@ class ReActEngine:
effective_timeout: 超时秒数stream=True 时在循环内检查 effective_timeout: 超时秒数stream=True 时在循环内检查
stream=False 时由 caller asyncio.wait_for 强制 stream=False 时由 caller asyncio.wait_for 强制
""" """
# P2 #9: Reset loop detection state so reuse across conversations is clean # P2 #9: Reset loop detection state so reuse across conversations is clean.
self.reset() # KTD-7: skip reset when restore_budget_state() was called so restored
# counters survive into the loop. Flag is cleared in the finally block
# below so the next execute() without a restore resets normally.
if not self._state_restored:
self.reset()
tools = tools or [] tools = tools or []
if tools: if tools:
tools = self._maybe_add_tool_search(tools) tools = self._maybe_add_tool_search(tools)
@ -1747,6 +1762,9 @@ class ReActEngine:
data={"result": final_result}, data={"result": final_result},
) )
finally: finally:
# KTD-7: clear the restore guard so the next execute() without a
# restore_budget_state() call resets counters normally.
self._state_restored = False
# 结束轨迹记录 — always runs even if consumer doesn't fully iterate # 结束轨迹记录 — always runs even if consumer doesn't fully iterate
if trace_recorder is not None: if trace_recorder is not None:
trace_recorder.end_trace(outcome=trace_outcome) trace_recorder.end_trace(outcome=trace_outcome)

View File

@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -152,14 +153,84 @@ def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
return registry return registry
def _try_get_experience_store(server_config) -> Any | None:
"""Build a PostgreSQL ExperienceStore from server_config, or None if unavailable.
Mirrors cli/skill.py:_try_get_experience_store. database_url lookup order:
1. server_config.evolution.database_url
2. server_config.memory.episodic.database_url
3. DATABASE_URL env var
Returns an ExperienceStore instance or None (lazy import return type is
Any to avoid a module-level dependency on the experience_store module).
"""
database_url: str | None = None
evo_conf = getattr(server_config, "evolution", None) or {}
database_url = evo_conf.get("database_url") if isinstance(evo_conf, dict) else None
if not database_url:
epi_conf = (getattr(server_config, "memory", None) or {}).get("episodic", {})
database_url = epi_conf.get("database_url") if isinstance(epi_conf, dict) else None
if not database_url:
database_url = os.environ.get("DATABASE_URL")
if not database_url:
return None
try:
from agentkit.evolution.experience_store import ExperienceStore
from agentkit.memory.models import ExperienceModel, create_experience_session_factory
session_factory = create_experience_session_factory(database_url)
return ExperienceStore(
session_factory=session_factory,
experience_model=ExperienceModel,
)
except Exception as e:
logger.warning(f"Failed to create PostgreSQL ExperienceStore: {e}")
return None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
task_store = app.state.task_store task_store = app.state.task_store
await task_store.start_cleanup() await task_store.start_cleanup()
# Start config watcher if server_config is available # U7/R12 + U8/R8 (KTD-5): instantiate PitfallDetector + SpecManager as
# app-state singletons so PLAN_EXEC tasks can access them. PitfallDetector
# requires the PostgreSQL ExperienceStore; if unavailable (no DB), it is
# skipped gracefully (pitfall injection becomes a no-op). SpecManager is
# file-based and always available.
app.state.pitfall_detector = None
app.state.spec_manager = None
server_config = getattr(app.state, "server_config", None) server_config = getattr(app.state, "server_config", None)
try:
from agentkit.core.spec_manager import SpecManager
app.state.spec_manager = SpecManager()
logger.info("SpecManager initialized (file-based)")
except Exception: # noqa: BLE001 — SpecManager init; must not block startup
logger.debug("SpecManager init failed — spec persistence unavailable", exc_info=True)
try:
experience_store = _try_get_experience_store(server_config)
if experience_store is not None:
from agentkit.evolution.pitfall_detector import PitfallDetector
app.state.pitfall_detector = PitfallDetector(experience_store)
logger.info("PitfallDetector initialized (ExperienceStore ready)")
else:
logger.debug(
"PitfallDetector skipped — no PostgreSQL ExperienceStore configured "
"(pitfall injection is a no-op for PLAN_EXEC)"
)
except Exception: # noqa: BLE001 — PitfallDetector init; must not block startup
logger.debug("PitfallDetector init failed — pitfall injection disabled", exc_info=True)
# Start config watcher if server_config is available
if server_config is not None and server_config._config_path: if server_config is not None and server_config._config_path:
server_config.on_change = lambda cfg: _on_config_change(app, cfg) server_config.on_change = lambda cfg: _on_config_change(app, cfg)
server_config.watch_config() server_config.watch_config()
@ -253,6 +324,20 @@ async def lifespan(app: FastAPI):
try: try:
agent = await app.state.agent_pool.create_agent(default_config) agent = await app.state.agent_pool.create_agent(default_config)
# U7/R12 + U8/R8 (KTD-5): wire app-state singletons onto the default
# agent so its PLAN_EXEC path (ConfigDrivenAgent._handle_plan_exec_*)
# threads pitfall_detector + spec_manager into PlanExecEngine.
# ponytail: known gap — agents created later via
# AgentPool.create_agent/create_agent_from_skill (skill-loaded agents)
# do NOT receive these singletons because AgentPool does not forward
# them yet. Upgrade path: add pitfall_detector/spec_manager params to
# AgentPool.__init__ and pass through in create_agent(). The default
# chat agent is wired here as the most critical path; skill agents
# fall back to None (no pitfall injection / spec review) until the
# pool is updated.
agent._pitfall_detector = app.state.pitfall_detector
agent._spec_manager = app.state.spec_manager
# Register tools into the agent's tool registry # Register tools into the agent's tool registry
search_api_keys = { search_api_keys = {
"tavily_api_key": os.environ.get("TAVILY_API_KEY"), "tavily_api_key": os.environ.get("TAVILY_API_KEY"),

View File

@ -0,0 +1,170 @@
"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset().
Regression coverage for the P1 finding where ``_execute_loop`` called
``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint
counters, zeroing them out and breaking checkpoint reconstruction.
Covers:
- restore_budget_state() sets _state_restored flag
- execute() does NOT zero out restored counters (reset skipped)
- _state_restored flag is cleared after execute() finishes
- A subsequent execute() without restore resets counters normally
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState
from agentkit.core.react import ReActEngine
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── helpers ───────────────────────────────────────────────────────────
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
"""Mock LLMGateway. chat returns the responses in order."""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
def make_response(content: str = "") -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
tool_calls=[],
)
def _wildcard_policy(start: PhaseState) -> PhasePolicy:
"""PhasePolicy allowing all tools in all phases."""
return PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
start_phase=start,
)
# ── restore_budget_state + execute() integration (KTD-7) ──────────────
class TestRestoreBudgetStateSurvivesExecute:
"""KTD-7: restored counters must survive into _execute_loop (not zeroed)."""
async def test_restored_counters_survive_execute(self) -> None:
"""restore_budget_state() then execute() — counters must NOT be zeroed.
Without the fix, _execute_loop calls self.reset() which zeros
_think_count/_verify_count/_reflect_count. The _state_restored flag
guards against this.
"""
# Start in VERIFICATION so think_count is not incremented by the loop
# (the increment only happens in PLANNING/BUILDING phases).
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Simulate checkpoint restore
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
assert engine._think_count == 5
assert engine._verify_count == 2
assert engine._reflect_count == 1
# Execute — _execute_loop must skip reset() due to _state_restored
await engine.execute(
messages=[{"role": "user", "content": "resume checkpoint"}],
)
# Counters survived (think=5 unchanged because we started in VERIFICATION;
# verify/reflect unchanged because verification_enabled=False default).
assert engine._think_count == 5, (
f"Expected _think_count==5 (restored), got {engine._think_count} "
"(reset() zeroed the restored checkpoint — KTD-7 regression)"
)
assert engine._verify_count == 2
assert engine._reflect_count == 1
async def test_state_restored_flag_cleared_after_execute(self) -> None:
"""_state_restored must be cleared in finally so next execute() resets."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
await engine.execute(
messages=[{"role": "user", "content": "resume"}],
)
# Flag cleared in finally block
assert engine._state_restored is False, (
"_state_restored not cleared after execute() — subsequent execute() "
"calls would incorrectly skip reset()"
)
async def test_second_execute_without_restore_resets_counters(self) -> None:
"""After a restored execute(), the next execute() must reset normally."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway(
[make_response(content="first"), make_response(content="second")]
)
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# First execute with restored state
engine.restore_budget_state(think=5, verify=2, reflect=1)
await engine.execute(messages=[{"role": "user", "content": "resume"}])
assert engine._think_count == 5 # survived
# Second execute WITHOUT restore — must reset to 0
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
assert engine._think_count == 0, (
f"Expected _think_count==0 after fresh execute(), got "
f"{engine._think_count} (flag not cleared, reset incorrectly skipped)"
)
assert engine._verify_count == 0
assert engine._reflect_count == 0
async def test_execute_without_restore_behaves_unchanged(self) -> None:
"""No restore_budget_state() call — execute() resets as before (backward compat)."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Manually set counters (simulating stale state from a prior run)
engine._think_count = 9
engine._verify_count = 3
engine._reflect_count = 2
assert engine._state_restored is False
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
# reset() ran normally, zeroing the stale counters
assert engine._think_count == 0
assert engine._verify_count == 0
assert engine._reflect_count == 0