feat: SQLite persistence, verification loop, spec-driven execution

Phase 2 of architecture optimization (U5/U6/U9):

- U5: SqliteConversationStore with WAL mode + LRU cache (1000 convs)
  Replaces in-memory ConversationStore in portal.py
  Data survives server restarts (ref: Codex Thread persistence)
- U6: VerificationLoop with verify/verify_and_retry
  Default commands: pytest + ruff check
  ReActEngine integration via verification_enabled flag
  New run_tests tool for LLM to invoke verification
- U9: SpecManager for plan-as-contract (ref: Qoder Quest Mode)
  Plans persisted to .agentkit/specs/{spec_id}.yaml
  API: GET/PUT /api/v1/specs, POST /api/v1/specs/{id}/confirm
  PlanExecEngine emits spec_created event after plan generation

Also fixes: portal skill_name routing, app.py SessionManager guard,
test_telemetry CostAwareRouter removal, test_compression_config fixture
This commit is contained in:
chiguyong 2026-06-17 10:45:20 +08:00
parent 5374bc8501
commit 200174c5c7
12 changed files with 1560 additions and 194 deletions

View File

@ -0,0 +1,294 @@
"""SQLite-backed conversation store with async support and in-memory LRU cache."""
from __future__ import annotations
import json
import logging
import os
import uuid
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiosqlite
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data classes (mirrors portal.py ChatMessage / Conversation)
# ---------------------------------------------------------------------------
@dataclass
class ChatMessage:
role: str # "user" or "assistant"
content: str
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: dict = field(default_factory=dict)
@dataclass
class Conversation:
id: str
messages: list[ChatMessage] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
_SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
metadata TEXT DEFAULT '{}',
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
);
CREATE INDEX IF NOT EXISTS idx_messages_conv_id ON messages(conversation_id);
"""
class SqliteConversationStore:
"""SQLite-backed conversation store with an in-memory LRU cache.
Drop-in replacement for the in-memory ConversationStore in portal.py.
Data is persisted to SQLite so it survives server restarts.
An in-memory LRU cache of recent conversations provides fast access.
"""
def __init__(
self,
db_path: str | Path | None = None,
max_conversations: int = 1000,
) -> None:
if db_path is None:
db_path = os.path.expanduser("~/.agentkit/conversations.db")
self._db_path = str(db_path)
self._max = max_conversations
self._cache: OrderedDict[str, Conversation] = OrderedDict()
self._db: aiosqlite.Connection | None = None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _ensure_db(self) -> aiosqlite.Connection:
"""Lazily open the database and ensure schema exists."""
if self._db is None:
db_dir = os.path.dirname(self._db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
self._db = await aiosqlite.connect(self._db_path)
self._db.row_factory = aiosqlite.Row
await self._db.execute("PRAGMA journal_mode=WAL")
await self._db.executescript(_SCHEMA_SQL)
await self._db.commit()
return self._db
async def _close_db(self) -> None:
"""Close the database connection."""
if self._db is not None:
await self._db.close()
self._db = None
@staticmethod
def _dt_to_str(dt: datetime) -> str:
return dt.isoformat()
@staticmethod
def _str_to_dt(s: str) -> datetime:
return datetime.fromisoformat(s)
def _touch_cache(self, conv_id: str) -> None:
"""Move conversation to the end of the LRU cache (most recently used)."""
if conv_id in self._cache:
self._cache.move_to_end(conv_id)
def _evict_if_needed(self) -> None:
"""Evict oldest entry from cache if over limit (does NOT delete from SQLite)."""
while len(self._cache) > self._max:
self._cache.popitem(last=False)
# ------------------------------------------------------------------
# Public API (matches ConversationStore interface)
# ------------------------------------------------------------------
async def get_or_create(self, conversation_id: str | None = None) -> Conversation:
"""Get existing conversation or create a new one.
If the conversation is in cache, return it directly.
Otherwise, try to load from SQLite. If not found, create new.
Messages are loaded lazily (only when get_history is called).
"""
db = await self._ensure_db()
if conversation_id and conversation_id in self._cache:
conv = self._cache[conversation_id]
conv.updated_at = datetime.now(timezone.utc)
self._touch_cache(conv.id)
return conv
# Try loading from SQLite
if conversation_id:
cursor = await db.execute(
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
(conversation_id,),
)
row = await cursor.fetchone()
if row:
conv = Conversation(
id=row["id"],
created_at=self._str_to_dt(row["created_at"]),
updated_at=self._str_to_dt(row["updated_at"]),
)
self._cache[conv.id] = conv
self._evict_if_needed()
return conv
# Create new
cid = conversation_id or str(uuid.uuid4())
now = datetime.now(timezone.utc)
conv = Conversation(id=cid, created_at=now, updated_at=now)
await db.execute(
"INSERT INTO conversations (id, created_at, updated_at) VALUES (?, ?, ?)",
(conv.id, self._dt_to_str(conv.created_at), self._dt_to_str(conv.updated_at)),
)
await db.commit()
self._cache[conv.id] = conv
self._evict_if_needed()
return conv
async def add_message(
self,
conversation_id: str,
role: str,
content: str,
metadata: dict | None = None,
) -> ChatMessage:
"""Add a message to a conversation (both in-memory cache and SQLite)."""
db = await self._ensure_db()
# Ensure conversation exists in cache
if conversation_id not in self._cache:
# Try loading from SQLite
cursor = await db.execute(
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
(conversation_id,),
)
row = await cursor.fetchone()
if row:
conv = Conversation(
id=row["id"],
created_at=self._str_to_dt(row["created_at"]),
updated_at=self._str_to_dt(row["updated_at"]),
)
self._cache[conv.id] = conv
else:
raise KeyError(f"Conversation '{conversation_id}' not found")
conv = self._cache[conversation_id]
msg = ChatMessage(
role=role,
content=content,
metadata=metadata or {},
)
conv.messages.append(msg)
conv.updated_at = datetime.now(timezone.utc)
self._touch_cache(conv.id)
# Persist to SQLite
await db.execute(
"INSERT INTO messages (conversation_id, role, content, timestamp, metadata) "
"VALUES (?, ?, ?, ?, ?)",
(
conversation_id,
role,
content,
self._dt_to_str(msg.timestamp),
json.dumps(msg.metadata, default=str),
),
)
await db.execute(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
(self._dt_to_str(conv.updated_at), conversation_id),
)
await db.commit()
return msg
async def get_history(self, conversation_id: str, limit: int = 50) -> list[ChatMessage]:
"""Get recent messages for a conversation.
Loads from SQLite to ensure completeness (in-memory cache may only
have a subset of messages after a restart).
"""
db = await self._ensure_db()
cursor = await db.execute(
"SELECT role, content, timestamp, metadata FROM messages "
"WHERE conversation_id = ? ORDER BY id DESC LIMIT ?",
(conversation_id, limit),
)
rows = await cursor.fetchall()
# Reverse because we fetched DESC but want chronological order
messages: list[ChatMessage] = []
for row in reversed(rows):
meta: dict[str, Any] = {}
try:
meta = json.loads(row["metadata"]) if row["metadata"] else {}
except (json.JSONDecodeError, TypeError):
pass
messages.append(
ChatMessage(
role=row["role"],
content=row["content"],
timestamp=self._str_to_dt(row["timestamp"]),
metadata=meta,
)
)
return messages
async def list_conversations(self, limit: int = 20) -> list[Conversation]:
"""List recent conversations ordered by updated_at (most recent first)."""
db = await self._ensure_db()
cursor = await db.execute(
"SELECT id, created_at, updated_at FROM conversations "
"ORDER BY updated_at DESC LIMIT ?",
(limit,),
)
rows = await cursor.fetchall()
result: list[Conversation] = []
for row in rows:
conv_id = row["id"]
# Use cached version if available (may have in-memory messages)
if conv_id in self._cache:
result.append(self._cache[conv_id])
else:
result.append(
Conversation(
id=conv_id,
created_at=self._str_to_dt(row["created_at"]),
updated_at=self._str_to_dt(row["updated_at"]),
)
)
return result
async def restore_from_store(
self,
max_sessions: int = 50,
max_messages_per_session: int = 100,
) -> None:
"""No-op for SQLite store — data is already persisted in the database."""
# Nothing to do; all data lives in SQLite and is loaded on demand.

View File

@ -0,0 +1,171 @@
"""Spec Manager — 执行计划规格文档管理器
PlanExecEngine 生成的执行计划持久化为 Spec 文档
用户可在执行前查看编辑和确认 Spec作为人与 AI 之间的契约
"""
from __future__ import annotations
import logging
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import yaml
logger = logging.getLogger(__name__)
@dataclass
class SpecStep:
"""A single step in a spec."""
step_id: str
name: str
description: str
dependencies: list[str] = field(default_factory=list)
status: str = "pending" # pending | confirmed | executing | completed | failed
@dataclass
class Spec:
"""A specification document for a planned task."""
spec_id: str
goal: str
steps: list[SpecStep] = field(default_factory=list)
status: str = "draft" # draft | confirmed | executing | completed | failed
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
confirmed_at: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class SpecManager:
"""Manages Spec documents as first-class citizens.
Specs are persisted to .agentkit/specs/ directory as YAML files.
Users can view, edit, and confirm specs before execution.
"""
def __init__(self, specs_dir: str | None = None):
"""
Args:
specs_dir: Directory for spec files. Default: .agentkit/specs/
"""
self._specs_dir = Path(specs_dir or ".agentkit/specs")
self._specs_dir.mkdir(parents=True, exist_ok=True)
self._cache: dict[str, Spec] = {}
def create(self, spec: Spec) -> Path:
"""Persist a Spec to disk. Returns the file path."""
path = self._specs_dir / f"{spec.spec_id}.yaml"
data = asdict(spec)
path.write_text(yaml.dump(data, allow_unicode=True, default_flow_style=False), encoding="utf-8")
self._cache[spec.spec_id] = spec
logger.info(f"Spec created: {spec.spec_id} -> {path}")
return path
def get(self, spec_id: str) -> Spec | None:
"""Load a Spec by ID."""
if spec_id in self._cache:
return self._cache[spec_id]
path = self._specs_dir / f"{spec_id}.yaml"
if not path.exists():
return None
try:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
spec = self._dict_to_spec(data)
self._cache[spec_id] = spec
return spec
except Exception as e:
logger.error(f"Failed to load spec {spec_id}: {e}")
return None
def update(self, spec_id: str, **kwargs: Any) -> Spec | None:
"""Update spec fields and persist."""
spec = self.get(spec_id)
if spec is None:
return None
for key, value in kwargs.items():
if key == "steps" and isinstance(value, list):
spec.steps = [self._dict_to_step(s) if isinstance(s, dict) else s for s in value]
elif hasattr(spec, key):
setattr(spec, key, value)
self.create(spec) # re-persist
return spec
def confirm(self, spec_id: str) -> Spec | None:
"""Mark a spec as confirmed (user approved execution)."""
spec = self.get(spec_id)
if spec is None:
return None
spec.status = "confirmed"
spec.confirmed_at = datetime.now(timezone.utc).isoformat()
# Mark all steps as confirmed
for step in spec.steps:
if step.status == "pending":
step.status = "confirmed"
self.create(spec) # re-persist
logger.info(f"Spec confirmed: {spec_id}")
return spec
def list_specs(self, status: str | None = None) -> list[Spec]:
"""List all specs, optionally filtered by status. Sorted by created_at desc."""
specs: list[Spec] = []
for path in self._specs_dir.glob("*.yaml"):
try:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
spec = self._dict_to_spec(data)
if status is None or spec.status == status:
specs.append(spec)
except Exception as e:
logger.warning(f"Failed to load spec from {path}: {e}")
specs.sort(key=lambda s: s.created_at, reverse=True)
return specs
def delete(self, spec_id: str) -> bool:
"""Delete a spec file."""
path = self._specs_dir / f"{spec_id}.yaml"
if not path.exists():
return False
path.unlink()
self._cache.pop(spec_id, None)
logger.info(f"Spec deleted: {spec_id}")
return True
@staticmethod
def _dict_to_spec(data: dict[str, Any]) -> Spec:
"""Convert a dict to a Spec instance."""
steps = [SpecManager._dict_to_step(s) for s in data.get("steps", [])]
return Spec(
spec_id=data["spec_id"],
goal=data["goal"],
steps=steps,
status=data.get("status", "draft"),
created_at=data.get("created_at", ""),
confirmed_at=data.get("confirmed_at"),
metadata=data.get("metadata", {}),
)
@staticmethod
def _dict_to_step(data: dict[str, Any] | SpecStep) -> SpecStep:
"""Convert a dict to a SpecStep instance."""
if isinstance(data, SpecStep):
return data
return SpecStep(
step_id=data["step_id"],
name=data["name"],
description=data["description"],
dependencies=data.get("dependencies", []),
status=data.get("status", "pending"),
)

View File

@ -0,0 +1,145 @@
"""验证循环 - 执行后运行项目测试验证结果
遵循 Codex Cloud 模式execute verify retry on failure
默认验证命令pytest, ruff check
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class VerificationResult:
"""验证循环运行结果"""
passed: bool
attempts: int
test_output: str
errors: list[str] = field(default_factory=list)
class VerificationLoop:
"""执行后运行项目测试验证结果
遵循 Codex Cloud 模式execute verify retry on failure
默认验证命令pytest, ruff check
"""
def __init__(
self,
commands: list[str] | None = None,
max_retries: int = 2,
working_dir: str | None = None,
timeout: float = 60.0,
):
"""
Args:
commands: 验证用的 Shell 命令列表
默认: ["pytest -x -q", "ruff check src/"]
max_retries: 初始执行后的最大重试次数
working_dir: 运行命令的工作目录
timeout: 每个验证命令的超时时间
"""
self._commands = commands or ["pytest -x -q", "ruff check src/"]
self._max_retries = max_retries
self._working_dir = working_dir
self._timeout = timeout
async def verify(self) -> VerificationResult:
"""运行验证命令并返回结果
依次执行每个命令捕获 stdout/stderr
如果任何命令失败非零退出码标记为失败
Returns:
VerificationResult所有命令成功时 passed=True
"""
all_output: list[str] = []
errors: list[str] = []
passed = True
for cmd in self._commands:
try:
proc = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=self._working_dir,
)
stdout, _ = await asyncio.wait_for(
proc.communicate(),
timeout=self._timeout,
)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
output = f"Command timed out after {self._timeout}s: {cmd}"
all_output.append(output)
errors.append(output)
passed = False
continue
except Exception as e:
output = f"Command failed to execute: {cmd}: {e}"
all_output.append(output)
errors.append(output)
passed = False
continue
output = stdout.decode("utf-8", errors="replace") if stdout else ""
all_output.append(f"$ {cmd}\n{output}")
if proc.returncode != 0:
passed = False
errors.append(f"Command failed (exit {proc.returncode}): {cmd}\n{output}")
return VerificationResult(
passed=passed,
attempts=1,
test_output="\n".join(all_output),
errors=errors,
)
async def verify_and_retry(
self,
fix_callback: Any = None,
) -> VerificationResult:
"""运行验证,如果失败则调用 fix_callback 并重试
Args:
fix_callback: 异步可调用对象接收 (errors, test_output)
尝试修复问题在重试之间调用
Returns:
VerificationResult包含最终验证状态
"""
result = await self.verify()
retries = 0
while not result.passed and retries < self._max_retries:
retries += 1
logger.info(
"Verification failed (attempt %d/%d), %s",
retries,
self._max_retries,
"calling fix_callback" if fix_callback else "retrying",
)
if fix_callback is not None:
try:
await fix_callback(result.errors, result.test_output)
except Exception as e:
logger.warning("fix_callback raised an error: %s", e)
result = await self.verify()
result.attempts = retries + 1
return result

