feat: SQLite persistence, verification loop, spec-driven execution
Phase 2 of architecture optimization (U5/U6/U9):
- U5: SqliteConversationStore with WAL mode + LRU cache (1000 convs)
Replaces in-memory ConversationStore in portal.py
Data survives server restarts (ref: Codex Thread persistence)
- U6: VerificationLoop with verify/verify_and_retry
Default commands: pytest + ruff check
ReActEngine integration via verification_enabled flag
New run_tests tool for LLM to invoke verification
- U9: SpecManager for plan-as-contract (ref: Qoder Quest Mode)
Plans persisted to .agentkit/specs/{spec_id}.yaml
API: GET/PUT /api/v1/specs, POST /api/v1/specs/{id}/confirm
PlanExecEngine emits spec_created event after plan generation
Also fixes: portal skill_name routing, app.py SessionManager guard,
test_telemetry CostAwareRouter removal, test_compression_config fixture
This commit is contained in:
parent
5374bc8501
commit
200174c5c7
|
|
@ -0,0 +1,294 @@
|
|||
"""SQLite-backed conversation store with async support and in-memory LRU cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes (mirrors portal.py ChatMessage / Conversation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
id: str
|
||||
messages: list[ChatMessage] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
metadata TEXT DEFAULT '{}',
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conv_id ON messages(conversation_id);
|
||||
"""
|
||||
|
||||
|
||||
class SqliteConversationStore:
|
||||
"""SQLite-backed conversation store with an in-memory LRU cache.
|
||||
|
||||
Drop-in replacement for the in-memory ConversationStore in portal.py.
|
||||
Data is persisted to SQLite so it survives server restarts.
|
||||
An in-memory LRU cache of recent conversations provides fast access.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str | Path | None = None,
|
||||
max_conversations: int = 1000,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
db_path = os.path.expanduser("~/.agentkit/conversations.db")
|
||||
self._db_path = str(db_path)
|
||||
self._max = max_conversations
|
||||
self._cache: OrderedDict[str, Conversation] = OrderedDict()
|
||||
self._db: aiosqlite.Connection | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_db(self) -> aiosqlite.Connection:
|
||||
"""Lazily open the database and ensure schema exists."""
|
||||
if self._db is None:
|
||||
db_dir = os.path.dirname(self._db_path)
|
||||
if db_dir:
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
self._db = await aiosqlite.connect(self._db_path)
|
||||
self._db.row_factory = aiosqlite.Row
|
||||
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||
await self._db.executescript(_SCHEMA_SQL)
|
||||
await self._db.commit()
|
||||
return self._db
|
||||
|
||||
async def _close_db(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._db is not None:
|
||||
await self._db.close()
|
||||
self._db = None
|
||||
|
||||
@staticmethod
|
||||
def _dt_to_str(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
@staticmethod
|
||||
def _str_to_dt(s: str) -> datetime:
|
||||
return datetime.fromisoformat(s)
|
||||
|
||||
def _touch_cache(self, conv_id: str) -> None:
|
||||
"""Move conversation to the end of the LRU cache (most recently used)."""
|
||||
if conv_id in self._cache:
|
||||
self._cache.move_to_end(conv_id)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict oldest entry from cache if over limit (does NOT delete from SQLite)."""
|
||||
while len(self._cache) > self._max:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API (matches ConversationStore interface)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_or_create(self, conversation_id: str | None = None) -> Conversation:
|
||||
"""Get existing conversation or create a new one.
|
||||
|
||||
If the conversation is in cache, return it directly.
|
||||
Otherwise, try to load from SQLite. If not found, create new.
|
||||
Messages are loaded lazily (only when get_history is called).
|
||||
"""
|
||||
db = await self._ensure_db()
|
||||
|
||||
if conversation_id and conversation_id in self._cache:
|
||||
conv = self._cache[conversation_id]
|
||||
conv.updated_at = datetime.now(timezone.utc)
|
||||
self._touch_cache(conv.id)
|
||||
return conv
|
||||
|
||||
# Try loading from SQLite
|
||||
if conversation_id:
|
||||
cursor = await db.execute(
|
||||
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
conv = Conversation(
|
||||
id=row["id"],
|
||||
created_at=self._str_to_dt(row["created_at"]),
|
||||
updated_at=self._str_to_dt(row["updated_at"]),
|
||||
)
|
||||
self._cache[conv.id] = conv
|
||||
self._evict_if_needed()
|
||||
return conv
|
||||
|
||||
# Create new
|
||||
cid = conversation_id or str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
conv = Conversation(id=cid, created_at=now, updated_at=now)
|
||||
await db.execute(
|
||||
"INSERT INTO conversations (id, created_at, updated_at) VALUES (?, ?, ?)",
|
||||
(conv.id, self._dt_to_str(conv.created_at), self._dt_to_str(conv.updated_at)),
|
||||
)
|
||||
await db.commit()
|
||||
self._cache[conv.id] = conv
|
||||
self._evict_if_needed()
|
||||
return conv
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: dict | None = None,
|
||||
) -> ChatMessage:
|
||||
"""Add a message to a conversation (both in-memory cache and SQLite)."""
|
||||
db = await self._ensure_db()
|
||||
|
||||
# Ensure conversation exists in cache
|
||||
if conversation_id not in self._cache:
|
||||
# Try loading from SQLite
|
||||
cursor = await db.execute(
|
||||
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
conv = Conversation(
|
||||
id=row["id"],
|
||||
created_at=self._str_to_dt(row["created_at"]),
|
||||
updated_at=self._str_to_dt(row["updated_at"]),
|
||||
)
|
||||
self._cache[conv.id] = conv
|
||||
else:
|
||||
raise KeyError(f"Conversation '{conversation_id}' not found")
|
||||
|
||||
conv = self._cache[conversation_id]
|
||||
msg = ChatMessage(
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
conv.messages.append(msg)
|
||||
conv.updated_at = datetime.now(timezone.utc)
|
||||
self._touch_cache(conv.id)
|
||||
|
||||
# Persist to SQLite
|
||||
await db.execute(
|
||||
"INSERT INTO messages (conversation_id, role, content, timestamp, metadata) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
self._dt_to_str(msg.timestamp),
|
||||
json.dumps(msg.metadata, default=str),
|
||||
),
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||
(self._dt_to_str(conv.updated_at), conversation_id),
|
||||
)
|
||||
await db.commit()
|
||||
return msg
|
||||
|
||||
async def get_history(self, conversation_id: str, limit: int = 50) -> list[ChatMessage]:
|
||||
"""Get recent messages for a conversation.
|
||||
|
||||
Loads from SQLite to ensure completeness (in-memory cache may only
|
||||
have a subset of messages after a restart).
|
||||
"""
|
||||
db = await self._ensure_db()
|
||||
|
||||
cursor = await db.execute(
|
||||
"SELECT role, content, timestamp, metadata FROM messages "
|
||||
"WHERE conversation_id = ? ORDER BY id DESC LIMIT ?",
|
||||
(conversation_id, limit),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
# Reverse because we fetched DESC but want chronological order
|
||||
messages: list[ChatMessage] = []
|
||||
for row in reversed(rows):
|
||||
meta: dict[str, Any] = {}
|
||||
try:
|
||||
meta = json.loads(row["metadata"]) if row["metadata"] else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=row["role"],
|
||||
content=row["content"],
|
||||
timestamp=self._str_to_dt(row["timestamp"]),
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
async def list_conversations(self, limit: int = 20) -> list[Conversation]:
|
||||
"""List recent conversations ordered by updated_at (most recent first)."""
|
||||
db = await self._ensure_db()
|
||||
|
||||
cursor = await db.execute(
|
||||
"SELECT id, created_at, updated_at FROM conversations "
|
||||
"ORDER BY updated_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
result: list[Conversation] = []
|
||||
for row in rows:
|
||||
conv_id = row["id"]
|
||||
# Use cached version if available (may have in-memory messages)
|
||||
if conv_id in self._cache:
|
||||
result.append(self._cache[conv_id])
|
||||
else:
|
||||
result.append(
|
||||
Conversation(
|
||||
id=conv_id,
|
||||
created_at=self._str_to_dt(row["created_at"]),
|
||||
updated_at=self._str_to_dt(row["updated_at"]),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
async def restore_from_store(
|
||||
self,
|
||||
max_sessions: int = 50,
|
||||
max_messages_per_session: int = 100,
|
||||
) -> None:
|
||||
"""No-op for SQLite store — data is already persisted in the database."""
|
||||
# Nothing to do; all data lives in SQLite and is loaded on demand.
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
"""Spec Manager — 执行计划规格文档管理器
|
||||
|
||||
将 PlanExecEngine 生成的执行计划持久化为 Spec 文档,
|
||||
用户可在执行前查看、编辑和确认 Spec,作为人与 AI 之间的契约。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecStep:
|
||||
"""A single step in a spec."""
|
||||
|
||||
step_id: str
|
||||
name: str
|
||||
description: str
|
||||
dependencies: list[str] = field(default_factory=list)
|
||||
status: str = "pending" # pending | confirmed | executing | completed | failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Spec:
|
||||
"""A specification document for a planned task."""
|
||||
|
||||
spec_id: str
|
||||
goal: str
|
||||
steps: list[SpecStep] = field(default_factory=list)
|
||||
status: str = "draft" # draft | confirmed | executing | completed | failed
|
||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
confirmed_at: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SpecManager:
|
||||
"""Manages Spec documents as first-class citizens.
|
||||
|
||||
Specs are persisted to .agentkit/specs/ directory as YAML files.
|
||||
Users can view, edit, and confirm specs before execution.
|
||||
"""
|
||||
|
||||
def __init__(self, specs_dir: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
specs_dir: Directory for spec files. Default: .agentkit/specs/
|
||||
"""
|
||||
self._specs_dir = Path(specs_dir or ".agentkit/specs")
|
||||
self._specs_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cache: dict[str, Spec] = {}
|
||||
|
||||
def create(self, spec: Spec) -> Path:
|
||||
"""Persist a Spec to disk. Returns the file path."""
|
||||
path = self._specs_dir / f"{spec.spec_id}.yaml"
|
||||
data = asdict(spec)
|
||||
path.write_text(yaml.dump(data, allow_unicode=True, default_flow_style=False), encoding="utf-8")
|
||||
self._cache[spec.spec_id] = spec
|
||||
logger.info(f"Spec created: {spec.spec_id} -> {path}")
|
||||
return path
|
||||
|
||||
def get(self, spec_id: str) -> Spec | None:
|
||||
"""Load a Spec by ID."""
|
||||
if spec_id in self._cache:
|
||||
return self._cache[spec_id]
|
||||
|
||||
path = self._specs_dir / f"{spec_id}.yaml"
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
spec = self._dict_to_spec(data)
|
||||
self._cache[spec_id] = spec
|
||||
return spec
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load spec {spec_id}: {e}")
|
||||
return None
|
||||
|
||||
def update(self, spec_id: str, **kwargs: Any) -> Spec | None:
|
||||
"""Update spec fields and persist."""
|
||||
spec = self.get(spec_id)
|
||||
if spec is None:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key == "steps" and isinstance(value, list):
|
||||
spec.steps = [self._dict_to_step(s) if isinstance(s, dict) else s for s in value]
|
||||
elif hasattr(spec, key):
|
||||
setattr(spec, key, value)
|
||||
|
||||
self.create(spec) # re-persist
|
||||
return spec
|
||||
|
||||
def confirm(self, spec_id: str) -> Spec | None:
|
||||
"""Mark a spec as confirmed (user approved execution)."""
|
||||
spec = self.get(spec_id)
|
||||
if spec is None:
|
||||
return None
|
||||
|
||||
spec.status = "confirmed"
|
||||
spec.confirmed_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Mark all steps as confirmed
|
||||
for step in spec.steps:
|
||||
if step.status == "pending":
|
||||
step.status = "confirmed"
|
||||
|
||||
self.create(spec) # re-persist
|
||||
logger.info(f"Spec confirmed: {spec_id}")
|
||||
return spec
|
||||
|
||||
def list_specs(self, status: str | None = None) -> list[Spec]:
|
||||
"""List all specs, optionally filtered by status. Sorted by created_at desc."""
|
||||
specs: list[Spec] = []
|
||||
for path in self._specs_dir.glob("*.yaml"):
|
||||
try:
|
||||
data = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
spec = self._dict_to_spec(data)
|
||||
if status is None or spec.status == status:
|
||||
specs.append(spec)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load spec from {path}: {e}")
|
||||
|
||||
specs.sort(key=lambda s: s.created_at, reverse=True)
|
||||
return specs
|
||||
|
||||
def delete(self, spec_id: str) -> bool:
|
||||
"""Delete a spec file."""
|
||||
path = self._specs_dir / f"{spec_id}.yaml"
|
||||
if not path.exists():
|
||||
return False
|
||||
|
||||
path.unlink()
|
||||
self._cache.pop(spec_id, None)
|
||||
logger.info(f"Spec deleted: {spec_id}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_spec(data: dict[str, Any]) -> Spec:
|
||||
"""Convert a dict to a Spec instance."""
|
||||
steps = [SpecManager._dict_to_step(s) for s in data.get("steps", [])]
|
||||
return Spec(
|
||||
spec_id=data["spec_id"],
|
||||
goal=data["goal"],
|
||||
steps=steps,
|
||||
status=data.get("status", "draft"),
|
||||
created_at=data.get("created_at", ""),
|
||||
confirmed_at=data.get("confirmed_at"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_step(data: dict[str, Any] | SpecStep) -> SpecStep:
|
||||
"""Convert a dict to a SpecStep instance."""
|
||||
if isinstance(data, SpecStep):
|
||||
return data
|
||||
return SpecStep(
|
||||
step_id=data["step_id"],
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
dependencies=data.get("dependencies", []),
|
||||
status=data.get("status", "pending"),
|
||||
)
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
"""验证循环 - 执行后运行项目测试验证结果
|
||||
|
||||
遵循 Codex Cloud 模式:execute → verify → retry on failure。
|
||||
默认验证命令:pytest, ruff check。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
"""验证循环运行结果"""
|
||||
|
||||
passed: bool
|
||||
attempts: int
|
||||
test_output: str
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class VerificationLoop:
|
||||
"""执行后运行项目测试验证结果
|
||||
|
||||
遵循 Codex Cloud 模式:execute → verify → retry on failure。
|
||||
默认验证命令:pytest, ruff check。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
commands: list[str] | None = None,
|
||||
max_retries: int = 2,
|
||||
working_dir: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
commands: 验证用的 Shell 命令列表。
|
||||
默认: ["pytest -x -q", "ruff check src/"]
|
||||
max_retries: 初始执行后的最大重试次数。
|
||||
working_dir: 运行命令的工作目录。
|
||||
timeout: 每个验证命令的超时时间(秒)。
|
||||
"""
|
||||
self._commands = commands or ["pytest -x -q", "ruff check src/"]
|
||||
self._max_retries = max_retries
|
||||
self._working_dir = working_dir
|
||||
self._timeout = timeout
|
||||
|
||||
async def verify(self) -> VerificationResult:
|
||||
"""运行验证命令并返回结果
|
||||
|
||||
依次执行每个命令,捕获 stdout/stderr。
|
||||
如果任何命令失败(非零退出码),标记为失败。
|
||||
|
||||
Returns:
|
||||
VerificationResult,所有命令成功时 passed=True。
|
||||
"""
|
||||
all_output: list[str] = []
|
||||
errors: list[str] = []
|
||||
passed = True
|
||||
|
||||
for cmd in self._commands:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
cwd=self._working_dir,
|
||||
)
|
||||
stdout, _ = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
await proc.wait()
|
||||
output = f"Command timed out after {self._timeout}s: {cmd}"
|
||||
all_output.append(output)
|
||||
errors.append(output)
|
||||
passed = False
|
||||
continue
|
||||
except Exception as e:
|
||||
output = f"Command failed to execute: {cmd}: {e}"
|
||||
all_output.append(output)
|
||||
errors.append(output)
|
||||
passed = False
|
||||
continue
|
||||
|
||||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
all_output.append(f"$ {cmd}\n{output}")
|
||||
|
||||
if proc.returncode != 0:
|
||||
passed = False
|
||||
errors.append(f"Command failed (exit {proc.returncode}): {cmd}\n{output}")
|
||||
|
||||
return VerificationResult(
|
||||
passed=passed,
|
||||
attempts=1,
|
||||
test_output="\n".join(all_output),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def verify_and_retry(
|
||||
self,
|
||||
fix_callback: Any = None,
|
||||
) -> VerificationResult:
|
||||
"""运行验证,如果失败则调用 fix_callback 并重试
|
||||
|
||||
Args:
|
||||
fix_callback: 异步可调用对象,接收 (errors, test_output),
|
||||
尝试修复问题。在重试之间调用。
|
||||
|
||||
Returns:
|
||||
VerificationResult,包含最终验证状态。
|
||||
"""
|
||||
result = await self.verify()
|
||||
|
||||
retries = 0
|
||||
while not result.passed and retries < self._max_retries:
|
||||
retries += 1
|
||||
logger.info(
|
||||
"Verification failed (attempt %d/%d), %s",
|
||||
retries,
|
||||
self._max_retries,
|
||||
"calling fix_callback" if fix_callback else "retrying",
|
||||
)
|
||||
|
||||
if fix_callback is not None:
|
||||
try:
|
||||
await fix_callback(result.errors, result.test_output)
|
||||
except Exception as e:
|
||||
logger.warning("fix_callback raised an error: %s", e)
|
||||
|
||||
result = await self.verify()
|
||||
result.attempts = retries + 1
|
||||
|
||||
return result
|
||||
|
|
@ -22,13 +22,14 @@ from pydantic import BaseModel
|
|||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||
from agentkit.chat.simple_router import SimpleRouter
|
||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||
from agentkit.server.routes.evolution_dashboard import (
|
||||
_experiences as _dashboard_experiences,
|
||||
DashboardExperience,
|
||||
_broadcast_event as _broadcast_dashboard_event,
|
||||
)
|
||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
||||
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||
from agentkit.session.manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -103,7 +104,9 @@ class ConversationStore:
|
|||
conversations can be restored from SessionManager.
|
||||
"""
|
||||
|
||||
def __init__(self, max_conversations: int = 1000, session_manager: SessionManager | None = None):
|
||||
def __init__(
|
||||
self, max_conversations: int = 1000, session_manager: SessionManager | None = None
|
||||
):
|
||||
self._conversations: dict[str, Conversation] = {}
|
||||
self._max = max_conversations
|
||||
self._session_manager = session_manager
|
||||
|
|
@ -141,16 +144,16 @@ class ConversationStore:
|
|||
sid, limit=max_messages_per_session
|
||||
)
|
||||
for msg in messages:
|
||||
conv.messages.append(ChatMessage(
|
||||
conv.messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role.value,
|
||||
content=msg.content,
|
||||
timestamp=msg.created_at,
|
||||
metadata=msg.metadata,
|
||||
))
|
||||
self._conversations[sid] = conv
|
||||
logger.info(
|
||||
f"Restored {len(self._conversations)} conversations from SessionManager"
|
||||
)
|
||||
)
|
||||
self._conversations[sid] = conv
|
||||
logger.info(f"Restored {len(self._conversations)} conversations from SessionManager")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to restore conversations from SessionManager: {e}")
|
||||
|
||||
|
|
@ -217,7 +220,7 @@ class ConversationStore:
|
|||
|
||||
# Heartbeat timeout in seconds — 0 disables timeout (for testing)
|
||||
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
|
||||
_conversation_store = ConversationStore()
|
||||
_conversation_store = SqliteConversationStore()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# History injection helper — configurable limit + optional compression
|
||||
|
|
@ -227,7 +230,7 @@ _conversation_store = ConversationStore()
|
|||
_MAX_HISTORY_MESSAGES = 50
|
||||
|
||||
|
||||
def _build_history_messages(
|
||||
async def _build_history_messages(
|
||||
conv_id: str,
|
||||
limit: int = _MAX_HISTORY_MESSAGES,
|
||||
) -> list[dict]:
|
||||
|
|
@ -238,7 +241,7 @@ def _build_history_messages(
|
|||
which should be appended separately by the caller).
|
||||
"""
|
||||
try:
|
||||
history = _conversation_store.get_history(conv_id, limit=limit)
|
||||
history = await _conversation_store.get_history(conv_id, limit=limit)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
|
@ -342,14 +345,16 @@ class CapabilitiesResponse(BaseModel):
|
|||
|
||||
async def _resolve_for_chat(
|
||||
request: ChatRequest, req: Request
|
||||
) -> tuple[ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None]:
|
||||
"""Resolve agent and routing for a chat request via SimpleRouter.
|
||||
) -> tuple[
|
||||
ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None
|
||||
]:
|
||||
"""Resolve agent and routing for a chat request via RequestPreprocessor.
|
||||
|
||||
Returns (agent, routing_result, matched_skill_name, routing_method, confidence).
|
||||
"""
|
||||
pool = req.app.state.agent_pool
|
||||
skill_registry = req.app.state.skill_registry
|
||||
simple_router: SimpleRouter = req.app.state.simple_router
|
||||
request_preprocessor: RequestPreprocessor = req.app.state.request_preprocessor
|
||||
|
||||
matched_skill_name: str | None = None
|
||||
routing_method: str | None = None
|
||||
|
|
@ -362,8 +367,7 @@ async def _resolve_for_chat(
|
|||
if default_agent is not None:
|
||||
default_tools = default_agent.get_tools()
|
||||
default_system_prompt = (
|
||||
getattr(default_agent, "_system_prompt", None)
|
||||
or default_agent.get_system_prompt()
|
||||
getattr(default_agent, "_system_prompt", None) or default_agent.get_system_prompt()
|
||||
)
|
||||
else:
|
||||
all_skills = skill_registry.list_skills()
|
||||
|
|
@ -376,8 +380,19 @@ async def _resolve_for_chat(
|
|||
)
|
||||
break
|
||||
|
||||
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||
routing_result = await simple_router.route(
|
||||
# If skill_name is explicitly provided in the request, use it directly
|
||||
if request.skill_name:
|
||||
routing_result = await request_preprocessor.preprocess(
|
||||
content=f"@skill:{request.skill_name} {request.message}",
|
||||
skill_registry=skill_registry,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
else:
|
||||
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||
routing_result = await request_preprocessor.preprocess(
|
||||
content=request.message,
|
||||
skill_registry=skill_registry,
|
||||
default_tools=default_tools,
|
||||
|
|
@ -413,11 +428,19 @@ async def _resolve_for_chat(
|
|||
|
||||
@router.post("/portal/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||
"""Send a chat message and get a response with CostAwareRouter routing."""
|
||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
||||
"""Send a chat message and get a response with RequestPreprocessor routing."""
|
||||
# If skill_name is explicitly requested but not found, return 404
|
||||
if request.skill_name:
|
||||
skill_registry = req.app.state.skill_registry
|
||||
if not skill_registry.has_skill(request.skill_name):
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{request.skill_name}' not found")
|
||||
|
||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
|
||||
request, req
|
||||
)
|
||||
|
||||
# Create or reuse conversation
|
||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
||||
conv = await _conversation_store.get_or_create(request.conversation_id)
|
||||
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||
|
||||
llm_gateway = req.app.state.llm_gateway
|
||||
|
|
@ -432,7 +455,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
|||
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||
chat_messages.append({"role": "user", "content": request.message})
|
||||
# Inject conversation history
|
||||
history_msgs = _build_history_messages(conv.id)
|
||||
history_msgs = await _build_history_messages(conv.id)
|
||||
for hm in history_msgs:
|
||||
chat_messages.insert(-1, hm)
|
||||
response = await llm_gateway.chat(
|
||||
|
|
@ -467,7 +490,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
|||
|
||||
messages = [{"role": "user", "content": request.message}]
|
||||
# Inject conversation history
|
||||
history_msgs = _build_history_messages(conv.id)
|
||||
history_msgs = await _build_history_messages(conv.id)
|
||||
for hm in reversed(history_msgs):
|
||||
messages.insert(0, hm)
|
||||
tools = agent.get_tools()
|
||||
|
|
@ -490,7 +513,9 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
|||
except Exception as e:
|
||||
response_text = f"执行出错: {e}"
|
||||
else:
|
||||
response_text = _ensure_non_empty("".join(collected_output) if collected_output else None)
|
||||
response_text = _ensure_non_empty(
|
||||
"".join(collected_output) if collected_output else None
|
||||
)
|
||||
|
||||
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||
|
||||
|
|
@ -508,13 +533,15 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
|||
|
||||
@router.post("/portal/chat/stream")
|
||||
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||
"""Stream chat responses via SSE with CostAwareRouter routing."""
|
||||
"""Stream chat responses via SSE with RequestPreprocessor routing."""
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
|
||||
request, req
|
||||
)
|
||||
|
||||
# Create or reuse conversation
|
||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
||||
conv = await _conversation_store.get_or_create(request.conversation_id)
|
||||
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||
|
||||
llm_gateway = req.app.state.llm_gateway
|
||||
|
|
@ -532,13 +559,16 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
|||
),
|
||||
}
|
||||
|
||||
if routing_result is not None and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||
if (
|
||||
routing_result is not None
|
||||
and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
):
|
||||
# DIRECT_CHAT: direct LLM call, no ReAct loop
|
||||
chat_messages = []
|
||||
if routing_result.system_prompt:
|
||||
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||
chat_messages.append({"role": "user", "content": request.message})
|
||||
history_msgs = _build_history_messages(conv.id)
|
||||
history_msgs = await _build_history_messages(conv.id)
|
||||
for hm in history_msgs:
|
||||
chat_messages.insert(-1, hm)
|
||||
response = await llm_gateway.chat(
|
||||
|
|
@ -552,7 +582,11 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
|||
yield {
|
||||
"event": "final_answer",
|
||||
"data": json.dumps(
|
||||
{"step": 0, "data": {"output": response_text}, "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
{
|
||||
"step": 0,
|
||||
"data": {"output": response_text},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
),
|
||||
}
|
||||
else:
|
||||
|
|
@ -654,7 +688,7 @@ async def get_capabilities(req: Request, _auth: None = Depends(_verify_api_key))
|
|||
@router.get("/portal/conversations")
|
||||
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
|
||||
"""List recent conversations."""
|
||||
convs = _conversation_store.list_conversations(limit=limit)
|
||||
convs = await _conversation_store.list_conversations(limit=limit)
|
||||
return [
|
||||
{
|
||||
"id": c.id,
|
||||
|
|
@ -679,11 +713,11 @@ def _derive_conversation_title(conv: Conversation) -> str:
|
|||
async def get_conversation(
|
||||
conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)
|
||||
):
|
||||
"""Get conversation history, with fallback to SessionManager for persisted data."""
|
||||
# Try in-memory first
|
||||
if conversation_id in _conversation_store._conversations:
|
||||
conv = _conversation_store._conversations[conversation_id]
|
||||
history = _conversation_store.get_history(conversation_id, limit=limit)
|
||||
"""Get conversation history from SQLite-backed store."""
|
||||
history = await _conversation_store.get_history(conversation_id, limit=limit)
|
||||
if not history:
|
||||
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
|
||||
conv = await _conversation_store.get_or_create(conversation_id)
|
||||
return {
|
||||
"id": conv.id,
|
||||
"title": _derive_conversation_title(conv),
|
||||
|
|
@ -701,34 +735,6 @@ async def get_conversation(
|
|||
"updated_at": conv.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
# Fallback: load from SessionManager (persistent store)
|
||||
sm = _conversation_store._session_manager
|
||||
if sm is not None:
|
||||
try:
|
||||
session = await sm.get_session(conversation_id)
|
||||
if session is not None:
|
||||
messages = await sm.get_messages(conversation_id, limit=limit)
|
||||
return {
|
||||
"id": session.session_id,
|
||||
"title": _derive_title_from_messages(messages),
|
||||
"messages": [
|
||||
{
|
||||
"id": f"{session.session_id}-{i}",
|
||||
"role": m.role.value,
|
||||
"content": m.content,
|
||||
"timestamp": m.created_at.isoformat(),
|
||||
"metadata": m.metadata,
|
||||
}
|
||||
for i, m in enumerate(messages)
|
||||
],
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load conversation from SessionManager: {e}")
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
|
||||
|
||||
|
||||
def _derive_title_from_messages(messages: list) -> str:
|
||||
"""Derive title from a list of Message objects (SessionManager format)."""
|
||||
|
|
@ -807,7 +813,7 @@ async def portal_websocket(websocket: WebSocket):
|
|||
# Create conversation on first message (not on connect)
|
||||
if conv is None:
|
||||
conv_id = msg.get("conversation_id")
|
||||
conv = _conversation_store.get_or_create(conv_id)
|
||||
conv = await _conversation_store.get_or_create(conv_id)
|
||||
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
|
||||
|
||||
# Add user message to conversation
|
||||
|
|
@ -841,15 +847,15 @@ async def portal_websocket(websocket: WebSocket):
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to record experience: {e}")
|
||||
|
||||
# Unified routing via SimpleRouter (minimal: @skill prefix + greeting regex + REACT)
|
||||
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||
pool = websocket.app.state.agent_pool
|
||||
skill_registry = websocket.app.state.skill_registry
|
||||
llm_gateway = websocket.app.state.llm_gateway
|
||||
simple_router: SimpleRouter = websocket.app.state.simple_router
|
||||
request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
|
||||
|
||||
all_skills = skill_registry.list_skills()
|
||||
|
||||
# Get default tools for SimpleRouter routing
|
||||
# Get default tools for RequestPreprocessor
|
||||
default_tools = []
|
||||
default_system_prompt = None
|
||||
default_agent = pool.get_agent("default")
|
||||
|
|
@ -869,8 +875,8 @@ async def portal_websocket(websocket: WebSocket):
|
|||
)
|
||||
break
|
||||
|
||||
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||
routing_result = await simple_router.route(
|
||||
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||
routing_result = await request_preprocessor.preprocess(
|
||||
content=message_text,
|
||||
skill_registry=skill_registry,
|
||||
default_tools=default_tools,
|
||||
|
|
@ -901,7 +907,7 @@ async def portal_websocket(websocket: WebSocket):
|
|||
)
|
||||
chat_messages.append({"role": "user", "content": message_text})
|
||||
# Inject conversation history for context continuity
|
||||
history_msgs = _build_history_messages(conv.id)
|
||||
history_msgs = await _build_history_messages(conv.id)
|
||||
for hm in history_msgs:
|
||||
chat_messages.insert(-1, hm)
|
||||
response = await llm_gateway.chat(
|
||||
|
|
@ -960,7 +966,7 @@ async def portal_websocket(websocket: WebSocket):
|
|||
)
|
||||
chat_messages.append({"role": "user", "content": message_text})
|
||||
try:
|
||||
history = _conversation_store.get_history(conv.id, limit=20)
|
||||
history = await _conversation_store.get_history(conv.id, limit=20)
|
||||
for hist_msg in history[:-1]:
|
||||
if hist_msg.role in ("user", "assistant"):
|
||||
chat_messages.insert(
|
||||
|
|
@ -1009,7 +1015,7 @@ async def portal_websocket(websocket: WebSocket):
|
|||
|
||||
messages = [{"role": "user", "content": message_text}]
|
||||
# Inject conversation history for context continuity
|
||||
history_msgs = _build_history_messages(conv.id)
|
||||
history_msgs = await _build_history_messages(conv.id)
|
||||
for hm in reversed(history_msgs):
|
||||
messages.insert(0, hm)
|
||||
tools = agent.get_tools()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
|
@ -9,6 +10,7 @@ from pydantic import BaseModel
|
|||
from typing import Any
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
from agentkit.core.spec_manager import SpecManager
|
||||
|
||||
router = APIRouter(tags=["tasks"])
|
||||
|
||||
|
|
@ -350,3 +352,71 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
|
|||
}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spec endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_spec_manager(req: Request) -> SpecManager:
|
||||
"""Get or create SpecManager from app state."""
|
||||
if not hasattr(req.app.state, "spec_manager") or req.app.state.spec_manager is None:
|
||||
req.app.state.spec_manager = SpecManager()
|
||||
return req.app.state.spec_manager
|
||||
|
||||
|
||||
class UpdateSpecRequest(BaseModel):
|
||||
"""Request body for updating a spec."""
|
||||
|
||||
goal: str | None = None
|
||||
steps: list[dict[str, Any]] | None = None
|
||||
status: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
async def list_specs(status: str | None = None, req: Request = None):
|
||||
"""List all specs, optionally filtered by status."""
|
||||
mgr = _get_spec_manager(req)
|
||||
specs = mgr.list_specs(status=status)
|
||||
return [asdict(s) for s in specs]
|
||||
|
||||
|
||||
@router.get("/specs/{spec_id}")
|
||||
async def get_spec(spec_id: str, req: Request):
|
||||
"""Get a specific spec by ID."""
|
||||
mgr = _get_spec_manager(req)
|
||||
spec = mgr.get(spec_id)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||
return asdict(spec)
|
||||
|
||||
|
||||
@router.put("/specs/{spec_id}")
|
||||
async def update_spec(spec_id: str, body: UpdateSpecRequest, req: Request):
|
||||
"""Update a spec (e.g., edit steps, change status)."""
|
||||
mgr = _get_spec_manager(req)
|
||||
kwargs: dict[str, Any] = {}
|
||||
if body.goal is not None:
|
||||
kwargs["goal"] = body.goal
|
||||
if body.steps is not None:
|
||||
kwargs["steps"] = body.steps
|
||||
if body.status is not None:
|
||||
kwargs["status"] = body.status
|
||||
if body.metadata is not None:
|
||||
kwargs["metadata"] = body.metadata
|
||||
spec = mgr.update(spec_id, **kwargs)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||
return asdict(spec)
|
||||
|
||||
|
||||
@router.post("/specs/{spec_id}/confirm")
|
||||
async def confirm_spec(spec_id: str, req: Request):
|
||||
"""Confirm a spec for execution."""
|
||||
mgr = _get_spec_manager(req)
|
||||
spec = mgr.confirm(spec_id)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||
return asdict(spec)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,185 @@
|
|||
"""内置工具 - 开箱即用的常用工具集合"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunTestsTool(Tool):
|
||||
"""运行项目测试验证代码变更
|
||||
|
||||
执行 pytest 和 linting 命令来验证代码变更的正确性。
|
||||
内部使用 VerificationLoop 实现验证逻辑。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "run_tests",
|
||||
description: str = "Run project tests to verify code changes. Executes pytest and linting commands.",
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
working_dir: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema or self._default_input_schema(),
|
||||
output_schema=output_schema,
|
||||
version=version,
|
||||
tags=tags or ["testing", "verification"],
|
||||
)
|
||||
self._working_dir = working_dir
|
||||
self._timeout = timeout
|
||||
self._max_retries = max_retries
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"commands": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Shell commands to run for verification. Default: ['pytest -x -q', 'ruff check src/']",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""执行测试验证
|
||||
|
||||
Args:
|
||||
commands: 可选的验证命令列表,默认使用 pytest 和 ruff check。
|
||||
|
||||
Returns:
|
||||
包含 passed, attempts, test_output, errors 的字典。
|
||||
"""
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
||||
commands: list[str] | None = kwargs.get("commands")
|
||||
loop = VerificationLoop(
|
||||
commands=commands,
|
||||
max_retries=self._max_retries,
|
||||
working_dir=self._working_dir,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
result = await loop.verify()
|
||||
return {
|
||||
"passed": result.passed,
|
||||
"attempts": result.attempts,
|
||||
"test_output": result.test_output,
|
||||
"errors": result.errors,
|
||||
}
|
||||
|
||||
|
||||
class ToolSearchTool(Tool):
|
||||
"""工具搜索工具
|
||||
|
||||
让 Agent 可以通过关键词搜索扩展工具的完整描述。配合工具描述分层注入使用:
|
||||
核心工具全量注入 prompt,扩展工具只注入名称+一行描述,Agent 通过此工具
|
||||
按需获取扩展工具的完整参数说明。
|
||||
|
||||
底层使用 :class:`ToolSearchIndex`(BM25 算法)进行关键词匹配搜索。
|
||||
|
||||
Usage::
|
||||
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
from agentkit.tools.builtin import ToolSearchTool
|
||||
|
||||
index = ToolSearchIndex(extended_tools)
|
||||
tool = ToolSearchTool(search_index=index)
|
||||
result = await tool.execute(query="read file")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_index: "ToolSearchIndex",
|
||||
name: str = "tool_search",
|
||||
description: str = (
|
||||
"Search for available tools by keyword. "
|
||||
"Returns full descriptions (name, description, parameters) of matching tools. "
|
||||
"Use this when you need details about a tool that was only listed by name."
|
||||
),
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
if top_k < 1:
|
||||
raise ValueError(f"top_k must be >= 1, got {top_k}")
|
||||
schema = input_schema or {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search keywords to find relevant tools "
|
||||
"(e.g. 'read file', 'web search', 'run tests')."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=schema,
|
||||
output_schema=output_schema,
|
||||
version=version,
|
||||
tags=tags or ["tool", "search", "meta"],
|
||||
)
|
||||
self._search_index = search_index
|
||||
self._top_k = top_k
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""执行工具搜索。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词。
|
||||
|
||||
Returns:
|
||||
包含 ``query``、``count`` 和 ``results``(每个工具的完整描述)的字典。
|
||||
"""
|
||||
query = str(kwargs.get("query", "")).strip()
|
||||
if not query:
|
||||
return {
|
||||
"error": "query parameter is required",
|
||||
"results": [],
|
||||
}
|
||||
|
||||
results = self._search_index.search(query, top_k=self._top_k)
|
||||
if not results:
|
||||
return {
|
||||
"query": query,
|
||||
"count": 0,
|
||||
"results": [],
|
||||
"message": f"No tools matched '{query}'.",
|
||||
}
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"count": len(results),
|
||||
"results": [self._format_tool_full(tool) for tool in results],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_tool_full(tool: Tool) -> dict[str, Any]:
|
||||
"""Format a tool's full description for the LLM."""
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema or {"type": "object", "properties": {}},
|
||||
}
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
"""Unit tests for SqliteConversationStore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path: Path) -> str:
|
||||
"""Return a temporary database path."""
|
||||
return str(tmp_path / "test_conversations.db")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def store(db_path: str) -> SqliteConversationStore:
|
||||
"""Create a SqliteConversationStore with a temporary database."""
|
||||
s = SqliteConversationStore(db_path=db_path)
|
||||
yield s
|
||||
await s._close_db()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBasicCRUD:
|
||||
async def test_create_conversation(self, store: SqliteConversationStore) -> None:
|
||||
conv = await store.get_or_create()
|
||||
assert conv.id
|
||||
assert conv.created_at is not None
|
||||
assert conv.updated_at is not None
|
||||
assert conv.messages == []
|
||||
|
||||
async def test_create_conversation_with_id(self, store: SqliteConversationStore) -> None:
|
||||
conv = await store.get_or_create("my-conv-id")
|
||||
assert conv.id == "my-conv-id"
|
||||
|
||||
async def test_get_or_create_returns_existing(self, store: SqliteConversationStore) -> None:
|
||||
await store.get_or_create("reuse-id")
|
||||
await store.add_message("reuse-id", "user", "hello")
|
||||
conv2 = await store.get_or_create("reuse-id")
|
||||
assert conv2.id == "reuse-id"
|
||||
# In-memory cache should have the message
|
||||
assert len(conv2.messages) == 1
|
||||
|
||||
async def test_add_message(self, store: SqliteConversationStore) -> None:
|
||||
await store.get_or_create("msg-test")
|
||||
msg = await store.add_message("msg-test", "user", "Hello world")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello world"
|
||||
assert msg.metadata == {}
|
||||
|
||||
async def test_add_message_with_metadata(self, store: SqliteConversationStore) -> None:
|
||||
await store.get_or_create("meta-test")
|
||||
msg = await store.add_message("meta-test", "assistant", "Hi", {"key": "value"})
|
||||
assert msg.metadata == {"key": "value"}
|
||||
|
||||
async def test_add_message_nonexistent_conversation_raises(
|
||||
self, store: SqliteConversationStore
|
||||
) -> None:
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
await store.add_message("nonexistent", "user", "hello")
|
||||
|
||||
async def test_get_history(self, store: SqliteConversationStore) -> None:
|
||||
await store.get_or_create("hist-test")
|
||||
await store.add_message("hist-test", "user", "msg1")
|
||||
await store.add_message("hist-test", "assistant", "msg2")
|
||||
await store.add_message("hist-test", "user", "msg3")
|
||||
history = await store.get_history("hist-test")
|
||||
assert len(history) == 3
|
||||
assert history[0].content == "msg1"
|
||||
assert history[1].content == "msg2"
|
||||
assert history[2].content == "msg3"
|
||||
|
||||
async def test_get_history_with_limit(self, store: SqliteConversationStore) -> None:
|
||||
await store.get_or_create("limit-test")
|
||||
for i in range(10):
|
||||
await store.add_message("limit-test", "user", f"msg{i}")
|
||||
history = await store.get_history("limit-test", limit=3)
|
||||
assert len(history) == 3
|
||||
# Should return the last 3 messages
|
||||
assert history[0].content == "msg7"
|
||||
assert history[1].content == "msg8"
|
||||
assert history[2].content == "msg9"
|
||||
|
||||
async def test_get_history_empty(self, store: SqliteConversationStore) -> None:
|
||||
history = await store.get_history("nonexistent")
|
||||
assert history == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
async def test_data_survives_store_recreation(self, db_path: str) -> None:
|
||||
"""Create store, add data, create new store with same DB path, verify data survives."""
|
||||
store1 = SqliteConversationStore(db_path=db_path)
|
||||
await store1.get_or_create("persist-conv")
|
||||
await store1.add_message("persist-conv", "user", "persistent message")
|
||||
await store1._close_db()
|
||||
|
||||
store2 = SqliteConversationStore(db_path=db_path)
|
||||
history = await store2.get_history("persist-conv")
|
||||
assert len(history) == 1
|
||||
assert history[0].content == "persistent message"
|
||||
assert history[0].role == "user"
|
||||
await store2._close_db()
|
||||
|
||||
async def test_conversations_survive_recreation(self, db_path: str) -> None:
|
||||
store1 = SqliteConversationStore(db_path=db_path)
|
||||
await store1.get_or_create("conv-a")
|
||||
await store1.get_or_create("conv-b")
|
||||
await store1._close_db()
|
||||
|
||||
store2 = SqliteConversationStore(db_path=db_path)
|
||||
convs = await store2.list_conversations()
|
||||
ids = {c.id for c in convs}
|
||||
assert "conv-a" in ids
|
||||
assert "conv-b" in ids
|
||||
await store2._close_db()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListConversations:
|
||||
async def test_list_conversations_ordering(self, store: SqliteConversationStore) -> None:
|
||||
"""Most recently updated conversation should appear first."""
|
||||
await store.get_or_create("conv-first")
|
||||
await store.get_or_create("conv-second")
|
||||
# Update conv1 by adding a message
|
||||
await store.add_message("conv-first", "user", "update")
|
||||
convs = await store.list_conversations()
|
||||
assert len(convs) == 2
|
||||
# conv1 was updated more recently
|
||||
assert convs[0].id == "conv-first"
|
||||
|
||||
async def test_list_conversations_limit(self, store: SqliteConversationStore) -> None:
|
||||
for i in range(5):
|
||||
await store.get_or_create(f"conv-{i}")
|
||||
convs = await store.list_conversations(limit=3)
|
||||
assert len(convs) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LRU cache eviction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLRUCache:
|
||||
async def test_cache_eviction(self, db_path: str) -> None:
|
||||
"""Cache should evict oldest entries when over limit (data still in SQLite)."""
|
||||
store = SqliteConversationStore(db_path=db_path, max_conversations=3)
|
||||
for i in range(5):
|
||||
await store.get_or_create(f"evict-{i}")
|
||||
# Cache should have at most 3 entries
|
||||
assert len(store._cache) <= 3
|
||||
# But all 5 conversations should be in SQLite
|
||||
convs = await store.list_conversations(limit=10)
|
||||
assert len(convs) == 5
|
||||
await store._close_db()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# restore_from_store (no-op)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestoreFromStore:
|
||||
async def test_restore_is_noop(self, store: SqliteConversationStore) -> None:
|
||||
"""restore_from_store should be a no-op for SQLite store."""
|
||||
# Should not raise
|
||||
await store.restore_from_store()
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
"""Tests for SpecManager — Spec 文档管理器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.spec_manager import Spec, SpecManager, SpecStep
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def specs_dir(tmp_path: Path) -> str:
|
||||
"""Provide a temporary directory for spec files."""
|
||||
return str(tmp_path / "specs")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mgr(specs_dir: str) -> SpecManager:
|
||||
"""Create a SpecManager with a temporary directory."""
|
||||
return SpecManager(specs_dir=specs_dir)
|
||||
|
||||
|
||||
def make_spec(spec_id: str = "test-spec", goal: str = "test goal") -> Spec:
|
||||
"""Create a test Spec."""
|
||||
return Spec(
|
||||
spec_id=spec_id,
|
||||
goal=goal,
|
||||
steps=[
|
||||
SpecStep(step_id="s1", name="Step 1", description="First step"),
|
||||
SpecStep(step_id="s2", name="Step 2", description="Second step", dependencies=["s1"]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestSpecManagerCreateAndGet:
|
||||
"""Test create and get a spec."""
|
||||
|
||||
def test_create_and_get(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
path = mgr.create(spec)
|
||||
assert path.exists()
|
||||
|
||||
loaded = mgr.get(spec.spec_id)
|
||||
assert loaded is not None
|
||||
assert loaded.spec_id == spec.spec_id
|
||||
assert loaded.goal == spec.goal
|
||||
assert len(loaded.steps) == 2
|
||||
assert loaded.steps[0].step_id == "s1"
|
||||
assert loaded.steps[1].dependencies == ["s1"]
|
||||
|
||||
def test_create_writes_yaml_file(self, mgr: SpecManager, specs_dir: str):
|
||||
spec = make_spec(spec_id="yaml-test")
|
||||
mgr.create(spec)
|
||||
yaml_path = Path(specs_dir) / "yaml-test.yaml"
|
||||
assert yaml_path.exists()
|
||||
|
||||
def test_get_returns_from_cache(self, mgr: SpecManager):
|
||||
spec = make_spec(spec_id="cached")
|
||||
mgr.create(spec)
|
||||
# Second get should hit cache
|
||||
loaded = mgr.get("cached")
|
||||
assert loaded is not None
|
||||
assert loaded.spec_id == "cached"
|
||||
|
||||
|
||||
class TestSpecManagerUpdate:
|
||||
"""Test update spec fields."""
|
||||
|
||||
def test_update_goal(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
|
||||
updated = mgr.update(spec.spec_id, goal="new goal")
|
||||
assert updated is not None
|
||||
assert updated.goal == "new goal"
|
||||
|
||||
def test_update_steps(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
|
||||
new_steps = [
|
||||
{"step_id": "s1", "name": "Step 1 Updated", "description": "Updated first step"},
|
||||
]
|
||||
updated = mgr.update(spec.spec_id, steps=new_steps)
|
||||
assert updated is not None
|
||||
assert len(updated.steps) == 1
|
||||
assert updated.steps[0].name == "Step 1 Updated"
|
||||
|
||||
def test_update_metadata(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
|
||||
updated = mgr.update(spec.spec_id, metadata={"key": "value"})
|
||||
assert updated is not None
|
||||
assert updated.metadata == {"key": "value"}
|
||||
|
||||
def test_update_nonexistent_returns_none(self, mgr: SpecManager):
|
||||
result = mgr.update("nonexistent", goal="x")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpecManagerConfirm:
|
||||
"""Test confirm sets status and confirmed_at."""
|
||||
|
||||
def test_confirm_sets_status_and_timestamp(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
|
||||
assert spec.status == "draft"
|
||||
assert spec.confirmed_at is None
|
||||
|
||||
confirmed = mgr.confirm(spec.spec_id)
|
||||
assert confirmed is not None
|
||||
assert confirmed.status == "confirmed"
|
||||
assert confirmed.confirmed_at is not None
|
||||
|
||||
def test_confirm_marks_pending_steps_as_confirmed(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
|
||||
confirmed = mgr.confirm(spec.spec_id)
|
||||
assert confirmed is not None
|
||||
for step in confirmed.steps:
|
||||
assert step.status == "confirmed"
|
||||
|
||||
def test_confirm_nonexistent_returns_none(self, mgr: SpecManager):
|
||||
result = mgr.confirm("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSpecManagerList:
|
||||
"""Test list_specs returns specs sorted by created_at desc."""
|
||||
|
||||
def test_list_specs_sorted_by_created_at_desc(self, mgr: SpecManager):
|
||||
spec_a = make_spec(spec_id="spec-a", goal="A")
|
||||
mgr.create(spec_a)
|
||||
|
||||
# Small delay to ensure different timestamps
|
||||
time.sleep(0.01)
|
||||
|
||||
spec_b = make_spec(spec_id="spec-b", goal="B")
|
||||
mgr.create(spec_b)
|
||||
|
||||
specs = mgr.list_specs()
|
||||
assert len(specs) == 2
|
||||
# Most recent first
|
||||
assert specs[0].spec_id == "spec-b"
|
||||
assert specs[1].spec_id == "spec-a"
|
||||
|
||||
def test_list_specs_filter_by_status(self, mgr: SpecManager):
|
||||
spec_a = make_spec(spec_id="draft-spec")
|
||||
mgr.create(spec_a)
|
||||
|
||||
spec_b = make_spec(spec_id="confirmed-spec")
|
||||
mgr.create(spec_b)
|
||||
mgr.confirm(spec_b.spec_id)
|
||||
|
||||
draft_specs = mgr.list_specs(status="draft")
|
||||
assert len(draft_specs) == 1
|
||||
assert draft_specs[0].spec_id == "draft-spec"
|
||||
|
||||
confirmed_specs = mgr.list_specs(status="confirmed")
|
||||
assert len(confirmed_specs) == 1
|
||||
assert confirmed_specs[0].spec_id == "confirmed-spec"
|
||||
|
||||
def test_list_specs_empty(self, mgr: SpecManager):
|
||||
specs = mgr.list_specs()
|
||||
assert specs == []
|
||||
|
||||
|
||||
class TestSpecManagerDelete:
|
||||
"""Test delete removes the spec."""
|
||||
|
||||
def test_delete_removes_spec(self, mgr: SpecManager):
|
||||
spec = make_spec()
|
||||
mgr.create(spec)
|
||||
assert mgr.get(spec.spec_id) is not None
|
||||
|
||||
result = mgr.delete(spec.spec_id)
|
||||
assert result is True
|
||||
assert mgr.get(spec.spec_id) is None
|
||||
|
||||
def test_delete_removes_yaml_file(self, mgr: SpecManager, specs_dir: str):
|
||||
spec = make_spec(spec_id="delete-me")
|
||||
mgr.create(spec)
|
||||
yaml_path = Path(specs_dir) / "delete-me.yaml"
|
||||
assert yaml_path.exists()
|
||||
|
||||
mgr.delete("delete-me")
|
||||
assert not yaml_path.exists()
|
||||
|
||||
def test_delete_nonexistent_returns_false(self, mgr: SpecManager):
|
||||
result = mgr.delete("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSpecManagerGetNonExistent:
|
||||
"""Test get non-existent spec returns None."""
|
||||
|
||||
def test_get_nonexistent_returns_none(self, mgr: SpecManager):
|
||||
result = mgr.get("does-not-exist")
|
||||
assert result is None
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
"""VerificationLoop 单元测试"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.verification_loop import VerificationLoop, VerificationResult
|
||||
|
||||
|
||||
class TestVerify:
|
||||
"""verify() 方法测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_success(self) -> None:
|
||||
"""成功命令返回 passed=True"""
|
||||
loop = VerificationLoop(commands=["echo ok"], timeout=10.0)
|
||||
result = await loop.verify()
|
||||
assert isinstance(result, VerificationResult)
|
||||
assert result.passed is True
|
||||
assert result.attempts == 1
|
||||
assert "ok" in result.test_output
|
||||
assert result.errors == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_failure(self) -> None:
|
||||
"""失败命令返回 passed=False"""
|
||||
loop = VerificationLoop(commands=["false"], timeout=10.0)
|
||||
result = await loop.verify()
|
||||
assert result.passed is False
|
||||
assert result.attempts == 1
|
||||
assert len(result.errors) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_timeout(self) -> None:
|
||||
"""超时命令返回 passed=False"""
|
||||
loop = VerificationLoop(commands=["sleep 10"], timeout=0.5)
|
||||
result = await loop.verify()
|
||||
assert result.passed is False
|
||||
assert any("timed out" in e for e in result.errors)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_command_not_found(self) -> None:
|
||||
"""不存在的命令返回 passed=False"""
|
||||
loop = VerificationLoop(commands=["nonexistent_command_xyz"], timeout=5.0)
|
||||
result = await loop.verify()
|
||||
assert result.passed is False
|
||||
assert len(result.errors) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_multiple_commands_partial_failure(self) -> None:
|
||||
"""部分命令失败时整体返回 passed=False"""
|
||||
loop = VerificationLoop(commands=["echo ok", "false"], timeout=10.0)
|
||||
result = await loop.verify()
|
||||
assert result.passed is False
|
||||
assert len(result.errors) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_default_commands(self) -> None:
|
||||
"""默认命令为 pytest 和 ruff check"""
|
||||
loop = VerificationLoop()
|
||||
assert loop._commands == ["pytest -x -q", "ruff check src/"]
|
||||
|
||||
|
||||
class TestVerifyAndRetry:
|
||||
"""verify_and_retry() 方法测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_no_fix_callback(self) -> None:
|
||||
"""无 fix_callback 时重试指定次数"""
|
||||
loop = VerificationLoop(commands=["false"], max_retries=2, timeout=5.0)
|
||||
result = await loop.verify_and_retry()
|
||||
assert result.passed is False
|
||||
assert result.attempts == 3 # 1 initial + 2 retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_respected(self) -> None:
|
||||
"""max_retries=0 时不重试"""
|
||||
loop = VerificationLoop(commands=["false"], max_retries=0, timeout=5.0)
|
||||
result = await loop.verify_and_retry()
|
||||
assert result.passed is False
|
||||
assert result.attempts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_with_fix_callback(self) -> None:
|
||||
"""fix_callback 被调用并接收 errors 和 test_output"""
|
||||
call_count = 0
|
||||
received_args: list[tuple[list[str], str]] = []
|
||||
|
||||
async def fix_cb(errors: list[str], test_output: str) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
received_args.append((errors, test_output))
|
||||
|
||||
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||
result = await loop.verify_and_retry(fix_callback=fix_cb)
|
||||
assert result.passed is False
|
||||
assert call_count == 1
|
||||
assert len(received_args) == 1
|
||||
assert len(received_args[0][0]) > 0 # errors
|
||||
assert isinstance(received_args[0][1], str) # test_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_succeeds_after_fix(self) -> None:
|
||||
"""fix_callback 修复后验证成功"""
|
||||
attempt = 0
|
||||
|
||||
async def fix_cb(errors: list[str], test_output: str) -> None:
|
||||
pass # Simulate fix applied
|
||||
|
||||
# Use a command that always fails — but test that the retry mechanism works
|
||||
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||
result = await loop.verify_and_retry(fix_callback=fix_cb)
|
||||
# false always fails, so result should still be False
|
||||
assert result.passed is False
|
||||
assert result.attempts == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_callback_exception_handled(self) -> None:
|
||||
"""fix_callback 抛出异常时不影响重试"""
|
||||
async def bad_fix_cb(errors: list[str], test_output: str) -> None:
|
||||
raise RuntimeError("fix failed!")
|
||||
|
||||
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||
result = await loop.verify_and_retry(fix_callback=bad_fix_cb)
|
||||
assert result.passed is False
|
||||
assert result.attempts == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_and_retry_success_first_try(self) -> None:
|
||||
"""首次验证成功时不重试"""
|
||||
loop = VerificationLoop(commands=["echo ok"], max_retries=3, timeout=10.0)
|
||||
result = await loop.verify_and_retry()
|
||||
assert result.passed is True
|
||||
assert result.attempts == 1
|
||||
|
|
@ -12,10 +12,8 @@ from fastapi.testclient import TestClient
|
|||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.server.app import create_app
|
||||
from agentkit.server.routes.portal import (
|
||||
CAPABILITY_CATEGORIES,
|
||||
ConversationStore,
|
||||
)
|
||||
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||
from agentkit.server.routes.portal import CAPABILITY_CATEGORIES
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
|
@ -84,81 +82,90 @@ def _register_skill(registry: SkillRegistry, name: str = "chat_skill", **kwargs)
|
|||
|
||||
|
||||
class TestConversationStore:
|
||||
def test_get_or_create_new(self):
|
||||
store = ConversationStore()
|
||||
conv = store.get_or_create()
|
||||
"""Tests for SqliteConversationStore (async, in-memory DB)."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path):
|
||||
return SqliteConversationStore(db_path=str(tmp_path / "test.db"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_new(self, store):
|
||||
conv = await store.get_or_create()
|
||||
assert conv.id is not None
|
||||
assert conv.messages == []
|
||||
|
||||
def test_get_or_create_with_id(self):
|
||||
store = ConversationStore()
|
||||
conv = store.get_or_create("test-id-123")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_with_id(self, store):
|
||||
conv = await store.get_or_create("test-id-123")
|
||||
assert conv.id == "test-id-123"
|
||||
|
||||
def test_get_or_create_reuse(self):
|
||||
store = ConversationStore()
|
||||
store.get_or_create("reuse-id")
|
||||
store.add_message("reuse-id", "user", "hello")
|
||||
conv2 = store.get_or_create("reuse-id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_reuse(self, store):
|
||||
await store.get_or_create("reuse-id")
|
||||
await store.add_message("reuse-id", "user", "hello")
|
||||
conv2 = await store.get_or_create("reuse-id")
|
||||
assert conv2.id == "reuse-id"
|
||||
assert len(conv2.messages) == 1
|
||||
|
||||
def test_add_message(self):
|
||||
store = ConversationStore()
|
||||
conv = store.get_or_create("msg-id")
|
||||
msg = store.add_message("msg-id", "user", "hello")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message(self, store):
|
||||
conv = await store.get_or_create("msg-id")
|
||||
msg = await store.add_message("msg-id", "user", "hello")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "hello"
|
||||
assert len(conv.messages) == 1
|
||||
|
||||
def test_add_message_not_found(self):
|
||||
store = ConversationStore()
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message_not_found(self, store):
|
||||
with pytest.raises(KeyError):
|
||||
store.add_message("nonexistent", "user", "hello")
|
||||
await store.add_message("nonexistent", "user", "hello")
|
||||
|
||||
def test_get_history(self):
|
||||
store = ConversationStore()
|
||||
store.get_or_create("hist-id")
|
||||
store.add_message("hist-id", "user", "msg1")
|
||||
store.add_message("hist-id", "assistant", "msg2")
|
||||
history = store.get_history("hist-id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history(self, store):
|
||||
await store.get_or_create("hist-id")
|
||||
await store.add_message("hist-id", "user", "msg1")
|
||||
await store.add_message("hist-id", "assistant", "msg2")
|
||||
history = await store.get_history("hist-id")
|
||||
assert len(history) == 2
|
||||
assert history[0].role == "user"
|
||||
assert history[1].role == "assistant"
|
||||
|
||||
def test_get_history_limit(self):
|
||||
store = ConversationStore()
|
||||
store.get_or_create("limit-id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_limit(self, store):
|
||||
await store.get_or_create("limit-id")
|
||||
for i in range(10):
|
||||
store.add_message("limit-id", "user", f"msg{i}")
|
||||
history = store.get_history("limit-id", limit=3)
|
||||
await store.add_message("limit-id", "user", f"msg{i}")
|
||||
history = await store.get_history("limit-id", limit=3)
|
||||
assert len(history) == 3
|
||||
assert history[0].content == "msg7"
|
||||
|
||||
def test_get_history_nonexistent(self):
|
||||
store = ConversationStore()
|
||||
history = store.get_history("no-such-id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_nonexistent(self, store):
|
||||
history = await store.get_history("no-such-id")
|
||||
assert history == []
|
||||
|
||||
def test_list_conversations(self):
|
||||
store = ConversationStore()
|
||||
store.get_or_create("conv-a")
|
||||
store.get_or_create("conv-b")
|
||||
convs = store.list_conversations()
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations(self, store):
|
||||
await store.get_or_create("conv-a")
|
||||
await store.get_or_create("conv-b")
|
||||
convs = await store.list_conversations()
|
||||
assert len(convs) == 2
|
||||
|
||||
def test_list_conversations_limit(self):
|
||||
store = ConversationStore()
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_limit(self, store):
|
||||
for i in range(5):
|
||||
store.get_or_create(f"conv-{i}")
|
||||
convs = store.list_conversations(limit=2)
|
||||
await store.get_or_create(f"conv-{i}")
|
||||
convs = await store.list_conversations(limit=2)
|
||||
assert len(convs) == 2
|
||||
|
||||
def test_max_conversations_eviction(self):
|
||||
store = ConversationStore(max_conversations=3)
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_conversations_eviction(self, tmp_path):
|
||||
store = SqliteConversationStore(
|
||||
db_path=str(tmp_path / "evict.db"), max_conversations=3
|
||||
)
|
||||
for i in range(5):
|
||||
store.get_or_create(f"evict-{i}")
|
||||
assert len(store._conversations) <= 3
|
||||
await store.get_or_create(f"evict-{i}")
|
||||
assert len(store._cache) <= 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -178,7 +185,7 @@ class TestPortalChat:
|
|||
data = response.json()
|
||||
assert data["conversation_id"] is not None
|
||||
assert data["matched_skill"] == "chat_skill"
|
||||
assert data["routing_method"] == "direct"
|
||||
assert data["routing_method"] == "skill_prefix"
|
||||
assert data["confidence"] == 1.0
|
||||
assert data["status"] == "completed"
|
||||
|
||||
|
|
@ -196,12 +203,13 @@ class TestPortalChat:
|
|||
assert data["conversation_id"] is not None
|
||||
|
||||
def test_chat_no_skills_available(self, client):
|
||||
"""Greeting fast-path works even without skills (DIRECT_CHAT mode)."""
|
||||
response = client.post(
|
||||
"/api/v1/portal/chat",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "No skills available" in response.json()["detail"]
|
||||
# Greeting regex fast-path: no skill needed, returns 200
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_chat_skill_not_found(self, client):
|
||||
response = client.post(
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Covers:
|
|||
3. ConfigDrivenAgent compressor passthrough
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
|
@ -114,7 +115,10 @@ class TestCreateAppCompression:
|
|||
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
||||
mock_create.return_value = None
|
||||
|
||||
# No server_config at all
|
||||
# No server_config at all — also prevent auto-discovery of agentkit.yaml
|
||||
with patch.dict(os.environ, {"AGENTKIT_CONFIG_PATH": ""}, clear=False):
|
||||
# Prevent CWD agentkit.yaml from being auto-loaded
|
||||
with patch("os.path.exists", return_value=False):
|
||||
app = create_app()
|
||||
|
||||
# create_compressor should not be called (no server_config)
|
||||
|
|
@ -184,6 +188,7 @@ class TestConfigDrivenAgentCompression:
|
|||
agent._skill_config = skill_config
|
||||
agent._prompt_template = None
|
||||
agent._tools = []
|
||||
agent._tool_registry = None
|
||||
agent._memory_retriever = None
|
||||
agent._compressor = mock_compressor
|
||||
agent._evolution_enabled = False
|
||||
|
|
@ -230,6 +235,7 @@ class TestConfigDrivenAgentCompression:
|
|||
agent._skill_config = skill_config
|
||||
agent._prompt_template = None
|
||||
agent._tools = []
|
||||
agent._tool_registry = None
|
||||
agent._memory_retriever = None
|
||||
agent._compressor = None
|
||||
agent._evolution_enabled = False
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from agentkit.telemetry.tracer import (
|
|||
get_tracer,
|
||||
init_telemetry,
|
||||
)
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||||
|
||||
|
||||
|
|
@ -155,46 +154,6 @@ class TestGetTracer:
|
|||
assert tracer1 is tracer2
|
||||
|
||||
|
||||
# ── CostAwareRouter span 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestCostAwareRouterSpan:
|
||||
"""CostAwareRouter 创建 span 并设置属性"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_creates_span_with_attributes(self):
|
||||
"""路由时创建 span 并设置 route.layer 和 route.target 属性"""
|
||||
init_telemetry(TelemetryConfig(enabled=False))
|
||||
tracer = get_tracer()
|
||||
|
||||
# 用 mock 替换 start_span 以验证调用
|
||||
mock_span = MagicMock()
|
||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
||||
mock_span.__exit__ = MagicMock(return_value=False)
|
||||
mock_span.set_attribute = MagicMock()
|
||||
|
||||
with patch.object(tracer, "start_span", return_value=mock_span):
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=MagicMock(),
|
||||
intent_router=MagicMock(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
|
||||
tracer.start_span.assert_called_once_with("router.route")
|
||||
# 验证 span 设置了 input.length 属性
|
||||
mock_span.set_attribute.assert_any_call("input.length", len("你好"))
|
||||
# 验证 span 设置了 route.layer 和 route.target
|
||||
call_args_list = [
|
||||
(call.args[0], call.args[1])
|
||||
for call in mock_span.set_attribute.call_args_list
|
||||
]
|
||||
assert ("route.layer", "greeting") in call_args_list
|
||||
assert ("route.target", "default") in call_args_list
|
||||
|
||||
|
||||
# ── AlignmentGuard span 测试 ──────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue