fix(review): wire pitfall_detector/spec_review to PlanExecEngine + fix restore_budget_state reset order

This commit is contained in:
chiguyong 2026-07-03 22:05:51 +08:00
parent f1f2e72cad
commit ffb7a51d77
4 changed files with 300 additions and 4 deletions

View File

@ -13,10 +13,14 @@ import logging
import os
from collections.abc import AsyncGenerator, Awaitable
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
from typing import TYPE_CHECKING, Any, Callable, Coroutine
import yaml
if TYPE_CHECKING:
from agentkit.core.spec_manager import SpecManager
from agentkit.evolution.pitfall_detector import PitfallDetector
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import ConfigValidationError, TaskCancelledError
from agentkit.core.protocol import (
@ -243,6 +247,11 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway: object | None = None, # NEW v2 param: LLMGateway
mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs
compressor: object | None = None, # CompressionStrategy | None
# U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine
# (KTD-5). None = skip pitfall injection / spec review gate (backward compat).
pitfall_detector: "PitfallDetector | None" = None,
spec_review_handler: Any | None = None,
spec_manager: "SpecManager | None" = None,
):
# v2: If SkillConfig, extract skill info
from agentkit.skills.base import SkillConfig, Skill
@ -324,6 +333,14 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
# v2: Store compressor for ReAct engine
self._compressor = compressor
# U7/R12 + U8/R8: app-state singletons threaded through to PlanExecEngine
# so PLAN_EXEC streaming/non-streaming paths actually invoke pitfall
# injection (R12) and the spec review gate (R8). None = no-op (backward
# compat). See _handle_plan_exec_stream / _handle_plan_exec.
self._pitfall_detector = pitfall_detector
self._spec_review_handler = spec_review_handler
self._spec_manager = spec_manager
# 从配置构建 Prompt 模板
if config.prompt:
sections = PromptSection(
@ -929,6 +946,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway=self._llm_gateway,
max_replans=2,
default_timeout=300.0,
pitfall_detector=self._pitfall_detector,
spec_review_handler=self._spec_review_handler,
spec_manager=self._spec_manager,
)
async for event in plan_exec_engine.execute_stream(
messages=user_messages,
@ -1118,6 +1138,9 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
llm_gateway=self._llm_gateway,
max_replans=2,
default_timeout=300.0,
pitfall_detector=self._pitfall_detector,
spec_review_handler=self._spec_review_handler,
spec_manager=self._spec_manager,
)
result = await plan_exec_engine.execute(

View File

@ -291,6 +291,11 @@ class ReActEngine:
# tracks error reinjections). Incremented each time a reflection is
# generated and injected for retry.
self._reflection_count: int = 0
# KTD-7: guard flag set by restore_budget_state() so _execute_loop's
# self.reset() call does NOT zero out the restored counters. Cleared in
# _execute_loop's finally block so subsequent execute() calls without a
# restore still reset properly.
self._state_restored: bool = False
def reset(self) -> None:
"""Reset internal state for reuse across conversations.
@ -391,6 +396,11 @@ class ReActEngine:
On resume, counters derive from persisted plan phase statuses, not
reset to zero. Call AFTER ``reset()`` but BEFORE ``execute()``.
Sets ``_state_restored`` so the subsequent ``execute()``/``execute_stream()``
call (which invokes ``_execute_loop`` ``self.reset()``) does NOT zero out
the restored counters. The flag is cleared in ``_execute_loop``'s finally
block so the next call without a restore resets normally.
Args:
think: Spent think steps (PLANNING/BUILDING phases).
verify: Spent verify attempts.
@ -399,6 +409,7 @@ class ReActEngine:
self._think_count = think
self._verify_count = verify
self._reflect_count = reflect
self._state_restored = True
def _force_advance_to_verification(self) -> None:
"""Force advance to VERIFICATION phase, skipping remaining think phases.
@ -764,7 +775,11 @@ class ReActEngine:
effective_timeout: 超时秒数stream=True 时在循环内检查
stream=False 时由 caller asyncio.wait_for 强制
"""
# P2 #9: Reset loop detection state so reuse across conversations is clean
# P2 #9: Reset loop detection state so reuse across conversations is clean.
# KTD-7: skip reset when restore_budget_state() was called so restored
# counters survive into the loop. Flag is cleared in the finally block
# below so the next execute() without a restore resets normally.
if not self._state_restored:
self.reset()
tools = tools or []
if tools:
@ -1747,6 +1762,9 @@ class ReActEngine:
data={"result": final_result},
)
finally:
# KTD-7: clear the restore guard so the next execute() without a
# restore_budget_state() call resets counters normally.
self._state_restored = False
# 结束轨迹记录 — always runs even if consumer doesn't fully iterate
if trace_recorder is not None:
trace_recorder.end_trace(outcome=trace_outcome)

View File

@ -4,6 +4,7 @@ import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -152,14 +153,84 @@ def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
return registry
def _try_get_experience_store(server_config) -> Any | None:
"""Build a PostgreSQL ExperienceStore from server_config, or None if unavailable.
Mirrors cli/skill.py:_try_get_experience_store. database_url lookup order:
1. server_config.evolution.database_url
2. server_config.memory.episodic.database_url
3. DATABASE_URL env var
Returns an ExperienceStore instance or None (lazy import return type is
Any to avoid a module-level dependency on the experience_store module).
"""
database_url: str | None = None
evo_conf = getattr(server_config, "evolution", None) or {}
database_url = evo_conf.get("database_url") if isinstance(evo_conf, dict) else None
if not database_url:
epi_conf = (getattr(server_config, "memory", None) or {}).get("episodic", {})
database_url = epi_conf.get("database_url") if isinstance(epi_conf, dict) else None
if not database_url:
database_url = os.environ.get("DATABASE_URL")
if not database_url:
return None
try:
from agentkit.evolution.experience_store import ExperienceStore
from agentkit.memory.models import ExperienceModel, create_experience_session_factory
session_factory = create_experience_session_factory(database_url)
return ExperienceStore(
session_factory=session_factory,
experience_model=ExperienceModel,
)
except Exception as e:
logger.warning(f"Failed to create PostgreSQL ExperienceStore: {e}")
return None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
task_store = app.state.task_store
await task_store.start_cleanup()
# Start config watcher if server_config is available
# U7/R12 + U8/R8 (KTD-5): instantiate PitfallDetector + SpecManager as
# app-state singletons so PLAN_EXEC tasks can access them. PitfallDetector
# requires the PostgreSQL ExperienceStore; if unavailable (no DB), it is
# skipped gracefully (pitfall injection becomes a no-op). SpecManager is
# file-based and always available.
app.state.pitfall_detector = None
app.state.spec_manager = None
server_config = getattr(app.state, "server_config", None)
try:
from agentkit.core.spec_manager import SpecManager
app.state.spec_manager = SpecManager()
logger.info("SpecManager initialized (file-based)")
except Exception: # noqa: BLE001 — SpecManager init; must not block startup
logger.debug("SpecManager init failed — spec persistence unavailable", exc_info=True)
try:
experience_store = _try_get_experience_store(server_config)
if experience_store is not None:
from agentkit.evolution.pitfall_detector import PitfallDetector
app.state.pitfall_detector = PitfallDetector(experience_store)
logger.info("PitfallDetector initialized (ExperienceStore ready)")
else:
logger.debug(
"PitfallDetector skipped — no PostgreSQL ExperienceStore configured "
"(pitfall injection is a no-op for PLAN_EXEC)"
)
except Exception: # noqa: BLE001 — PitfallDetector init; must not block startup
logger.debug("PitfallDetector init failed — pitfall injection disabled", exc_info=True)
# Start config watcher if server_config is available
if server_config is not None and server_config._config_path:
server_config.on_change = lambda cfg: _on_config_change(app, cfg)
server_config.watch_config()
@ -253,6 +324,20 @@ async def lifespan(app: FastAPI):
try:
agent = await app.state.agent_pool.create_agent(default_config)
# U7/R12 + U8/R8 (KTD-5): wire app-state singletons onto the default
# agent so its PLAN_EXEC path (ConfigDrivenAgent._handle_plan_exec_*)
# threads pitfall_detector + spec_manager into PlanExecEngine.
# ponytail: known gap — agents created later via
# AgentPool.create_agent/create_agent_from_skill (skill-loaded agents)
# do NOT receive these singletons because AgentPool does not forward
# them yet. Upgrade path: add pitfall_detector/spec_manager params to
# AgentPool.__init__ and pass through in create_agent(). The default
# chat agent is wired here as the most critical path; skill agents
# fall back to None (no pitfall injection / spec review) until the
# pool is updated.
agent._pitfall_detector = app.state.pitfall_detector
agent._spec_manager = app.state.spec_manager
# Register tools into the agent's tool registry
search_api_keys = {
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),

View File

@ -0,0 +1,170 @@
"""Unit tests for KTD-7: restore_budget_state() survives _execute_loop's reset().
Regression coverage for the P1 finding where ``_execute_loop`` called
``self.reset()`` AFTER ``restore_budget_state()`` had set the checkpoint
counters, zeroing them out and breaking checkpoint reconstruction.
Covers:
- restore_budget_state() sets _state_restored flag
- execute() does NOT zero out restored counters (reset skipped)
- _state_restored flag is cleared after execute() finishes
- A subsequent execute() without restore resets counters normally
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.phase import WILDCARD, PhasePolicy, PhaseState
from agentkit.core.react import ReActEngine
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── helpers ───────────────────────────────────────────────────────────
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
"""Mock LLMGateway. chat returns the responses in order."""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
def make_response(content: str = "") -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
tool_calls=[],
)
def _wildcard_policy(start: PhaseState) -> PhasePolicy:
"""PhasePolicy allowing all tools in all phases."""
return PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
start_phase=start,
)
# ── restore_budget_state + execute() integration (KTD-7) ──────────────
class TestRestoreBudgetStateSurvivesExecute:
"""KTD-7: restored counters must survive into _execute_loop (not zeroed)."""
async def test_restored_counters_survive_execute(self) -> None:
"""restore_budget_state() then execute() — counters must NOT be zeroed.
Without the fix, _execute_loop calls self.reset() which zeros
_think_count/_verify_count/_reflect_count. The _state_restored flag
guards against this.
"""
# Start in VERIFICATION so think_count is not incremented by the loop
# (the increment only happens in PLANNING/BUILDING phases).
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Simulate checkpoint restore
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
assert engine._think_count == 5
assert engine._verify_count == 2
assert engine._reflect_count == 1
# Execute — _execute_loop must skip reset() due to _state_restored
await engine.execute(
messages=[{"role": "user", "content": "resume checkpoint"}],
)
# Counters survived (think=5 unchanged because we started in VERIFICATION;
# verify/reflect unchanged because verification_enabled=False default).
assert engine._think_count == 5, (
f"Expected _think_count==5 (restored), got {engine._think_count} "
"(reset() zeroed the restored checkpoint — KTD-7 regression)"
)
assert engine._verify_count == 2
assert engine._reflect_count == 1
async def test_state_restored_flag_cleared_after_execute(self) -> None:
"""_state_restored must be cleared in finally so next execute() resets."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
engine.restore_budget_state(think=5, verify=2, reflect=1)
assert engine._state_restored is True
await engine.execute(
messages=[{"role": "user", "content": "resume"}],
)
# Flag cleared in finally block
assert engine._state_restored is False, (
"_state_restored not cleared after execute() — subsequent execute() "
"calls would incorrectly skip reset()"
)
async def test_second_execute_without_restore_resets_counters(self) -> None:
"""After a restored execute(), the next execute() must reset normally."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway(
[make_response(content="first"), make_response(content="second")]
)
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# First execute with restored state
engine.restore_budget_state(think=5, verify=2, reflect=1)
await engine.execute(messages=[{"role": "user", "content": "resume"}])
assert engine._think_count == 5 # survived
# Second execute WITHOUT restore — must reset to 0
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
assert engine._think_count == 0, (
f"Expected _think_count==0 after fresh execute(), got "
f"{engine._think_count} (flag not cleared, reset incorrectly skipped)"
)
assert engine._verify_count == 0
assert engine._reflect_count == 0
async def test_execute_without_restore_behaves_unchanged(self) -> None:
"""No restore_budget_state() call — execute() resets as before (backward compat)."""
policy = _wildcard_policy(start=PhaseState.VERIFICATION)
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=policy,
phase_budgets={"think": 7, "verify": 2, "reflect": 1},
)
# Manually set counters (simulating stale state from a prior run)
engine._think_count = 9
engine._verify_count = 3
engine._reflect_count = 2
assert engine._state_restored is False
await engine.execute(messages=[{"role": "user", "content": "fresh"}])
# reset() ran normally, zeroing the stale counters
assert engine._think_count == 0
assert engine._verify_count == 0
assert engine._reflect_count == 0