View File

@ -22,13 +22,14 @@ from pydantic import BaseModel
from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.react import ReActEngine
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
from agentkit.chat.simple_router import SimpleRouter
from agentkit.chat.request_preprocessor import RequestPreprocessor
from agentkit.server.routes.evolution_dashboard import (
_experiences as _dashboard_experiences,
DashboardExperience,
_broadcast_event as _broadcast_dashboard_event,
)
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
from agentkit.session.manager import SessionManager
logger = logging.getLogger(__name__)
@ -103,7 +104,9 @@ class ConversationStore:
conversations can be restored from SessionManager.
"""
def __init__(self, max_conversations: int = 1000, session_manager: SessionManager | None = None):
def __init__(
self, max_conversations: int = 1000, session_manager: SessionManager | None = None
):
self._conversations: dict[str, Conversation] = {}
self._max = max_conversations
self._session_manager = session_manager
@ -141,16 +144,16 @@ class ConversationStore:
sid, limit=max_messages_per_session
)
for msg in messages:
conv.messages.append(ChatMessage(
role=msg.role.value,
content=msg.content,
timestamp=msg.created_at,
metadata=msg.metadata,
))
conv.messages.append(
ChatMessage(
role=msg.role.value,
content=msg.content,
timestamp=msg.created_at,
metadata=msg.metadata,
)
)
self._conversations[sid] = conv
logger.info(
f"Restored {len(self._conversations)} conversations from SessionManager"
)
logger.info(f"Restored {len(self._conversations)} conversations from SessionManager")
except Exception as e:
logger.warning(f"Failed to restore conversations from SessionManager: {e}")
@ -217,7 +220,7 @@ class ConversationStore:
# Heartbeat timeout in seconds — 0 disables timeout (for testing)
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
_conversation_store = ConversationStore()
_conversation_store = SqliteConversationStore()
# ---------------------------------------------------------------------------
# History injection helper — configurable limit + optional compression
@ -227,7 +230,7 @@ _conversation_store = ConversationStore()
_MAX_HISTORY_MESSAGES = 50
def _build_history_messages(
async def _build_history_messages(
conv_id: str,
limit: int = _MAX_HISTORY_MESSAGES,
) -> list[dict]:
@ -238,7 +241,7 @@ def _build_history_messages(
which should be appended separately by the caller).
"""
try:
history = _conversation_store.get_history(conv_id, limit=limit)
history = await _conversation_store.get_history(conv_id, limit=limit)
except Exception:
return []
@ -342,14 +345,16 @@ class CapabilitiesResponse(BaseModel):
async def _resolve_for_chat(
request: ChatRequest, req: Request
) -> tuple[ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None]:
"""Resolve agent and routing for a chat request via SimpleRouter.
) -> tuple[
ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None
]:
"""Resolve agent and routing for a chat request via RequestPreprocessor.
Returns (agent, routing_result, matched_skill_name, routing_method, confidence).
"""
pool = req.app.state.agent_pool
skill_registry = req.app.state.skill_registry
simple_router: SimpleRouter = req.app.state.simple_router
request_preprocessor: RequestPreprocessor = req.app.state.request_preprocessor
matched_skill_name: str | None = None
routing_method: str | None = None
@ -362,8 +367,7 @@ async def _resolve_for_chat(
if default_agent is not None:
default_tools = default_agent.get_tools()
default_system_prompt = (
getattr(default_agent, "_system_prompt", None)
or default_agent.get_system_prompt()
getattr(default_agent, "_system_prompt", None) or default_agent.get_system_prompt()
)
else:
all_skills = skill_registry.list_skills()
@ -376,15 +380,26 @@ async def _resolve_for_chat(
)
break
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
routing_result = await simple_router.route(
content=request.message,
skill_registry=skill_registry,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model="default",
default_agent_name="default",
)
# If skill_name is explicitly provided in the request, use it directly
if request.skill_name:
routing_result = await request_preprocessor.preprocess(
content=f"@skill:{request.skill_name} {request.message}",
skill_registry=skill_registry,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model="default",
default_agent_name="default",
)
else:
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
routing_result = await request_preprocessor.preprocess(
content=request.message,
skill_registry=skill_registry,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model="default",
default_agent_name="default",
)
matched_skill_name = routing_result.skill_name or routing_result.agent_name
routing_method = routing_result.match_method
@ -413,11 +428,19 @@ async def _resolve_for_chat(
@router.post("/portal/chat", response_model=ChatResponse)
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
"""Send a chat message and get a response with CostAwareRouter routing."""
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
"""Send a chat message and get a response with RequestPreprocessor routing."""
# If skill_name is explicitly requested but not found, return 404
if request.skill_name:
skill_registry = req.app.state.skill_registry
if not skill_registry.has_skill(request.skill_name):
raise HTTPException(status_code=404, detail=f"Skill '{request.skill_name}' not found")
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
request, req
)
# Create or reuse conversation
conv = _conversation_store.get_or_create(request.conversation_id)
conv = await _conversation_store.get_or_create(request.conversation_id)
await _conversation_store.add_message(conv.id, "user", request.message)
llm_gateway = req.app.state.llm_gateway
@ -432,7 +455,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
chat_messages.append({"role": "user", "content": request.message})
# Inject conversation history
history_msgs = _build_history_messages(conv.id)
history_msgs = await _build_history_messages(conv.id)
for hm in history_msgs:
chat_messages.insert(-1, hm)
response = await llm_gateway.chat(
@ -467,7 +490,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
messages = [{"role": "user", "content": request.message}]
# Inject conversation history
history_msgs = _build_history_messages(conv.id)
history_msgs = await _build_history_messages(conv.id)
for hm in reversed(history_msgs):
messages.insert(0, hm)
tools = agent.get_tools()
@ -490,7 +513,9 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
except Exception as e:
response_text = f"执行出错: {e}"
else:
response_text = _ensure_non_empty("".join(collected_output) if collected_output else None)
response_text = _ensure_non_empty(
"".join(collected_output) if collected_output else None
)
await _conversation_store.add_message(conv.id, "assistant", response_text)
@ -508,13 +533,15 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
@router.post("/portal/chat/stream")
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
"""Stream chat responses via SSE with CostAwareRouter routing."""
"""Stream chat responses via SSE with RequestPreprocessor routing."""
from sse_starlette.sse import EventSourceResponse
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
request, req
)
# Create or reuse conversation
conv = _conversation_store.get_or_create(request.conversation_id)
conv = await _conversation_store.get_or_create(request.conversation_id)
await _conversation_store.add_message(conv.id, "user", request.message)
llm_gateway = req.app.state.llm_gateway
@ -532,13 +559,16 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
),
}
if routing_result is not None and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT:
if (
routing_result is not None
and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT
):
# DIRECT_CHAT: direct LLM call, no ReAct loop
chat_messages = []
if routing_result.system_prompt:
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
chat_messages.append({"role": "user", "content": request.message})
history_msgs = _build_history_messages(conv.id)
history_msgs = await _build_history_messages(conv.id)
for hm in history_msgs:
chat_messages.insert(-1, hm)
response = await llm_gateway.chat(
@ -552,7 +582,11 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
yield {
"event": "final_answer",
"data": json.dumps(
{"step": 0, "data": {"output": response_text}, "timestamp": datetime.now(timezone.utc).isoformat()}
{
"step": 0,
"data": {"output": response_text},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
),
}
else:
@ -654,7 +688,7 @@ async def get_capabilities(req: Request, _auth: None = Depends(_verify_api_key))
@router.get("/portal/conversations")
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
"""List recent conversations."""
convs = _conversation_store.list_conversations(limit=limit)
convs = await _conversation_store.list_conversations(limit=limit)
return [
{
"id": c.id,
@ -679,55 +713,27 @@ def _derive_conversation_title(conv: Conversation) -> str:
async def get_conversation(
conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)
):
"""Get conversation history, with fallback to SessionManager for persisted data."""
# Try in-memory first
if conversation_id in _conversation_store._conversations:
conv = _conversation_store._conversations[conversation_id]
history = _conversation_store.get_history(conversation_id, limit=limit)
return {
"id": conv.id,
"title": _derive_conversation_title(conv),
"messages": [
{
"id": f"{conv.id}-{i}",
"role": m.role,
"content": m.content,
"timestamp": m.timestamp.isoformat(),
"metadata": m.metadata,
}
for i, m in enumerate(history)
],
"created_at": conv.created_at.isoformat(),
"updated_at": conv.updated_at.isoformat(),
}
# Fallback: load from SessionManager (persistent store)
sm = _conversation_store._session_manager
if sm is not None:
try:
session = await sm.get_session(conversation_id)
if session is not None:
messages = await sm.get_messages(conversation_id, limit=limit)
return {
"id": session.session_id,
"title": _derive_title_from_messages(messages),
"messages": [
{
"id": f"{session.session_id}-{i}",
"role": m.role.value,
"content": m.content,
"timestamp": m.created_at.isoformat(),
"metadata": m.metadata,
}
for i, m in enumerate(messages)
],
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
}
except Exception as e:
logger.warning(f"Failed to load conversation from SessionManager: {e}")
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
"""Get conversation history from SQLite-backed store."""
history = await _conversation_store.get_history(conversation_id, limit=limit)
if not history:
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
conv = await _conversation_store.get_or_create(conversation_id)
return {
"id": conv.id,
"title": _derive_conversation_title(conv),
"messages": [
{
"id": f"{conv.id}-{i}",
"role": m.role,
"content": m.content,
"timestamp": m.timestamp.isoformat(),
"metadata": m.metadata,
}
for i, m in enumerate(history)
],
"created_at": conv.created_at.isoformat(),
"updated_at": conv.updated_at.isoformat(),
}
def _derive_title_from_messages(messages: list) -> str:
@ -807,7 +813,7 @@ async def portal_websocket(websocket: WebSocket):
# Create conversation on first message (not on connect)
if conv is None:
conv_id = msg.get("conversation_id")
conv = _conversation_store.get_or_create(conv_id)
conv = await _conversation_store.get_or_create(conv_id)
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
# Add user message to conversation
@ -841,15 +847,15 @@ async def portal_websocket(websocket: WebSocket):
except Exception as e:
logger.warning(f"Failed to record experience: {e}")
# Unified routing via SimpleRouter (minimal: @skill prefix + greeting regex + REACT)
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
pool = websocket.app.state.agent_pool
skill_registry = websocket.app.state.skill_registry
llm_gateway = websocket.app.state.llm_gateway
simple_router: SimpleRouter = websocket.app.state.simple_router
request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
all_skills = skill_registry.list_skills()
# Get default tools for SimpleRouter routing
# Get default tools for RequestPreprocessor
default_tools = []
default_system_prompt = None
default_agent = pool.get_agent("default")
@ -869,8 +875,8 @@ async def portal_websocket(websocket: WebSocket):
)
break
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
routing_result = await simple_router.route(
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
routing_result = await request_preprocessor.preprocess(
content=message_text,
skill_registry=skill_registry,
default_tools=default_tools,
@ -901,7 +907,7 @@ async def portal_websocket(websocket: WebSocket):
)
chat_messages.append({"role": "user", "content": message_text})
# Inject conversation history for context continuity
history_msgs = _build_history_messages(conv.id)
history_msgs = await _build_history_messages(conv.id)
for hm in history_msgs:
chat_messages.insert(-1, hm)
response = await llm_gateway.chat(
@ -960,7 +966,7 @@ async def portal_websocket(websocket: WebSocket):
)
chat_messages.append({"role": "user", "content": message_text})
try:
history = _conversation_store.get_history(conv.id, limit=20)
history = await _conversation_store.get_history(conv.id, limit=20)
for hist_msg in history[:-1]:
if hist_msg.role in ("user", "assistant"):
chat_messages.insert(
@ -1009,7 +1015,7 @@ async def portal_websocket(websocket: WebSocket):
messages = [{"role": "user", "content": message_text}]
# Inject conversation history for context continuity
history_msgs = _build_history_messages(conv.id)
history_msgs = await _build_history_messages(conv.id)
for hm in reversed(history_msgs):
messages.insert(0, hm)
tools = agent.get_tools()

View File

@ -2,6 +2,7 @@
import json
import uuid
from dataclasses import asdict
from datetime import datetime, timezone
from fastapi import APIRouter, HTTPException, Request
@ -9,6 +10,7 @@ from pydantic import BaseModel
from typing import Any
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.core.spec_manager import SpecManager
router = APIRouter(tags=["tasks"])
@ -350,3 +352,71 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
}
return EventSourceResponse(event_generator())
# ---------------------------------------------------------------------------
# Spec endpoints
# ---------------------------------------------------------------------------
def _get_spec_manager(req: Request) -> SpecManager:
"""Get or create SpecManager from app state."""
if not hasattr(req.app.state, "spec_manager") or req.app.state.spec_manager is None:
req.app.state.spec_manager = SpecManager()
return req.app.state.spec_manager
class UpdateSpecRequest(BaseModel):
"""Request body for updating a spec."""
goal: str | None = None
steps: list[dict[str, Any]] | None = None
status: str | None = None
metadata: dict[str, Any] | None = None
@router.get("/specs")
async def list_specs(status: str | None = None, req: Request = None):
"""List all specs, optionally filtered by status."""
mgr = _get_spec_manager(req)
specs = mgr.list_specs(status=status)
return [asdict(s) for s in specs]
@router.get("/specs/{spec_id}")
async def get_spec(spec_id: str, req: Request):
"""Get a specific spec by ID."""
mgr = _get_spec_manager(req)
spec = mgr.get(spec_id)
if spec is None:
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
return asdict(spec)
@router.put("/specs/{spec_id}")
async def update_spec(spec_id: str, body: UpdateSpecRequest, req: Request):
"""Update a spec (e.g., edit steps, change status)."""
mgr = _get_spec_manager(req)
kwargs: dict[str, Any] = {}
if body.goal is not None:
kwargs["goal"] = body.goal
if body.steps is not None:
kwargs["steps"] = body.steps
if body.status is not None:
kwargs["status"] = body.status
if body.metadata is not None:
kwargs["metadata"] = body.metadata
spec = mgr.update(spec_id, **kwargs)
if spec is None:
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
return asdict(spec)
@router.post("/specs/{spec_id}/confirm")
async def confirm_spec(spec_id: str, req: Request):
"""Confirm a spec for execution."""
mgr = _get_spec_manager(req)
spec = mgr.confirm(spec_id)
if spec is None:
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
return asdict(spec)

View File

@ -0,0 +1,185 @@
"""内置工具 - 开箱即用的常用工具集合"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from agentkit.tools.base import Tool
if TYPE_CHECKING:
from agentkit.tools.search import ToolSearchIndex
logger = logging.getLogger(__name__)
class RunTestsTool(Tool):
"""运行项目测试验证代码变更
执行 pytest linting 命令来验证代码变更的正确性
内部使用 VerificationLoop 实现验证逻辑
"""
def __init__(
self,
name: str = "run_tests",
description: str = "Run project tests to verify code changes. Executes pytest and linting commands.",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
working_dir: str | None = None,
timeout: float = 60.0,
max_retries: int = 0,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema,
version=version,
tags=tags or ["testing", "verification"],
)
self._working_dir = working_dir
self._timeout = timeout
self._max_retries = max_retries
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"commands": {
"type": "array",
"items": {"type": "string"},
"description": "Shell commands to run for verification. Default: ['pytest -x -q', 'ruff check src/']",
},
},
}
async def execute(self, **kwargs) -> dict:
"""执行测试验证
Args:
commands: 可选的验证命令列表默认使用 pytest ruff check
Returns:
包含 passed, attempts, test_output, errors 的字典
"""
from agentkit.core.verification_loop import VerificationLoop
commands: list[str] | None = kwargs.get("commands")
loop = VerificationLoop(
commands=commands,
max_retries=self._max_retries,
working_dir=self._working_dir,
timeout=self._timeout,
)
result = await loop.verify()
return {
"passed": result.passed,
"attempts": result.attempts,
"test_output": result.test_output,
"errors": result.errors,
}
class ToolSearchTool(Tool):
"""工具搜索工具
Agent 可以通过关键词搜索扩展工具的完整描述配合工具描述分层注入使用
核心工具全量注入 prompt扩展工具只注入名称+一行描述Agent 通过此工具
按需获取扩展工具的完整参数说明
底层使用 :class:`ToolSearchIndex`BM25 算法进行关键词匹配搜索
Usage::
from agentkit.tools.search import ToolSearchIndex
from agentkit.tools.builtin import ToolSearchTool
index = ToolSearchIndex(extended_tools)
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="read file")
"""
def __init__(
self,
search_index: "ToolSearchIndex",
name: str = "tool_search",
description: str = (
"Search for available tools by keyword. "
"Returns full descriptions (name, description, parameters) of matching tools. "
"Use this when you need details about a tool that was only listed by name."
),
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
top_k: int = 5,
):
if top_k < 1:
raise ValueError(f"top_k must be >= 1, got {top_k}")
schema = input_schema or {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search keywords to find relevant tools "
"(e.g. 'read file', 'web search', 'run tests')."
),
},
},
"required": ["query"],
}
super().__init__(
name=name,
description=description,
input_schema=schema,
output_schema=output_schema,
version=version,
tags=tags or ["tool", "search", "meta"],
)
self._search_index = search_index
self._top_k = top_k
async def execute(self, **kwargs) -> dict:
"""执行工具搜索。
Args:
query: 搜索关键词
Returns:
包含 ``query````count`` ``results``每个工具的完整描述的字典
"""
query = str(kwargs.get("query", "")).strip()
if not query:
return {
"error": "query parameter is required",
"results": [],
}
results = self._search_index.search(query, top_k=self._top_k)
if not results:
return {
"query": query,
"count": 0,
"results": [],
"message": f"No tools matched '{query}'.",
}
return {
"query": query,
"count": len(results),
"results": [self._format_tool_full(tool) for tool in results],
}
@staticmethod
def _format_tool_full(tool: Tool) -> dict[str, Any]:
"""Format a tool's full description for the LLM."""
return {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema or {"type": "object", "properties": {}},
}

View File

@ -0,0 +1,182 @@
"""Unit tests for SqliteConversationStore."""
from __future__ import annotations
from pathlib import Path
import pytest
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
@pytest.fixture
def db_path(tmp_path: Path) -> str:
"""Return a temporary database path."""
return str(tmp_path / "test_conversations.db")
@pytest.fixture
async def store(db_path: str) -> SqliteConversationStore:
"""Create a SqliteConversationStore with a temporary database."""
s = SqliteConversationStore(db_path=db_path)
yield s
await s._close_db()
# ---------------------------------------------------------------------------
# Basic CRUD
# ---------------------------------------------------------------------------
class TestBasicCRUD:
async def test_create_conversation(self, store: SqliteConversationStore) -> None:
conv = await store.get_or_create()
assert conv.id
assert conv.created_at is not None
assert conv.updated_at is not None
assert conv.messages == []
async def test_create_conversation_with_id(self, store: SqliteConversationStore) -> None:
conv = await store.get_or_create("my-conv-id")
assert conv.id == "my-conv-id"
async def test_get_or_create_returns_existing(self, store: SqliteConversationStore) -> None:
await store.get_or_create("reuse-id")
await store.add_message("reuse-id", "user", "hello")
conv2 = await store.get_or_create("reuse-id")
assert conv2.id == "reuse-id"
# In-memory cache should have the message
assert len(conv2.messages) == 1
async def test_add_message(self, store: SqliteConversationStore) -> None:
await store.get_or_create("msg-test")
msg = await store.add_message("msg-test", "user", "Hello world")
assert msg.role == "user"
assert msg.content == "Hello world"
assert msg.metadata == {}
async def test_add_message_with_metadata(self, store: SqliteConversationStore) -> None:
await store.get_or_create("meta-test")
msg = await store.add_message("meta-test", "assistant", "Hi", {"key": "value"})
assert msg.metadata == {"key": "value"}
async def test_add_message_nonexistent_conversation_raises(
self, store: SqliteConversationStore
) -> None:
with pytest.raises(KeyError, match="not found"):
await store.add_message("nonexistent", "user", "hello")
async def test_get_history(self, store: SqliteConversationStore) -> None:
await store.get_or_create("hist-test")
await store.add_message("hist-test", "user", "msg1")
await store.add_message("hist-test", "assistant", "msg2")
await store.add_message("hist-test", "user", "msg3")
history = await store.get_history("hist-test")
assert len(history) == 3
assert history[0].content == "msg1"
assert history[1].content == "msg2"
assert history[2].content == "msg3"
async def test_get_history_with_limit(self, store: SqliteConversationStore) -> None:
await store.get_or_create("limit-test")
for i in range(10):
await store.add_message("limit-test", "user", f"msg{i}")
history = await store.get_history("limit-test", limit=3)
assert len(history) == 3
# Should return the last 3 messages
assert history[0].content == "msg7"
assert history[1].content == "msg8"
assert history[2].content == "msg9"
async def test_get_history_empty(self, store: SqliteConversationStore) -> None:
history = await store.get_history("nonexistent")
assert history == []
# ---------------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------------
class TestPersistence:
async def test_data_survives_store_recreation(self, db_path: str) -> None:
"""Create store, add data, create new store with same DB path, verify data survives."""
store1 = SqliteConversationStore(db_path=db_path)
await store1.get_or_create("persist-conv")
await store1.add_message("persist-conv", "user", "persistent message")
await store1._close_db()
store2 = SqliteConversationStore(db_path=db_path)
history = await store2.get_history("persist-conv")
assert len(history) == 1
assert history[0].content == "persistent message"
assert history[0].role == "user"
await store2._close_db()
async def test_conversations_survive_recreation(self, db_path: str) -> None:
store1 = SqliteConversationStore(db_path=db_path)
await store1.get_or_create("conv-a")
await store1.get_or_create("conv-b")
await store1._close_db()
store2 = SqliteConversationStore(db_path=db_path)
convs = await store2.list_conversations()
ids = {c.id for c in convs}
assert "conv-a" in ids
assert "conv-b" in ids
await store2._close_db()
# ---------------------------------------------------------------------------
# list_conversations
# ---------------------------------------------------------------------------
class TestListConversations:
async def test_list_conversations_ordering(self, store: SqliteConversationStore) -> None:
"""Most recently updated conversation should appear first."""
await store.get_or_create("conv-first")
await store.get_or_create("conv-second")
# Update conv1 by adding a message
await store.add_message("conv-first", "user", "update")
convs = await store.list_conversations()
assert len(convs) == 2
# conv1 was updated more recently
assert convs[0].id == "conv-first"
async def test_list_conversations_limit(self, store: SqliteConversationStore) -> None:
for i in range(5):
await store.get_or_create(f"conv-{i}")
convs = await store.list_conversations(limit=3)
assert len(convs) == 3
# ---------------------------------------------------------------------------
# LRU cache eviction
# ---------------------------------------------------------------------------
class TestLRUCache:
async def test_cache_eviction(self, db_path: str) -> None:
"""Cache should evict oldest entries when over limit (data still in SQLite)."""
store = SqliteConversationStore(db_path=db_path, max_conversations=3)
for i in range(5):
await store.get_or_create(f"evict-{i}")
# Cache should have at most 3 entries
assert len(store._cache) <= 3
# But all 5 conversations should be in SQLite
convs = await store.list_conversations(limit=10)
assert len(convs) == 5
await store._close_db()
# ---------------------------------------------------------------------------
# restore_from_store (no-op)
# ---------------------------------------------------------------------------
class TestRestoreFromStore:
async def test_restore_is_noop(self, store: SqliteConversationStore) -> None:
"""restore_from_store should be a no-op for SQLite store."""
# Should not raise
await store.restore_from_store()

View File

@ -0,0 +1,204 @@
"""Tests for SpecManager — Spec 文档管理器"""
from __future__ import annotations
import time
from pathlib import Path
import pytest
from agentkit.core.spec_manager import Spec, SpecManager, SpecStep
@pytest.fixture
def specs_dir(tmp_path: Path) -> str:
"""Provide a temporary directory for spec files."""
return str(tmp_path / "specs")
@pytest.fixture
def mgr(specs_dir: str) -> SpecManager:
"""Create a SpecManager with a temporary directory."""
return SpecManager(specs_dir=specs_dir)
def make_spec(spec_id: str = "test-spec", goal: str = "test goal") -> Spec:
"""Create a test Spec."""
return Spec(
spec_id=spec_id,
goal=goal,
steps=[
SpecStep(step_id="s1", name="Step 1", description="First step"),
SpecStep(step_id="s2", name="Step 2", description="Second step", dependencies=["s1"]),
],
)
class TestSpecManagerCreateAndGet:
"""Test create and get a spec."""
def test_create_and_get(self, mgr: SpecManager):
spec = make_spec()
path = mgr.create(spec)
assert path.exists()
loaded = mgr.get(spec.spec_id)
assert loaded is not None
assert loaded.spec_id == spec.spec_id
assert loaded.goal == spec.goal
assert len(loaded.steps) == 2
assert loaded.steps[0].step_id == "s1"
assert loaded.steps[1].dependencies == ["s1"]
def test_create_writes_yaml_file(self, mgr: SpecManager, specs_dir: str):
spec = make_spec(spec_id="yaml-test")
mgr.create(spec)
yaml_path = Path(specs_dir) / "yaml-test.yaml"
assert yaml_path.exists()
def test_get_returns_from_cache(self, mgr: SpecManager):
spec = make_spec(spec_id="cached")
mgr.create(spec)
# Second get should hit cache
loaded = mgr.get("cached")
assert loaded is not None
assert loaded.spec_id == "cached"
class TestSpecManagerUpdate:
"""Test update spec fields."""
def test_update_goal(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
updated = mgr.update(spec.spec_id, goal="new goal")
assert updated is not None
assert updated.goal == "new goal"
def test_update_steps(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
new_steps = [
{"step_id": "s1", "name": "Step 1 Updated", "description": "Updated first step"},
]
updated = mgr.update(spec.spec_id, steps=new_steps)
assert updated is not None
assert len(updated.steps) == 1
assert updated.steps[0].name == "Step 1 Updated"
def test_update_metadata(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
updated = mgr.update(spec.spec_id, metadata={"key": "value"})
assert updated is not None
assert updated.metadata == {"key": "value"}
def test_update_nonexistent_returns_none(self, mgr: SpecManager):
result = mgr.update("nonexistent", goal="x")
assert result is None
class TestSpecManagerConfirm:
"""Test confirm sets status and confirmed_at."""
def test_confirm_sets_status_and_timestamp(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
assert spec.status == "draft"
assert spec.confirmed_at is None
confirmed = mgr.confirm(spec.spec_id)
assert confirmed is not None
assert confirmed.status == "confirmed"
assert confirmed.confirmed_at is not None
def test_confirm_marks_pending_steps_as_confirmed(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
confirmed = mgr.confirm(spec.spec_id)
assert confirmed is not None
for step in confirmed.steps:
assert step.status == "confirmed"
def test_confirm_nonexistent_returns_none(self, mgr: SpecManager):
result = mgr.confirm("nonexistent")
assert result is None
class TestSpecManagerList:
"""Test list_specs returns specs sorted by created_at desc."""
def test_list_specs_sorted_by_created_at_desc(self, mgr: SpecManager):
spec_a = make_spec(spec_id="spec-a", goal="A")
mgr.create(spec_a)
# Small delay to ensure different timestamps
time.sleep(0.01)
spec_b = make_spec(spec_id="spec-b", goal="B")
mgr.create(spec_b)
specs = mgr.list_specs()
assert len(specs) == 2
# Most recent first
assert specs[0].spec_id == "spec-b"
assert specs[1].spec_id == "spec-a"
def test_list_specs_filter_by_status(self, mgr: SpecManager):
spec_a = make_spec(spec_id="draft-spec")
mgr.create(spec_a)
spec_b = make_spec(spec_id="confirmed-spec")
mgr.create(spec_b)
mgr.confirm(spec_b.spec_id)
draft_specs = mgr.list_specs(status="draft")
assert len(draft_specs) == 1
assert draft_specs[0].spec_id == "draft-spec"
confirmed_specs = mgr.list_specs(status="confirmed")
assert len(confirmed_specs) == 1
assert confirmed_specs[0].spec_id == "confirmed-spec"
def test_list_specs_empty(self, mgr: SpecManager):
specs = mgr.list_specs()
assert specs == []
class TestSpecManagerDelete:
"""Test delete removes the spec."""
def test_delete_removes_spec(self, mgr: SpecManager):
spec = make_spec()
mgr.create(spec)
assert mgr.get(spec.spec_id) is not None
result = mgr.delete(spec.spec_id)
assert result is True
assert mgr.get(spec.spec_id) is None
def test_delete_removes_yaml_file(self, mgr: SpecManager, specs_dir: str):
spec = make_spec(spec_id="delete-me")
mgr.create(spec)
yaml_path = Path(specs_dir) / "delete-me.yaml"
assert yaml_path.exists()
mgr.delete("delete-me")
assert not yaml_path.exists()
def test_delete_nonexistent_returns_false(self, mgr: SpecManager):
result = mgr.delete("nonexistent")
assert result is False
class TestSpecManagerGetNonExistent:
"""Test get non-existent spec returns None."""
def test_get_nonexistent_returns_none(self, mgr: SpecManager):
result = mgr.get("does-not-exist")
assert result is None

View File

@ -0,0 +1,136 @@
"""VerificationLoop 单元测试"""
from __future__ import annotations
import asyncio
import pytest
from agentkit.core.verification_loop import VerificationLoop, VerificationResult
class TestVerify:
"""verify() 方法测试"""
@pytest.mark.asyncio
async def test_verify_success(self) -> None:
"""成功命令返回 passed=True"""
loop = VerificationLoop(commands=["echo ok"], timeout=10.0)
result = await loop.verify()
assert isinstance(result, VerificationResult)
assert result.passed is True
assert result.attempts == 1
assert "ok" in result.test_output
assert result.errors == []
@pytest.mark.asyncio
async def test_verify_failure(self) -> None:
"""失败命令返回 passed=False"""
loop = VerificationLoop(commands=["false"], timeout=10.0)
result = await loop.verify()
assert result.passed is False
assert result.attempts == 1
assert len(result.errors) > 0
@pytest.mark.asyncio
async def test_verify_timeout(self) -> None:
"""超时命令返回 passed=False"""
loop = VerificationLoop(commands=["sleep 10"], timeout=0.5)
result = await loop.verify()
assert result.passed is False
assert any("timed out" in e for e in result.errors)
@pytest.mark.asyncio
async def test_verify_command_not_found(self) -> None:
"""不存在的命令返回 passed=False"""
loop = VerificationLoop(commands=["nonexistent_command_xyz"], timeout=5.0)
result = await loop.verify()
assert result.passed is False
assert len(result.errors) > 0
@pytest.mark.asyncio
async def test_verify_multiple_commands_partial_failure(self) -> None:
"""部分命令失败时整体返回 passed=False"""
loop = VerificationLoop(commands=["echo ok", "false"], timeout=10.0)
result = await loop.verify()
assert result.passed is False
assert len(result.errors) == 1
@pytest.mark.asyncio
async def test_verify_default_commands(self) -> None:
"""默认命令为 pytest 和 ruff check"""
loop = VerificationLoop()
assert loop._commands == ["pytest -x -q", "ruff check src/"]
class TestVerifyAndRetry:
"""verify_and_retry() 方法测试"""
@pytest.mark.asyncio
async def test_retry_no_fix_callback(self) -> None:
"""无 fix_callback 时重试指定次数"""
loop = VerificationLoop(commands=["false"], max_retries=2, timeout=5.0)
result = await loop.verify_and_retry()
assert result.passed is False
assert result.attempts == 3 # 1 initial + 2 retries
@pytest.mark.asyncio
async def test_max_retries_respected(self) -> None:
"""max_retries=0 时不重试"""
loop = VerificationLoop(commands=["false"], max_retries=0, timeout=5.0)
result = await loop.verify_and_retry()
assert result.passed is False
assert result.attempts == 1
@pytest.mark.asyncio
async def test_retry_with_fix_callback(self) -> None:
"""fix_callback 被调用并接收 errors 和 test_output"""
call_count = 0
received_args: list[tuple[list[str], str]] = []
async def fix_cb(errors: list[str], test_output: str) -> None:
nonlocal call_count
call_count += 1
received_args.append((errors, test_output))
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
result = await loop.verify_and_retry(fix_callback=fix_cb)
assert result.passed is False
assert call_count == 1
assert len(received_args) == 1
assert len(received_args[0][0]) > 0 # errors
assert isinstance(received_args[0][1], str) # test_output
@pytest.mark.asyncio
async def test_retry_succeeds_after_fix(self) -> None:
"""fix_callback 修复后验证成功"""
attempt = 0
async def fix_cb(errors: list[str], test_output: str) -> None:
pass # Simulate fix applied
# Use a command that always fails — but test that the retry mechanism works
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
result = await loop.verify_and_retry(fix_callback=fix_cb)
# false always fails, so result should still be False
assert result.passed is False
assert result.attempts == 2
@pytest.mark.asyncio
async def test_fix_callback_exception_handled(self) -> None:
"""fix_callback 抛出异常时不影响重试"""
async def bad_fix_cb(errors: list[str], test_output: str) -> None:
raise RuntimeError("fix failed!")
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
result = await loop.verify_and_retry(fix_callback=bad_fix_cb)
assert result.passed is False
assert result.attempts == 2
@pytest.mark.asyncio
async def test_verify_and_retry_success_first_try(self) -> None:
"""首次验证成功时不重试"""
loop = VerificationLoop(commands=["echo ok"], max_retries=3, timeout=10.0)
result = await loop.verify_and_retry()
assert result.passed is True
assert result.attempts == 1

View File

@ -12,10 +12,8 @@ from fastapi.testclient import TestClient
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.server.app import create_app
from agentkit.server.routes.portal import (
CAPABILITY_CATEGORIES,
ConversationStore,
)
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
from agentkit.server.routes.portal import CAPABILITY_CATEGORIES
from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
@ -84,81 +82,90 @@ def _register_skill(registry: SkillRegistry, name: str = "chat_skill", **kwargs)
class TestConversationStore:
def test_get_or_create_new(self):
store = ConversationStore()
conv = store.get_or_create()
"""Tests for SqliteConversationStore (async, in-memory DB)."""
@pytest.fixture
def store(self, tmp_path):
return SqliteConversationStore(db_path=str(tmp_path / "test.db"))
@pytest.mark.asyncio
async def test_get_or_create_new(self, store):
conv = await store.get_or_create()
assert conv.id is not None
assert conv.messages == []
def test_get_or_create_with_id(self):
store = ConversationStore()
conv = store.get_or_create("test-id-123")
@pytest.mark.asyncio
async def test_get_or_create_with_id(self, store):
conv = await store.get_or_create("test-id-123")
assert conv.id == "test-id-123"
def test_get_or_create_reuse(self):
store = ConversationStore()
store.get_or_create("reuse-id")
store.add_message("reuse-id", "user", "hello")
conv2 = store.get_or_create("reuse-id")
@pytest.mark.asyncio
async def test_get_or_create_reuse(self, store):
await store.get_or_create("reuse-id")
await store.add_message("reuse-id", "user", "hello")
conv2 = await store.get_or_create("reuse-id")
assert conv2.id == "reuse-id"
assert len(conv2.messages) == 1
def test_add_message(self):
store = ConversationStore()
conv = store.get_or_create("msg-id")
msg = store.add_message("msg-id", "user", "hello")
@pytest.mark.asyncio
async def test_add_message(self, store):
conv = await store.get_or_create("msg-id")
msg = await store.add_message("msg-id", "user", "hello")
assert msg.role == "user"
assert msg.content == "hello"
assert len(conv.messages) == 1
def test_add_message_not_found(self):
store = ConversationStore()
@pytest.mark.asyncio
async def test_add_message_not_found(self, store):
with pytest.raises(KeyError):
store.add_message("nonexistent", "user", "hello")
await store.add_message("nonexistent", "user", "hello")
def test_get_history(self):
store = ConversationStore()
store.get_or_create("hist-id")
store.add_message("hist-id", "user", "msg1")
store.add_message("hist-id", "assistant", "msg2")
history = store.get_history("hist-id")
@pytest.mark.asyncio
async def test_get_history(self, store):
await store.get_or_create("hist-id")
await store.add_message("hist-id", "user", "msg1")
await store.add_message("hist-id", "assistant", "msg2")
history = await store.get_history("hist-id")
assert len(history) == 2
assert history[0].role == "user"
assert history[1].role == "assistant"
def test_get_history_limit(self):
store = ConversationStore()
store.get_or_create("limit-id")
@pytest.mark.asyncio
async def test_get_history_limit(self, store):
await store.get_or_create("limit-id")
for i in range(10):
store.add_message("limit-id", "user", f"msg{i}")
history = store.get_history("limit-id", limit=3)
await store.add_message("limit-id", "user", f"msg{i}")
history = await store.get_history("limit-id", limit=3)
assert len(history) == 3
assert history[0].content == "msg7"
def test_get_history_nonexistent(self):
store = ConversationStore()
history = store.get_history("no-such-id")
@pytest.mark.asyncio
async def test_get_history_nonexistent(self, store):
history = await store.get_history("no-such-id")
assert history == []
def test_list_conversations(self):
store = ConversationStore()
store.get_or_create("conv-a")
store.get_or_create("conv-b")
convs = store.list_conversations()
@pytest.mark.asyncio
async def test_list_conversations(self, store):
await store.get_or_create("conv-a")
await store.get_or_create("conv-b")
convs = await store.list_conversations()
assert len(convs) == 2
def test_list_conversations_limit(self):
store = ConversationStore()
@pytest.mark.asyncio
async def test_list_conversations_limit(self, store):
for i in range(5):
store.get_or_create(f"conv-{i}")
convs = store.list_conversations(limit=2)
await store.get_or_create(f"conv-{i}")
convs = await store.list_conversations(limit=2)
assert len(convs) == 2
def test_max_conversations_eviction(self):
store = ConversationStore(max_conversations=3)
@pytest.mark.asyncio
async def test_max_conversations_eviction(self, tmp_path):
store = SqliteConversationStore(
db_path=str(tmp_path / "evict.db"), max_conversations=3
)
for i in range(5):
store.get_or_create(f"evict-{i}")
assert len(store._conversations) <= 3
await store.get_or_create(f"evict-{i}")
assert len(store._cache) <= 3
# ---------------------------------------------------------------------------
@ -178,7 +185,7 @@ class TestPortalChat:
data = response.json()
assert data["conversation_id"] is not None
assert data["matched_skill"] == "chat_skill"
assert data["routing_method"] == "direct"
assert data["routing_method"] == "skill_prefix"
assert data["confidence"] == 1.0
assert data["status"] == "completed"
@ -196,12 +203,13 @@ class TestPortalChat:
assert data["conversation_id"] is not None
def test_chat_no_skills_available(self, client):
"""Greeting fast-path works even without skills (DIRECT_CHAT mode)."""
response = client.post(
"/api/v1/portal/chat",
json={"message": "hello"},
)
assert response.status_code == 400
assert "No skills available" in response.json()["detail"]
# Greeting regex fast-path: no skill needed, returns 200
assert response.status_code == 200
def test_chat_skill_not_found(self, client):
response = client.post(

View File

@ -6,6 +6,7 @@ Covers:
3. ConfigDrivenAgent compressor passthrough
"""
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
@ -114,8 +115,11 @@ class TestCreateAppCompression:
with patch("agentkit.core.compressor.create_compressor") as mock_create:
mock_create.return_value = None
# No server_config at all
app = create_app()
# No server_config at all — also prevent auto-discovery of agentkit.yaml
with patch.dict(os.environ, {"AGENTKIT_CONFIG_PATH": ""}, clear=False):
# Prevent CWD agentkit.yaml from being auto-loaded
with patch("os.path.exists", return_value=False):
app = create_app()
# create_compressor should not be called (no server_config)
mock_create.assert_not_called()
@ -184,6 +188,7 @@ class TestConfigDrivenAgentCompression:
agent._skill_config = skill_config
agent._prompt_template = None
agent._tools = []
agent._tool_registry = None
agent._memory_retriever = None
agent._compressor = mock_compressor
agent._evolution_enabled = False
@ -230,6 +235,7 @@ class TestConfigDrivenAgentCompression:
agent._skill_config = skill_config
agent._prompt_template = None
agent._tools = []
agent._tool_registry = None
agent._memory_retriever = None
agent._compressor = None
agent._evolution_enabled = False

View File

@ -13,7 +13,6 @@ from agentkit.telemetry.tracer import (
get_tracer,
init_telemetry,
)
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
@ -155,46 +154,6 @@ class TestGetTracer:
assert tracer1 is tracer2
# ── CostAwareRouter span 测试 ──────────────────────────────
class TestCostAwareRouterSpan:
"""CostAwareRouter 创建 span 并设置属性"""
@pytest.mark.asyncio
async def test_router_creates_span_with_attributes(self):
"""路由时创建 span 并设置 route.layer 和 route.target 属性"""
init_telemetry(TelemetryConfig(enabled=False))
tracer = get_tracer()
# 用 mock 替换 start_span 以验证调用
mock_span = MagicMock()
mock_span.__enter__ = MagicMock(return_value=mock_span)
mock_span.__exit__ = MagicMock(return_value=False)
mock_span.set_attribute = MagicMock()
with patch.object(tracer, "start_span", return_value=mock_span):
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=MagicMock(),
intent_router=MagicMock(),
default_tools=[],
default_system_prompt="You are helpful.",
)
tracer.start_span.assert_called_once_with("router.route")
# 验证 span 设置了 input.length 属性
mock_span.set_attribute.assert_any_call("input.length", len("你好"))
# 验证 span 设置了 route.layer 和 route.target
call_args_list = [
(call.args[0], call.args[1])
for call in mock_span.set_attribute.call_args_list
]
assert ("route.layer", "greeting") in call_args_list
assert ("route.target", "default") in call_args_list
# ── AlignmentGuard span 测试 ──────────────────────────────