feat: SQLite persistence, verification loop, spec-driven execution
Phase 2 of architecture optimization (U5/U6/U9):
- U5: SqliteConversationStore with WAL mode + LRU cache (1000 convs)
Replaces in-memory ConversationStore in portal.py
Data survives server restarts (ref: Codex Thread persistence)
- U6: VerificationLoop with verify/verify_and_retry
Default commands: pytest + ruff check
ReActEngine integration via verification_enabled flag
New run_tests tool for LLM to invoke verification
- U9: SpecManager for plan-as-contract (ref: Qoder Quest Mode)
Plans persisted to .agentkit/specs/{spec_id}.yaml
API: GET/PUT /api/v1/specs, POST /api/v1/specs/{id}/confirm
PlanExecEngine emits spec_created event after plan generation
Also fixes: portal skill_name routing, app.py SessionManager guard,
test_telemetry CostAwareRouter removal, test_compression_config fixture
This commit is contained in:
parent
5374bc8501
commit
200174c5c7
|
|
@ -0,0 +1,294 @@
|
||||||
|
"""SQLite-backed conversation store with async support and in-memory LRU cache."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data classes (mirrors portal.py ChatMessage / Conversation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
role: str # "user" or "assistant"
|
||||||
|
content: str
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Conversation:
|
||||||
|
id: str
|
||||||
|
messages: list[ChatMessage] = field(default_factory=list)
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SCHEMA_SQL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
conversation_id TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
metadata TEXT DEFAULT '{}',
|
||||||
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_conv_id ON messages(conversation_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteConversationStore:
|
||||||
|
"""SQLite-backed conversation store with an in-memory LRU cache.
|
||||||
|
|
||||||
|
Drop-in replacement for the in-memory ConversationStore in portal.py.
|
||||||
|
Data is persisted to SQLite so it survives server restarts.
|
||||||
|
An in-memory LRU cache of recent conversations provides fast access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str | Path | None = None,
|
||||||
|
max_conversations: int = 1000,
|
||||||
|
) -> None:
|
||||||
|
if db_path is None:
|
||||||
|
db_path = os.path.expanduser("~/.agentkit/conversations.db")
|
||||||
|
self._db_path = str(db_path)
|
||||||
|
self._max = max_conversations
|
||||||
|
self._cache: OrderedDict[str, Conversation] = OrderedDict()
|
||||||
|
self._db: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _ensure_db(self) -> aiosqlite.Connection:
|
||||||
|
"""Lazily open the database and ensure schema exists."""
|
||||||
|
if self._db is None:
|
||||||
|
db_dir = os.path.dirname(self._db_path)
|
||||||
|
if db_dir:
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
self._db.row_factory = aiosqlite.Row
|
||||||
|
await self._db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await self._db.executescript(_SCHEMA_SQL)
|
||||||
|
await self._db.commit()
|
||||||
|
return self._db
|
||||||
|
|
||||||
|
async def _close_db(self) -> None:
|
||||||
|
"""Close the database connection."""
|
||||||
|
if self._db is not None:
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dt_to_str(dt: datetime) -> str:
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _str_to_dt(s: str) -> datetime:
|
||||||
|
return datetime.fromisoformat(s)
|
||||||
|
|
||||||
|
def _touch_cache(self, conv_id: str) -> None:
|
||||||
|
"""Move conversation to the end of the LRU cache (most recently used)."""
|
||||||
|
if conv_id in self._cache:
|
||||||
|
self._cache.move_to_end(conv_id)
|
||||||
|
|
||||||
|
def _evict_if_needed(self) -> None:
|
||||||
|
"""Evict oldest entry from cache if over limit (does NOT delete from SQLite)."""
|
||||||
|
while len(self._cache) > self._max:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API (matches ConversationStore interface)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_or_create(self, conversation_id: str | None = None) -> Conversation:
|
||||||
|
"""Get existing conversation or create a new one.
|
||||||
|
|
||||||
|
If the conversation is in cache, return it directly.
|
||||||
|
Otherwise, try to load from SQLite. If not found, create new.
|
||||||
|
Messages are loaded lazily (only when get_history is called).
|
||||||
|
"""
|
||||||
|
db = await self._ensure_db()
|
||||||
|
|
||||||
|
if conversation_id and conversation_id in self._cache:
|
||||||
|
conv = self._cache[conversation_id]
|
||||||
|
conv.updated_at = datetime.now(timezone.utc)
|
||||||
|
self._touch_cache(conv.id)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
# Try loading from SQLite
|
||||||
|
if conversation_id:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
conv = Conversation(
|
||||||
|
id=row["id"],
|
||||||
|
created_at=self._str_to_dt(row["created_at"]),
|
||||||
|
updated_at=self._str_to_dt(row["updated_at"]),
|
||||||
|
)
|
||||||
|
self._cache[conv.id] = conv
|
||||||
|
self._evict_if_needed()
|
||||||
|
return conv
|
||||||
|
|
||||||
|
# Create new
|
||||||
|
cid = conversation_id or str(uuid.uuid4())
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
conv = Conversation(id=cid, created_at=now, updated_at=now)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO conversations (id, created_at, updated_at) VALUES (?, ?, ?)",
|
||||||
|
(conv.id, self._dt_to_str(conv.created_at), self._dt_to_str(conv.updated_at)),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
self._cache[conv.id] = conv
|
||||||
|
self._evict_if_needed()
|
||||||
|
return conv
|
||||||
|
|
||||||
|
async def add_message(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
role: str,
|
||||||
|
content: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> ChatMessage:
|
||||||
|
"""Add a message to a conversation (both in-memory cache and SQLite)."""
|
||||||
|
db = await self._ensure_db()
|
||||||
|
|
||||||
|
# Ensure conversation exists in cache
|
||||||
|
if conversation_id not in self._cache:
|
||||||
|
# Try loading from SQLite
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, created_at, updated_at FROM conversations WHERE id = ?",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
conv = Conversation(
|
||||||
|
id=row["id"],
|
||||||
|
created_at=self._str_to_dt(row["created_at"]),
|
||||||
|
updated_at=self._str_to_dt(row["updated_at"]),
|
||||||
|
)
|
||||||
|
self._cache[conv.id] = conv
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Conversation '{conversation_id}' not found")
|
||||||
|
|
||||||
|
conv = self._cache[conversation_id]
|
||||||
|
msg = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
conv.messages.append(msg)
|
||||||
|
conv.updated_at = datetime.now(timezone.utc)
|
||||||
|
self._touch_cache(conv.id)
|
||||||
|
|
||||||
|
# Persist to SQLite
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO messages (conversation_id, role, content, timestamp, metadata) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
conversation_id,
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
self._dt_to_str(msg.timestamp),
|
||||||
|
json.dumps(msg.metadata, default=str),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||||
|
(self._dt_to_str(conv.updated_at), conversation_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def get_history(self, conversation_id: str, limit: int = 50) -> list[ChatMessage]:
|
||||||
|
"""Get recent messages for a conversation.
|
||||||
|
|
||||||
|
Loads from SQLite to ensure completeness (in-memory cache may only
|
||||||
|
have a subset of messages after a restart).
|
||||||
|
"""
|
||||||
|
db = await self._ensure_db()
|
||||||
|
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT role, content, timestamp, metadata FROM messages "
|
||||||
|
"WHERE conversation_id = ? ORDER BY id DESC LIMIT ?",
|
||||||
|
(conversation_id, limit),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
# Reverse because we fetched DESC but want chronological order
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
for row in reversed(rows):
|
||||||
|
meta: dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
meta = json.loads(row["metadata"]) if row["metadata"] else {}
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=row["role"],
|
||||||
|
content=row["content"],
|
||||||
|
timestamp=self._str_to_dt(row["timestamp"]),
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def list_conversations(self, limit: int = 20) -> list[Conversation]:
|
||||||
|
"""List recent conversations ordered by updated_at (most recent first)."""
|
||||||
|
db = await self._ensure_db()
|
||||||
|
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, created_at, updated_at FROM conversations "
|
||||||
|
"ORDER BY updated_at DESC LIMIT ?",
|
||||||
|
(limit,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
result: list[Conversation] = []
|
||||||
|
for row in rows:
|
||||||
|
conv_id = row["id"]
|
||||||
|
# Use cached version if available (may have in-memory messages)
|
||||||
|
if conv_id in self._cache:
|
||||||
|
result.append(self._cache[conv_id])
|
||||||
|
else:
|
||||||
|
result.append(
|
||||||
|
Conversation(
|
||||||
|
id=conv_id,
|
||||||
|
created_at=self._str_to_dt(row["created_at"]),
|
||||||
|
updated_at=self._str_to_dt(row["updated_at"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def restore_from_store(
|
||||||
|
self,
|
||||||
|
max_sessions: int = 50,
|
||||||
|
max_messages_per_session: int = 100,
|
||||||
|
) -> None:
|
||||||
|
"""No-op for SQLite store — data is already persisted in the database."""
|
||||||
|
# Nothing to do; all data lives in SQLite and is loaded on demand.
|
||||||
|
|
@ -0,0 +1,171 @@
|
||||||
|
"""Spec Manager — 执行计划规格文档管理器
|
||||||
|
|
||||||
|
将 PlanExecEngine 生成的执行计划持久化为 Spec 文档,
|
||||||
|
用户可在执行前查看、编辑和确认 Spec,作为人与 AI 之间的契约。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpecStep:
|
||||||
|
"""A single step in a spec."""
|
||||||
|
|
||||||
|
step_id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
dependencies: list[str] = field(default_factory=list)
|
||||||
|
status: str = "pending" # pending | confirmed | executing | completed | failed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Spec:
|
||||||
|
"""A specification document for a planned task."""
|
||||||
|
|
||||||
|
spec_id: str
|
||||||
|
goal: str
|
||||||
|
steps: list[SpecStep] = field(default_factory=list)
|
||||||
|
status: str = "draft" # draft | confirmed | executing | completed | failed
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
confirmed_at: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecManager:
|
||||||
|
"""Manages Spec documents as first-class citizens.
|
||||||
|
|
||||||
|
Specs are persisted to .agentkit/specs/ directory as YAML files.
|
||||||
|
Users can view, edit, and confirm specs before execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, specs_dir: str | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
specs_dir: Directory for spec files. Default: .agentkit/specs/
|
||||||
|
"""
|
||||||
|
self._specs_dir = Path(specs_dir or ".agentkit/specs")
|
||||||
|
self._specs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._cache: dict[str, Spec] = {}
|
||||||
|
|
||||||
|
def create(self, spec: Spec) -> Path:
|
||||||
|
"""Persist a Spec to disk. Returns the file path."""
|
||||||
|
path = self._specs_dir / f"{spec.spec_id}.yaml"
|
||||||
|
data = asdict(spec)
|
||||||
|
path.write_text(yaml.dump(data, allow_unicode=True, default_flow_style=False), encoding="utf-8")
|
||||||
|
self._cache[spec.spec_id] = spec
|
||||||
|
logger.info(f"Spec created: {spec.spec_id} -> {path}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get(self, spec_id: str) -> Spec | None:
|
||||||
|
"""Load a Spec by ID."""
|
||||||
|
if spec_id in self._cache:
|
||||||
|
return self._cache[spec_id]
|
||||||
|
|
||||||
|
path = self._specs_dir / f"{spec_id}.yaml"
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
spec = self._dict_to_spec(data)
|
||||||
|
self._cache[spec_id] = spec
|
||||||
|
return spec
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load spec {spec_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update(self, spec_id: str, **kwargs: Any) -> Spec | None:
|
||||||
|
"""Update spec fields and persist."""
|
||||||
|
spec = self.get(spec_id)
|
||||||
|
if spec is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key == "steps" and isinstance(value, list):
|
||||||
|
spec.steps = [self._dict_to_step(s) if isinstance(s, dict) else s for s in value]
|
||||||
|
elif hasattr(spec, key):
|
||||||
|
setattr(spec, key, value)
|
||||||
|
|
||||||
|
self.create(spec) # re-persist
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def confirm(self, spec_id: str) -> Spec | None:
|
||||||
|
"""Mark a spec as confirmed (user approved execution)."""
|
||||||
|
spec = self.get(spec_id)
|
||||||
|
if spec is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
spec.status = "confirmed"
|
||||||
|
spec.confirmed_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# Mark all steps as confirmed
|
||||||
|
for step in spec.steps:
|
||||||
|
if step.status == "pending":
|
||||||
|
step.status = "confirmed"
|
||||||
|
|
||||||
|
self.create(spec) # re-persist
|
||||||
|
logger.info(f"Spec confirmed: {spec_id}")
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def list_specs(self, status: str | None = None) -> list[Spec]:
|
||||||
|
"""List all specs, optionally filtered by status. Sorted by created_at desc."""
|
||||||
|
specs: list[Spec] = []
|
||||||
|
for path in self._specs_dir.glob("*.yaml"):
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
spec = self._dict_to_spec(data)
|
||||||
|
if status is None or spec.status == status:
|
||||||
|
specs.append(spec)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load spec from {path}: {e}")
|
||||||
|
|
||||||
|
specs.sort(key=lambda s: s.created_at, reverse=True)
|
||||||
|
return specs
|
||||||
|
|
||||||
|
def delete(self, spec_id: str) -> bool:
|
||||||
|
"""Delete a spec file."""
|
||||||
|
path = self._specs_dir / f"{spec_id}.yaml"
|
||||||
|
if not path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
path.unlink()
|
||||||
|
self._cache.pop(spec_id, None)
|
||||||
|
logger.info(f"Spec deleted: {spec_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dict_to_spec(data: dict[str, Any]) -> Spec:
|
||||||
|
"""Convert a dict to a Spec instance."""
|
||||||
|
steps = [SpecManager._dict_to_step(s) for s in data.get("steps", [])]
|
||||||
|
return Spec(
|
||||||
|
spec_id=data["spec_id"],
|
||||||
|
goal=data["goal"],
|
||||||
|
steps=steps,
|
||||||
|
status=data.get("status", "draft"),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
confirmed_at=data.get("confirmed_at"),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dict_to_step(data: dict[str, Any] | SpecStep) -> SpecStep:
|
||||||
|
"""Convert a dict to a SpecStep instance."""
|
||||||
|
if isinstance(data, SpecStep):
|
||||||
|
return data
|
||||||
|
return SpecStep(
|
||||||
|
step_id=data["step_id"],
|
||||||
|
name=data["name"],
|
||||||
|
description=data["description"],
|
||||||
|
dependencies=data.get("dependencies", []),
|
||||||
|
status=data.get("status", "pending"),
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""验证循环 - 执行后运行项目测试验证结果
|
||||||
|
|
||||||
|
遵循 Codex Cloud 模式:execute → verify → retry on failure。
|
||||||
|
默认验证命令:pytest, ruff check。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerificationResult:
|
||||||
|
"""验证循环运行结果"""
|
||||||
|
|
||||||
|
passed: bool
|
||||||
|
attempts: int
|
||||||
|
test_output: str
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationLoop:
|
||||||
|
"""执行后运行项目测试验证结果
|
||||||
|
|
||||||
|
遵循 Codex Cloud 模式:execute → verify → retry on failure。
|
||||||
|
默认验证命令:pytest, ruff check。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
commands: list[str] | None = None,
|
||||||
|
max_retries: int = 2,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
commands: 验证用的 Shell 命令列表。
|
||||||
|
默认: ["pytest -x -q", "ruff check src/"]
|
||||||
|
max_retries: 初始执行后的最大重试次数。
|
||||||
|
working_dir: 运行命令的工作目录。
|
||||||
|
timeout: 每个验证命令的超时时间(秒)。
|
||||||
|
"""
|
||||||
|
self._commands = commands or ["pytest -x -q", "ruff check src/"]
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._working_dir = working_dir
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
async def verify(self) -> VerificationResult:
|
||||||
|
"""运行验证命令并返回结果
|
||||||
|
|
||||||
|
依次执行每个命令,捕获 stdout/stderr。
|
||||||
|
如果任何命令失败(非零退出码),标记为失败。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VerificationResult,所有命令成功时 passed=True。
|
||||||
|
"""
|
||||||
|
all_output: list[str] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
passed = True
|
||||||
|
|
||||||
|
for cmd in self._commands:
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
cwd=self._working_dir,
|
||||||
|
)
|
||||||
|
stdout, _ = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
await proc.wait()
|
||||||
|
output = f"Command timed out after {self._timeout}s: {cmd}"
|
||||||
|
all_output.append(output)
|
||||||
|
errors.append(output)
|
||||||
|
passed = False
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
output = f"Command failed to execute: {cmd}: {e}"
|
||||||
|
all_output.append(output)
|
||||||
|
errors.append(output)
|
||||||
|
passed = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||||
|
all_output.append(f"$ {cmd}\n{output}")
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
passed = False
|
||||||
|
errors.append(f"Command failed (exit {proc.returncode}): {cmd}\n{output}")
|
||||||
|
|
||||||
|
return VerificationResult(
|
||||||
|
passed=passed,
|
||||||
|
attempts=1,
|
||||||
|
test_output="\n".join(all_output),
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def verify_and_retry(
|
||||||
|
self,
|
||||||
|
fix_callback: Any = None,
|
||||||
|
) -> VerificationResult:
|
||||||
|
"""运行验证,如果失败则调用 fix_callback 并重试
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fix_callback: 异步可调用对象,接收 (errors, test_output),
|
||||||
|
尝试修复问题。在重试之间调用。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VerificationResult,包含最终验证状态。
|
||||||
|
"""
|
||||||
|
result = await self.verify()
|
||||||
|
|
||||||
|
retries = 0
|
||||||
|
while not result.passed and retries < self._max_retries:
|
||||||
|
retries += 1
|
||||||
|
logger.info(
|
||||||
|
"Verification failed (attempt %d/%d), %s",
|
||||||
|
retries,
|
||||||
|
self._max_retries,
|
||||||
|
"calling fix_callback" if fix_callback else "retrying",
|
||||||
|
)
|
||||||
|
|
||||||
|
if fix_callback is not None:
|
||||||
|
try:
|
||||||
|
await fix_callback(result.errors, result.test_output)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("fix_callback raised an error: %s", e)
|
||||||
|
|
||||||
|
result = await self.verify()
|
||||||
|
result.attempts = retries + 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -22,13 +22,14 @@ from pydantic import BaseModel
|
||||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
from agentkit.chat.simple_router import SimpleRouter
|
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||||
from agentkit.server.routes.evolution_dashboard import (
|
from agentkit.server.routes.evolution_dashboard import (
|
||||||
_experiences as _dashboard_experiences,
|
_experiences as _dashboard_experiences,
|
||||||
DashboardExperience,
|
DashboardExperience,
|
||||||
_broadcast_event as _broadcast_dashboard_event,
|
_broadcast_event as _broadcast_dashboard_event,
|
||||||
)
|
)
|
||||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
||||||
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||||
from agentkit.session.manager import SessionManager
|
from agentkit.session.manager import SessionManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -103,7 +104,9 @@ class ConversationStore:
|
||||||
conversations can be restored from SessionManager.
|
conversations can be restored from SessionManager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_conversations: int = 1000, session_manager: SessionManager | None = None):
|
def __init__(
|
||||||
|
self, max_conversations: int = 1000, session_manager: SessionManager | None = None
|
||||||
|
):
|
||||||
self._conversations: dict[str, Conversation] = {}
|
self._conversations: dict[str, Conversation] = {}
|
||||||
self._max = max_conversations
|
self._max = max_conversations
|
||||||
self._session_manager = session_manager
|
self._session_manager = session_manager
|
||||||
|
|
@ -141,16 +144,16 @@ class ConversationStore:
|
||||||
sid, limit=max_messages_per_session
|
sid, limit=max_messages_per_session
|
||||||
)
|
)
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
conv.messages.append(ChatMessage(
|
conv.messages.append(
|
||||||
role=msg.role.value,
|
ChatMessage(
|
||||||
content=msg.content,
|
role=msg.role.value,
|
||||||
timestamp=msg.created_at,
|
content=msg.content,
|
||||||
metadata=msg.metadata,
|
timestamp=msg.created_at,
|
||||||
))
|
metadata=msg.metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._conversations[sid] = conv
|
self._conversations[sid] = conv
|
||||||
logger.info(
|
logger.info(f"Restored {len(self._conversations)} conversations from SessionManager")
|
||||||
f"Restored {len(self._conversations)} conversations from SessionManager"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to restore conversations from SessionManager: {e}")
|
logger.warning(f"Failed to restore conversations from SessionManager: {e}")
|
||||||
|
|
||||||
|
|
@ -217,7 +220,7 @@ class ConversationStore:
|
||||||
|
|
||||||
# Heartbeat timeout in seconds — 0 disables timeout (for testing)
|
# Heartbeat timeout in seconds — 0 disables timeout (for testing)
|
||||||
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
|
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
|
||||||
_conversation_store = ConversationStore()
|
_conversation_store = SqliteConversationStore()
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# History injection helper — configurable limit + optional compression
|
# History injection helper — configurable limit + optional compression
|
||||||
|
|
@ -227,7 +230,7 @@ _conversation_store = ConversationStore()
|
||||||
_MAX_HISTORY_MESSAGES = 50
|
_MAX_HISTORY_MESSAGES = 50
|
||||||
|
|
||||||
|
|
||||||
def _build_history_messages(
|
async def _build_history_messages(
|
||||||
conv_id: str,
|
conv_id: str,
|
||||||
limit: int = _MAX_HISTORY_MESSAGES,
|
limit: int = _MAX_HISTORY_MESSAGES,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
|
|
@ -238,7 +241,7 @@ def _build_history_messages(
|
||||||
which should be appended separately by the caller).
|
which should be appended separately by the caller).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = _conversation_store.get_history(conv_id, limit=limit)
|
history = await _conversation_store.get_history(conv_id, limit=limit)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -342,14 +345,16 @@ class CapabilitiesResponse(BaseModel):
|
||||||
|
|
||||||
async def _resolve_for_chat(
|
async def _resolve_for_chat(
|
||||||
request: ChatRequest, req: Request
|
request: ChatRequest, req: Request
|
||||||
) -> tuple[ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None]:
|
) -> tuple[
|
||||||
"""Resolve agent and routing for a chat request via SimpleRouter.
|
ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None
|
||||||
|
]:
|
||||||
|
"""Resolve agent and routing for a chat request via RequestPreprocessor.
|
||||||
|
|
||||||
Returns (agent, routing_result, matched_skill_name, routing_method, confidence).
|
Returns (agent, routing_result, matched_skill_name, routing_method, confidence).
|
||||||
"""
|
"""
|
||||||
pool = req.app.state.agent_pool
|
pool = req.app.state.agent_pool
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
simple_router: SimpleRouter = req.app.state.simple_router
|
request_preprocessor: RequestPreprocessor = req.app.state.request_preprocessor
|
||||||
|
|
||||||
matched_skill_name: str | None = None
|
matched_skill_name: str | None = None
|
||||||
routing_method: str | None = None
|
routing_method: str | None = None
|
||||||
|
|
@ -362,8 +367,7 @@ async def _resolve_for_chat(
|
||||||
if default_agent is not None:
|
if default_agent is not None:
|
||||||
default_tools = default_agent.get_tools()
|
default_tools = default_agent.get_tools()
|
||||||
default_system_prompt = (
|
default_system_prompt = (
|
||||||
getattr(default_agent, "_system_prompt", None)
|
getattr(default_agent, "_system_prompt", None) or default_agent.get_system_prompt()
|
||||||
or default_agent.get_system_prompt()
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
all_skills = skill_registry.list_skills()
|
all_skills = skill_registry.list_skills()
|
||||||
|
|
@ -376,15 +380,26 @@ async def _resolve_for_chat(
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
# If skill_name is explicitly provided in the request, use it directly
|
||||||
routing_result = await simple_router.route(
|
if request.skill_name:
|
||||||
content=request.message,
|
routing_result = await request_preprocessor.preprocess(
|
||||||
skill_registry=skill_registry,
|
content=f"@skill:{request.skill_name} {request.message}",
|
||||||
default_tools=default_tools,
|
skill_registry=skill_registry,
|
||||||
default_system_prompt=default_system_prompt,
|
default_tools=default_tools,
|
||||||
default_model="default",
|
default_system_prompt=default_system_prompt,
|
||||||
default_agent_name="default",
|
default_model="default",
|
||||||
)
|
default_agent_name="default",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||||
|
routing_result = await request_preprocessor.preprocess(
|
||||||
|
content=request.message,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=default_tools,
|
||||||
|
default_system_prompt=default_system_prompt,
|
||||||
|
default_model="default",
|
||||||
|
default_agent_name="default",
|
||||||
|
)
|
||||||
|
|
||||||
matched_skill_name = routing_result.skill_name or routing_result.agent_name
|
matched_skill_name = routing_result.skill_name or routing_result.agent_name
|
||||||
routing_method = routing_result.match_method
|
routing_method = routing_result.match_method
|
||||||
|
|
@ -413,11 +428,19 @@ async def _resolve_for_chat(
|
||||||
|
|
||||||
@router.post("/portal/chat", response_model=ChatResponse)
|
@router.post("/portal/chat", response_model=ChatResponse)
|
||||||
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||||
"""Send a chat message and get a response with CostAwareRouter routing."""
|
"""Send a chat message and get a response with RequestPreprocessor routing."""
|
||||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
# If skill_name is explicitly requested but not found, return 404
|
||||||
|
if request.skill_name:
|
||||||
|
skill_registry = req.app.state.skill_registry
|
||||||
|
if not skill_registry.has_skill(request.skill_name):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Skill '{request.skill_name}' not found")
|
||||||
|
|
||||||
|
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
|
||||||
|
request, req
|
||||||
|
)
|
||||||
|
|
||||||
# Create or reuse conversation
|
# Create or reuse conversation
|
||||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
conv = await _conversation_store.get_or_create(request.conversation_id)
|
||||||
await _conversation_store.add_message(conv.id, "user", request.message)
|
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||||
|
|
||||||
llm_gateway = req.app.state.llm_gateway
|
llm_gateway = req.app.state.llm_gateway
|
||||||
|
|
@ -432,7 +455,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
||||||
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||||
chat_messages.append({"role": "user", "content": request.message})
|
chat_messages.append({"role": "user", "content": request.message})
|
||||||
# Inject conversation history
|
# Inject conversation history
|
||||||
history_msgs = _build_history_messages(conv.id)
|
history_msgs = await _build_history_messages(conv.id)
|
||||||
for hm in history_msgs:
|
for hm in history_msgs:
|
||||||
chat_messages.insert(-1, hm)
|
chat_messages.insert(-1, hm)
|
||||||
response = await llm_gateway.chat(
|
response = await llm_gateway.chat(
|
||||||
|
|
@ -467,7 +490,7 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
||||||
|
|
||||||
messages = [{"role": "user", "content": request.message}]
|
messages = [{"role": "user", "content": request.message}]
|
||||||
# Inject conversation history
|
# Inject conversation history
|
||||||
history_msgs = _build_history_messages(conv.id)
|
history_msgs = await _build_history_messages(conv.id)
|
||||||
for hm in reversed(history_msgs):
|
for hm in reversed(history_msgs):
|
||||||
messages.insert(0, hm)
|
messages.insert(0, hm)
|
||||||
tools = agent.get_tools()
|
tools = agent.get_tools()
|
||||||
|
|
@ -490,7 +513,9 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_text = f"执行出错: {e}"
|
response_text = f"执行出错: {e}"
|
||||||
else:
|
else:
|
||||||
response_text = _ensure_non_empty("".join(collected_output) if collected_output else None)
|
response_text = _ensure_non_empty(
|
||||||
|
"".join(collected_output) if collected_output else None
|
||||||
|
)
|
||||||
|
|
||||||
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||||
|
|
||||||
|
|
@ -508,13 +533,15 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
||||||
|
|
||||||
@router.post("/portal/chat/stream")
|
@router.post("/portal/chat/stream")
|
||||||
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||||
"""Stream chat responses via SSE with CostAwareRouter routing."""
|
"""Stream chat responses via SSE with RequestPreprocessor routing."""
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(
|
||||||
|
request, req
|
||||||
|
)
|
||||||
|
|
||||||
# Create or reuse conversation
|
# Create or reuse conversation
|
||||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
conv = await _conversation_store.get_or_create(request.conversation_id)
|
||||||
await _conversation_store.add_message(conv.id, "user", request.message)
|
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||||
|
|
||||||
llm_gateway = req.app.state.llm_gateway
|
llm_gateway = req.app.state.llm_gateway
|
||||||
|
|
@ -532,13 +559,16 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if routing_result is not None and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT:
|
if (
|
||||||
|
routing_result is not None
|
||||||
|
and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
):
|
||||||
# DIRECT_CHAT: direct LLM call, no ReAct loop
|
# DIRECT_CHAT: direct LLM call, no ReAct loop
|
||||||
chat_messages = []
|
chat_messages = []
|
||||||
if routing_result.system_prompt:
|
if routing_result.system_prompt:
|
||||||
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||||
chat_messages.append({"role": "user", "content": request.message})
|
chat_messages.append({"role": "user", "content": request.message})
|
||||||
history_msgs = _build_history_messages(conv.id)
|
history_msgs = await _build_history_messages(conv.id)
|
||||||
for hm in history_msgs:
|
for hm in history_msgs:
|
||||||
chat_messages.insert(-1, hm)
|
chat_messages.insert(-1, hm)
|
||||||
response = await llm_gateway.chat(
|
response = await llm_gateway.chat(
|
||||||
|
|
@ -552,7 +582,11 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
||||||
yield {
|
yield {
|
||||||
"event": "final_answer",
|
"event": "final_answer",
|
||||||
"data": json.dumps(
|
"data": json.dumps(
|
||||||
{"step": 0, "data": {"output": response_text}, "timestamp": datetime.now(timezone.utc).isoformat()}
|
{
|
||||||
|
"step": 0,
|
||||||
|
"data": {"output": response_text},
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -654,7 +688,7 @@ async def get_capabilities(req: Request, _auth: None = Depends(_verify_api_key))
|
||||||
@router.get("/portal/conversations")
|
@router.get("/portal/conversations")
|
||||||
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
|
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
|
||||||
"""List recent conversations."""
|
"""List recent conversations."""
|
||||||
convs = _conversation_store.list_conversations(limit=limit)
|
convs = await _conversation_store.list_conversations(limit=limit)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": c.id,
|
"id": c.id,
|
||||||
|
|
@ -679,55 +713,27 @@ def _derive_conversation_title(conv: Conversation) -> str:
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)
|
conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)
|
||||||
):
|
):
|
||||||
"""Get conversation history, with fallback to SessionManager for persisted data."""
|
"""Get conversation history from SQLite-backed store."""
|
||||||
# Try in-memory first
|
history = await _conversation_store.get_history(conversation_id, limit=limit)
|
||||||
if conversation_id in _conversation_store._conversations:
|
if not history:
|
||||||
conv = _conversation_store._conversations[conversation_id]
|
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
|
||||||
history = _conversation_store.get_history(conversation_id, limit=limit)
|
conv = await _conversation_store.get_or_create(conversation_id)
|
||||||
return {
|
return {
|
||||||
"id": conv.id,
|
"id": conv.id,
|
||||||
"title": _derive_conversation_title(conv),
|
"title": _derive_conversation_title(conv),
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"id": f"{conv.id}-{i}",
|
"id": f"{conv.id}-{i}",
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
"content": m.content,
|
"content": m.content,
|
||||||
"timestamp": m.timestamp.isoformat(),
|
"timestamp": m.timestamp.isoformat(),
|
||||||
"metadata": m.metadata,
|
"metadata": m.metadata,
|
||||||
}
|
}
|
||||||
for i, m in enumerate(history)
|
for i, m in enumerate(history)
|
||||||
],
|
],
|
||||||
"created_at": conv.created_at.isoformat(),
|
"created_at": conv.created_at.isoformat(),
|
||||||
"updated_at": conv.updated_at.isoformat(),
|
"updated_at": conv.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Fallback: load from SessionManager (persistent store)
|
|
||||||
sm = _conversation_store._session_manager
|
|
||||||
if sm is not None:
|
|
||||||
try:
|
|
||||||
session = await sm.get_session(conversation_id)
|
|
||||||
if session is not None:
|
|
||||||
messages = await sm.get_messages(conversation_id, limit=limit)
|
|
||||||
return {
|
|
||||||
"id": session.session_id,
|
|
||||||
"title": _derive_title_from_messages(messages),
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"id": f"{session.session_id}-{i}",
|
|
||||||
"role": m.role.value,
|
|
||||||
"content": m.content,
|
|
||||||
"timestamp": m.created_at.isoformat(),
|
|
||||||
"metadata": m.metadata,
|
|
||||||
}
|
|
||||||
for i, m in enumerate(messages)
|
|
||||||
],
|
|
||||||
"created_at": session.created_at.isoformat(),
|
|
||||||
"updated_at": session.updated_at.isoformat(),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load conversation from SessionManager: {e}")
|
|
||||||
|
|
||||||
raise HTTPException(status_code=404, detail=f"Conversation '{conversation_id}' not found")
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_title_from_messages(messages: list) -> str:
|
def _derive_title_from_messages(messages: list) -> str:
|
||||||
|
|
@ -807,7 +813,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
# Create conversation on first message (not on connect)
|
# Create conversation on first message (not on connect)
|
||||||
if conv is None:
|
if conv is None:
|
||||||
conv_id = msg.get("conversation_id")
|
conv_id = msg.get("conversation_id")
|
||||||
conv = _conversation_store.get_or_create(conv_id)
|
conv = await _conversation_store.get_or_create(conv_id)
|
||||||
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
|
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
|
||||||
|
|
||||||
# Add user message to conversation
|
# Add user message to conversation
|
||||||
|
|
@ -841,15 +847,15 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to record experience: {e}")
|
logger.warning(f"Failed to record experience: {e}")
|
||||||
|
|
||||||
# Unified routing via SimpleRouter (minimal: @skill prefix + greeting regex + REACT)
|
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||||
pool = websocket.app.state.agent_pool
|
pool = websocket.app.state.agent_pool
|
||||||
skill_registry = websocket.app.state.skill_registry
|
skill_registry = websocket.app.state.skill_registry
|
||||||
llm_gateway = websocket.app.state.llm_gateway
|
llm_gateway = websocket.app.state.llm_gateway
|
||||||
simple_router: SimpleRouter = websocket.app.state.simple_router
|
request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
|
||||||
|
|
||||||
all_skills = skill_registry.list_skills()
|
all_skills = skill_registry.list_skills()
|
||||||
|
|
||||||
# Get default tools for SimpleRouter routing
|
# Get default tools for RequestPreprocessor
|
||||||
default_tools = []
|
default_tools = []
|
||||||
default_system_prompt = None
|
default_system_prompt = None
|
||||||
default_agent = pool.get_agent("default")
|
default_agent = pool.get_agent("default")
|
||||||
|
|
@ -869,8 +875,8 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
# Preprocess via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||||
routing_result = await simple_router.route(
|
routing_result = await request_preprocessor.preprocess(
|
||||||
content=message_text,
|
content=message_text,
|
||||||
skill_registry=skill_registry,
|
skill_registry=skill_registry,
|
||||||
default_tools=default_tools,
|
default_tools=default_tools,
|
||||||
|
|
@ -901,7 +907,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
chat_messages.append({"role": "user", "content": message_text})
|
chat_messages.append({"role": "user", "content": message_text})
|
||||||
# Inject conversation history for context continuity
|
# Inject conversation history for context continuity
|
||||||
history_msgs = _build_history_messages(conv.id)
|
history_msgs = await _build_history_messages(conv.id)
|
||||||
for hm in history_msgs:
|
for hm in history_msgs:
|
||||||
chat_messages.insert(-1, hm)
|
chat_messages.insert(-1, hm)
|
||||||
response = await llm_gateway.chat(
|
response = await llm_gateway.chat(
|
||||||
|
|
@ -960,7 +966,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
chat_messages.append({"role": "user", "content": message_text})
|
chat_messages.append({"role": "user", "content": message_text})
|
||||||
try:
|
try:
|
||||||
history = _conversation_store.get_history(conv.id, limit=20)
|
history = await _conversation_store.get_history(conv.id, limit=20)
|
||||||
for hist_msg in history[:-1]:
|
for hist_msg in history[:-1]:
|
||||||
if hist_msg.role in ("user", "assistant"):
|
if hist_msg.role in ("user", "assistant"):
|
||||||
chat_messages.insert(
|
chat_messages.insert(
|
||||||
|
|
@ -1009,7 +1015,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
|
|
||||||
messages = [{"role": "user", "content": message_text}]
|
messages = [{"role": "user", "content": message_text}]
|
||||||
# Inject conversation history for context continuity
|
# Inject conversation history for context continuity
|
||||||
history_msgs = _build_history_messages(conv.id)
|
history_msgs = await _build_history_messages(conv.id)
|
||||||
for hm in reversed(history_msgs):
|
for hm in reversed(history_msgs):
|
||||||
messages.insert(0, hm)
|
messages.insert(0, hm)
|
||||||
tools = agent.get_tools()
|
tools = agent.get_tools()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import asdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
@ -9,6 +10,7 @@ from pydantic import BaseModel
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
from agentkit.core.spec_manager import SpecManager
|
||||||
|
|
||||||
router = APIRouter(tags=["tasks"])
|
router = APIRouter(tags=["tasks"])
|
||||||
|
|
||||||
|
|
@ -350,3 +352,71 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||||
}
|
}
|
||||||
|
|
||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Spec endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_spec_manager(req: Request) -> SpecManager:
|
||||||
|
"""Get or create SpecManager from app state."""
|
||||||
|
if not hasattr(req.app.state, "spec_manager") or req.app.state.spec_manager is None:
|
||||||
|
req.app.state.spec_manager = SpecManager()
|
||||||
|
return req.app.state.spec_manager
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSpecRequest(BaseModel):
|
||||||
|
"""Request body for updating a spec."""
|
||||||
|
|
||||||
|
goal: str | None = None
|
||||||
|
steps: list[dict[str, Any]] | None = None
|
||||||
|
status: str | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/specs")
|
||||||
|
async def list_specs(status: str | None = None, req: Request = None):
|
||||||
|
"""List all specs, optionally filtered by status."""
|
||||||
|
mgr = _get_spec_manager(req)
|
||||||
|
specs = mgr.list_specs(status=status)
|
||||||
|
return [asdict(s) for s in specs]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/specs/{spec_id}")
|
||||||
|
async def get_spec(spec_id: str, req: Request):
|
||||||
|
"""Get a specific spec by ID."""
|
||||||
|
mgr = _get_spec_manager(req)
|
||||||
|
spec = mgr.get(spec_id)
|
||||||
|
if spec is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||||
|
return asdict(spec)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/specs/{spec_id}")
|
||||||
|
async def update_spec(spec_id: str, body: UpdateSpecRequest, req: Request):
|
||||||
|
"""Update a spec (e.g., edit steps, change status)."""
|
||||||
|
mgr = _get_spec_manager(req)
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if body.goal is not None:
|
||||||
|
kwargs["goal"] = body.goal
|
||||||
|
if body.steps is not None:
|
||||||
|
kwargs["steps"] = body.steps
|
||||||
|
if body.status is not None:
|
||||||
|
kwargs["status"] = body.status
|
||||||
|
if body.metadata is not None:
|
||||||
|
kwargs["metadata"] = body.metadata
|
||||||
|
spec = mgr.update(spec_id, **kwargs)
|
||||||
|
if spec is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||||
|
return asdict(spec)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/specs/{spec_id}/confirm")
|
||||||
|
async def confirm_spec(spec_id: str, req: Request):
|
||||||
|
"""Confirm a spec for execution."""
|
||||||
|
mgr = _get_spec_manager(req)
|
||||||
|
spec = mgr.confirm(spec_id)
|
||||||
|
if spec is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Spec '{spec_id}' not found")
|
||||||
|
return asdict(spec)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,185 @@
|
||||||
|
"""内置工具 - 开箱即用的常用工具集合"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunTestsTool(Tool):
|
||||||
|
"""运行项目测试验证代码变更
|
||||||
|
|
||||||
|
执行 pytest 和 linting 命令来验证代码变更的正确性。
|
||||||
|
内部使用 VerificationLoop 实现验证逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "run_tests",
|
||||||
|
description: str = "Run project tests to verify code changes. Executes pytest and linting commands.",
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
max_retries: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema,
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["testing", "verification"],
|
||||||
|
)
|
||||||
|
self._working_dir = working_dir
|
||||||
|
self._timeout = timeout
|
||||||
|
self._max_retries = max_retries
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"commands": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Shell commands to run for verification. Default: ['pytest -x -q', 'ruff check src/']",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行测试验证
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commands: 可选的验证命令列表,默认使用 pytest 和 ruff check。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 passed, attempts, test_output, errors 的字典。
|
||||||
|
"""
|
||||||
|
from agentkit.core.verification_loop import VerificationLoop
|
||||||
|
|
||||||
|
commands: list[str] | None = kwargs.get("commands")
|
||||||
|
loop = VerificationLoop(
|
||||||
|
commands=commands,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
working_dir=self._working_dir,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
result = await loop.verify()
|
||||||
|
return {
|
||||||
|
"passed": result.passed,
|
||||||
|
"attempts": result.attempts,
|
||||||
|
"test_output": result.test_output,
|
||||||
|
"errors": result.errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSearchTool(Tool):
|
||||||
|
"""工具搜索工具
|
||||||
|
|
||||||
|
让 Agent 可以通过关键词搜索扩展工具的完整描述。配合工具描述分层注入使用:
|
||||||
|
核心工具全量注入 prompt,扩展工具只注入名称+一行描述,Agent 通过此工具
|
||||||
|
按需获取扩展工具的完整参数说明。
|
||||||
|
|
||||||
|
底层使用 :class:`ToolSearchIndex`(BM25 算法)进行关键词匹配搜索。
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
from agentkit.tools.builtin import ToolSearchTool
|
||||||
|
|
||||||
|
index = ToolSearchIndex(extended_tools)
|
||||||
|
tool = ToolSearchTool(search_index=index)
|
||||||
|
result = await tool.execute(query="read file")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
search_index: "ToolSearchIndex",
|
||||||
|
name: str = "tool_search",
|
||||||
|
description: str = (
|
||||||
|
"Search for available tools by keyword. "
|
||||||
|
"Returns full descriptions (name, description, parameters) of matching tools. "
|
||||||
|
"Use this when you need details about a tool that was only listed by name."
|
||||||
|
),
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
if top_k < 1:
|
||||||
|
raise ValueError(f"top_k must be >= 1, got {top_k}")
|
||||||
|
schema = input_schema or {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Search keywords to find relevant tools "
|
||||||
|
"(e.g. 'read file', 'web search', 'run tests')."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["tool", "search", "meta"],
|
||||||
|
)
|
||||||
|
self._search_index = search_index
|
||||||
|
self._top_k = top_k
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行工具搜索。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 ``query``、``count`` 和 ``results``(每个工具的完整描述)的字典。
|
||||||
|
"""
|
||||||
|
query = str(kwargs.get("query", "")).strip()
|
||||||
|
if not query:
|
||||||
|
return {
|
||||||
|
"error": "query parameter is required",
|
||||||
|
"results": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
results = self._search_index.search(query, top_k=self._top_k)
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"count": 0,
|
||||||
|
"results": [],
|
||||||
|
"message": f"No tools matched '{query}'.",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"count": len(results),
|
||||||
|
"results": [self._format_tool_full(tool) for tool in results],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_tool_full(tool: Tool) -> dict[str, Any]:
|
||||||
|
"""Format a tool's full description for the LLM."""
|
||||||
|
return {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.input_schema or {"type": "object", "properties": {}},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""Unit tests for SqliteConversationStore."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_path(tmp_path: Path) -> str:
|
||||||
|
"""Return a temporary database path."""
|
||||||
|
return str(tmp_path / "test_conversations.db")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def store(db_path: str) -> SqliteConversationStore:
|
||||||
|
"""Create a SqliteConversationStore with a temporary database."""
|
||||||
|
s = SqliteConversationStore(db_path=db_path)
|
||||||
|
yield s
|
||||||
|
await s._close_db()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Basic CRUD
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBasicCRUD:
|
||||||
|
async def test_create_conversation(self, store: SqliteConversationStore) -> None:
|
||||||
|
conv = await store.get_or_create()
|
||||||
|
assert conv.id
|
||||||
|
assert conv.created_at is not None
|
||||||
|
assert conv.updated_at is not None
|
||||||
|
assert conv.messages == []
|
||||||
|
|
||||||
|
async def test_create_conversation_with_id(self, store: SqliteConversationStore) -> None:
|
||||||
|
conv = await store.get_or_create("my-conv-id")
|
||||||
|
assert conv.id == "my-conv-id"
|
||||||
|
|
||||||
|
async def test_get_or_create_returns_existing(self, store: SqliteConversationStore) -> None:
|
||||||
|
await store.get_or_create("reuse-id")
|
||||||
|
await store.add_message("reuse-id", "user", "hello")
|
||||||
|
conv2 = await store.get_or_create("reuse-id")
|
||||||
|
assert conv2.id == "reuse-id"
|
||||||
|
# In-memory cache should have the message
|
||||||
|
assert len(conv2.messages) == 1
|
||||||
|
|
||||||
|
async def test_add_message(self, store: SqliteConversationStore) -> None:
|
||||||
|
await store.get_or_create("msg-test")
|
||||||
|
msg = await store.add_message("msg-test", "user", "Hello world")
|
||||||
|
assert msg.role == "user"
|
||||||
|
assert msg.content == "Hello world"
|
||||||
|
assert msg.metadata == {}
|
||||||
|
|
||||||
|
async def test_add_message_with_metadata(self, store: SqliteConversationStore) -> None:
|
||||||
|
await store.get_or_create("meta-test")
|
||||||
|
msg = await store.add_message("meta-test", "assistant", "Hi", {"key": "value"})
|
||||||
|
assert msg.metadata == {"key": "value"}
|
||||||
|
|
||||||
|
async def test_add_message_nonexistent_conversation_raises(
|
||||||
|
self, store: SqliteConversationStore
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
await store.add_message("nonexistent", "user", "hello")
|
||||||
|
|
||||||
|
async def test_get_history(self, store: SqliteConversationStore) -> None:
|
||||||
|
await store.get_or_create("hist-test")
|
||||||
|
await store.add_message("hist-test", "user", "msg1")
|
||||||
|
await store.add_message("hist-test", "assistant", "msg2")
|
||||||
|
await store.add_message("hist-test", "user", "msg3")
|
||||||
|
history = await store.get_history("hist-test")
|
||||||
|
assert len(history) == 3
|
||||||
|
assert history[0].content == "msg1"
|
||||||
|
assert history[1].content == "msg2"
|
||||||
|
assert history[2].content == "msg3"
|
||||||
|
|
||||||
|
async def test_get_history_with_limit(self, store: SqliteConversationStore) -> None:
|
||||||
|
await store.get_or_create("limit-test")
|
||||||
|
for i in range(10):
|
||||||
|
await store.add_message("limit-test", "user", f"msg{i}")
|
||||||
|
history = await store.get_history("limit-test", limit=3)
|
||||||
|
assert len(history) == 3
|
||||||
|
# Should return the last 3 messages
|
||||||
|
assert history[0].content == "msg7"
|
||||||
|
assert history[1].content == "msg8"
|
||||||
|
assert history[2].content == "msg9"
|
||||||
|
|
||||||
|
async def test_get_history_empty(self, store: SqliteConversationStore) -> None:
|
||||||
|
history = await store.get_history("nonexistent")
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Persistence
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPersistence:
|
||||||
|
async def test_data_survives_store_recreation(self, db_path: str) -> None:
|
||||||
|
"""Create store, add data, create new store with same DB path, verify data survives."""
|
||||||
|
store1 = SqliteConversationStore(db_path=db_path)
|
||||||
|
await store1.get_or_create("persist-conv")
|
||||||
|
await store1.add_message("persist-conv", "user", "persistent message")
|
||||||
|
await store1._close_db()
|
||||||
|
|
||||||
|
store2 = SqliteConversationStore(db_path=db_path)
|
||||||
|
history = await store2.get_history("persist-conv")
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0].content == "persistent message"
|
||||||
|
assert history[0].role == "user"
|
||||||
|
await store2._close_db()
|
||||||
|
|
||||||
|
async def test_conversations_survive_recreation(self, db_path: str) -> None:
|
||||||
|
store1 = SqliteConversationStore(db_path=db_path)
|
||||||
|
await store1.get_or_create("conv-a")
|
||||||
|
await store1.get_or_create("conv-b")
|
||||||
|
await store1._close_db()
|
||||||
|
|
||||||
|
store2 = SqliteConversationStore(db_path=db_path)
|
||||||
|
convs = await store2.list_conversations()
|
||||||
|
ids = {c.id for c in convs}
|
||||||
|
assert "conv-a" in ids
|
||||||
|
assert "conv-b" in ids
|
||||||
|
await store2._close_db()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_conversations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestListConversations:
|
||||||
|
async def test_list_conversations_ordering(self, store: SqliteConversationStore) -> None:
|
||||||
|
"""Most recently updated conversation should appear first."""
|
||||||
|
await store.get_or_create("conv-first")
|
||||||
|
await store.get_or_create("conv-second")
|
||||||
|
# Update conv1 by adding a message
|
||||||
|
await store.add_message("conv-first", "user", "update")
|
||||||
|
convs = await store.list_conversations()
|
||||||
|
assert len(convs) == 2
|
||||||
|
# conv1 was updated more recently
|
||||||
|
assert convs[0].id == "conv-first"
|
||||||
|
|
||||||
|
async def test_list_conversations_limit(self, store: SqliteConversationStore) -> None:
|
||||||
|
for i in range(5):
|
||||||
|
await store.get_or_create(f"conv-{i}")
|
||||||
|
convs = await store.list_conversations(limit=3)
|
||||||
|
assert len(convs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LRU cache eviction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLRUCache:
|
||||||
|
async def test_cache_eviction(self, db_path: str) -> None:
|
||||||
|
"""Cache should evict oldest entries when over limit (data still in SQLite)."""
|
||||||
|
store = SqliteConversationStore(db_path=db_path, max_conversations=3)
|
||||||
|
for i in range(5):
|
||||||
|
await store.get_or_create(f"evict-{i}")
|
||||||
|
# Cache should have at most 3 entries
|
||||||
|
assert len(store._cache) <= 3
|
||||||
|
# But all 5 conversations should be in SQLite
|
||||||
|
convs = await store.list_conversations(limit=10)
|
||||||
|
assert len(convs) == 5
|
||||||
|
await store._close_db()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# restore_from_store (no-op)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRestoreFromStore:
|
||||||
|
async def test_restore_is_noop(self, store: SqliteConversationStore) -> None:
|
||||||
|
"""restore_from_store should be a no-op for SQLite store."""
|
||||||
|
# Should not raise
|
||||||
|
await store.restore_from_store()
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
"""Tests for SpecManager — Spec 文档管理器"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.spec_manager import Spec, SpecManager, SpecStep
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def specs_dir(tmp_path: Path) -> str:
|
||||||
|
"""Provide a temporary directory for spec files."""
|
||||||
|
return str(tmp_path / "specs")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mgr(specs_dir: str) -> SpecManager:
|
||||||
|
"""Create a SpecManager with a temporary directory."""
|
||||||
|
return SpecManager(specs_dir=specs_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def make_spec(spec_id: str = "test-spec", goal: str = "test goal") -> Spec:
|
||||||
|
"""Create a test Spec."""
|
||||||
|
return Spec(
|
||||||
|
spec_id=spec_id,
|
||||||
|
goal=goal,
|
||||||
|
steps=[
|
||||||
|
SpecStep(step_id="s1", name="Step 1", description="First step"),
|
||||||
|
SpecStep(step_id="s2", name="Step 2", description="Second step", dependencies=["s1"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerCreateAndGet:
|
||||||
|
"""Test create and get a spec."""
|
||||||
|
|
||||||
|
def test_create_and_get(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
path = mgr.create(spec)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
loaded = mgr.get(spec.spec_id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.spec_id == spec.spec_id
|
||||||
|
assert loaded.goal == spec.goal
|
||||||
|
assert len(loaded.steps) == 2
|
||||||
|
assert loaded.steps[0].step_id == "s1"
|
||||||
|
assert loaded.steps[1].dependencies == ["s1"]
|
||||||
|
|
||||||
|
def test_create_writes_yaml_file(self, mgr: SpecManager, specs_dir: str):
|
||||||
|
spec = make_spec(spec_id="yaml-test")
|
||||||
|
mgr.create(spec)
|
||||||
|
yaml_path = Path(specs_dir) / "yaml-test.yaml"
|
||||||
|
assert yaml_path.exists()
|
||||||
|
|
||||||
|
def test_get_returns_from_cache(self, mgr: SpecManager):
|
||||||
|
spec = make_spec(spec_id="cached")
|
||||||
|
mgr.create(spec)
|
||||||
|
# Second get should hit cache
|
||||||
|
loaded = mgr.get("cached")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.spec_id == "cached"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerUpdate:
|
||||||
|
"""Test update spec fields."""
|
||||||
|
|
||||||
|
def test_update_goal(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
|
||||||
|
updated = mgr.update(spec.spec_id, goal="new goal")
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.goal == "new goal"
|
||||||
|
|
||||||
|
def test_update_steps(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
|
||||||
|
new_steps = [
|
||||||
|
{"step_id": "s1", "name": "Step 1 Updated", "description": "Updated first step"},
|
||||||
|
]
|
||||||
|
updated = mgr.update(spec.spec_id, steps=new_steps)
|
||||||
|
assert updated is not None
|
||||||
|
assert len(updated.steps) == 1
|
||||||
|
assert updated.steps[0].name == "Step 1 Updated"
|
||||||
|
|
||||||
|
def test_update_metadata(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
|
||||||
|
updated = mgr.update(spec.spec_id, metadata={"key": "value"})
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.metadata == {"key": "value"}
|
||||||
|
|
||||||
|
def test_update_nonexistent_returns_none(self, mgr: SpecManager):
|
||||||
|
result = mgr.update("nonexistent", goal="x")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerConfirm:
|
||||||
|
"""Test confirm sets status and confirmed_at."""
|
||||||
|
|
||||||
|
def test_confirm_sets_status_and_timestamp(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
|
||||||
|
assert spec.status == "draft"
|
||||||
|
assert spec.confirmed_at is None
|
||||||
|
|
||||||
|
confirmed = mgr.confirm(spec.spec_id)
|
||||||
|
assert confirmed is not None
|
||||||
|
assert confirmed.status == "confirmed"
|
||||||
|
assert confirmed.confirmed_at is not None
|
||||||
|
|
||||||
|
def test_confirm_marks_pending_steps_as_confirmed(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
|
||||||
|
confirmed = mgr.confirm(spec.spec_id)
|
||||||
|
assert confirmed is not None
|
||||||
|
for step in confirmed.steps:
|
||||||
|
assert step.status == "confirmed"
|
||||||
|
|
||||||
|
def test_confirm_nonexistent_returns_none(self, mgr: SpecManager):
|
||||||
|
result = mgr.confirm("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerList:
|
||||||
|
"""Test list_specs returns specs sorted by created_at desc."""
|
||||||
|
|
||||||
|
def test_list_specs_sorted_by_created_at_desc(self, mgr: SpecManager):
|
||||||
|
spec_a = make_spec(spec_id="spec-a", goal="A")
|
||||||
|
mgr.create(spec_a)
|
||||||
|
|
||||||
|
# Small delay to ensure different timestamps
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
spec_b = make_spec(spec_id="spec-b", goal="B")
|
||||||
|
mgr.create(spec_b)
|
||||||
|
|
||||||
|
specs = mgr.list_specs()
|
||||||
|
assert len(specs) == 2
|
||||||
|
# Most recent first
|
||||||
|
assert specs[0].spec_id == "spec-b"
|
||||||
|
assert specs[1].spec_id == "spec-a"
|
||||||
|
|
||||||
|
def test_list_specs_filter_by_status(self, mgr: SpecManager):
|
||||||
|
spec_a = make_spec(spec_id="draft-spec")
|
||||||
|
mgr.create(spec_a)
|
||||||
|
|
||||||
|
spec_b = make_spec(spec_id="confirmed-spec")
|
||||||
|
mgr.create(spec_b)
|
||||||
|
mgr.confirm(spec_b.spec_id)
|
||||||
|
|
||||||
|
draft_specs = mgr.list_specs(status="draft")
|
||||||
|
assert len(draft_specs) == 1
|
||||||
|
assert draft_specs[0].spec_id == "draft-spec"
|
||||||
|
|
||||||
|
confirmed_specs = mgr.list_specs(status="confirmed")
|
||||||
|
assert len(confirmed_specs) == 1
|
||||||
|
assert confirmed_specs[0].spec_id == "confirmed-spec"
|
||||||
|
|
||||||
|
def test_list_specs_empty(self, mgr: SpecManager):
|
||||||
|
specs = mgr.list_specs()
|
||||||
|
assert specs == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerDelete:
|
||||||
|
"""Test delete removes the spec."""
|
||||||
|
|
||||||
|
def test_delete_removes_spec(self, mgr: SpecManager):
|
||||||
|
spec = make_spec()
|
||||||
|
mgr.create(spec)
|
||||||
|
assert mgr.get(spec.spec_id) is not None
|
||||||
|
|
||||||
|
result = mgr.delete(spec.spec_id)
|
||||||
|
assert result is True
|
||||||
|
assert mgr.get(spec.spec_id) is None
|
||||||
|
|
||||||
|
def test_delete_removes_yaml_file(self, mgr: SpecManager, specs_dir: str):
|
||||||
|
spec = make_spec(spec_id="delete-me")
|
||||||
|
mgr.create(spec)
|
||||||
|
yaml_path = Path(specs_dir) / "delete-me.yaml"
|
||||||
|
assert yaml_path.exists()
|
||||||
|
|
||||||
|
mgr.delete("delete-me")
|
||||||
|
assert not yaml_path.exists()
|
||||||
|
|
||||||
|
def test_delete_nonexistent_returns_false(self, mgr: SpecManager):
|
||||||
|
result = mgr.delete("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecManagerGetNonExistent:
|
||||||
|
"""Test get non-existent spec returns None."""
|
||||||
|
|
||||||
|
def test_get_nonexistent_returns_none(self, mgr: SpecManager):
|
||||||
|
result = mgr.get("does-not-exist")
|
||||||
|
assert result is None
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""VerificationLoop 单元测试"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.verification_loop import VerificationLoop, VerificationResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerify:
|
||||||
|
"""verify() 方法测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_success(self) -> None:
|
||||||
|
"""成功命令返回 passed=True"""
|
||||||
|
loop = VerificationLoop(commands=["echo ok"], timeout=10.0)
|
||||||
|
result = await loop.verify()
|
||||||
|
assert isinstance(result, VerificationResult)
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.attempts == 1
|
||||||
|
assert "ok" in result.test_output
|
||||||
|
assert result.errors == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_failure(self) -> None:
|
||||||
|
"""失败命令返回 passed=False"""
|
||||||
|
loop = VerificationLoop(commands=["false"], timeout=10.0)
|
||||||
|
result = await loop.verify()
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.attempts == 1
|
||||||
|
assert len(result.errors) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_timeout(self) -> None:
|
||||||
|
"""超时命令返回 passed=False"""
|
||||||
|
loop = VerificationLoop(commands=["sleep 10"], timeout=0.5)
|
||||||
|
result = await loop.verify()
|
||||||
|
assert result.passed is False
|
||||||
|
assert any("timed out" in e for e in result.errors)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_command_not_found(self) -> None:
|
||||||
|
"""不存在的命令返回 passed=False"""
|
||||||
|
loop = VerificationLoop(commands=["nonexistent_command_xyz"], timeout=5.0)
|
||||||
|
result = await loop.verify()
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.errors) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_multiple_commands_partial_failure(self) -> None:
|
||||||
|
"""部分命令失败时整体返回 passed=False"""
|
||||||
|
loop = VerificationLoop(commands=["echo ok", "false"], timeout=10.0)
|
||||||
|
result = await loop.verify()
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.errors) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_default_commands(self) -> None:
|
||||||
|
"""默认命令为 pytest 和 ruff check"""
|
||||||
|
loop = VerificationLoop()
|
||||||
|
assert loop._commands == ["pytest -x -q", "ruff check src/"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyAndRetry:
|
||||||
|
"""verify_and_retry() 方法测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_no_fix_callback(self) -> None:
|
||||||
|
"""无 fix_callback 时重试指定次数"""
|
||||||
|
loop = VerificationLoop(commands=["false"], max_retries=2, timeout=5.0)
|
||||||
|
result = await loop.verify_and_retry()
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.attempts == 3 # 1 initial + 2 retries
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_retries_respected(self) -> None:
|
||||||
|
"""max_retries=0 时不重试"""
|
||||||
|
loop = VerificationLoop(commands=["false"], max_retries=0, timeout=5.0)
|
||||||
|
result = await loop.verify_and_retry()
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.attempts == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_with_fix_callback(self) -> None:
|
||||||
|
"""fix_callback 被调用并接收 errors 和 test_output"""
|
||||||
|
call_count = 0
|
||||||
|
received_args: list[tuple[list[str], str]] = []
|
||||||
|
|
||||||
|
async def fix_cb(errors: list[str], test_output: str) -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
received_args.append((errors, test_output))
|
||||||
|
|
||||||
|
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||||
|
result = await loop.verify_and_retry(fix_callback=fix_cb)
|
||||||
|
assert result.passed is False
|
||||||
|
assert call_count == 1
|
||||||
|
assert len(received_args) == 1
|
||||||
|
assert len(received_args[0][0]) > 0 # errors
|
||||||
|
assert isinstance(received_args[0][1], str) # test_output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_succeeds_after_fix(self) -> None:
|
||||||
|
"""fix_callback 修复后验证成功"""
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
async def fix_cb(errors: list[str], test_output: str) -> None:
|
||||||
|
pass # Simulate fix applied
|
||||||
|
|
||||||
|
# Use a command that always fails — but test that the retry mechanism works
|
||||||
|
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||||
|
result = await loop.verify_and_retry(fix_callback=fix_cb)
|
||||||
|
# false always fails, so result should still be False
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.attempts == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fix_callback_exception_handled(self) -> None:
|
||||||
|
"""fix_callback 抛出异常时不影响重试"""
|
||||||
|
async def bad_fix_cb(errors: list[str], test_output: str) -> None:
|
||||||
|
raise RuntimeError("fix failed!")
|
||||||
|
|
||||||
|
loop = VerificationLoop(commands=["false"], max_retries=1, timeout=5.0)
|
||||||
|
result = await loop.verify_and_retry(fix_callback=bad_fix_cb)
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.attempts == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_and_retry_success_first_try(self) -> None:
|
||||||
|
"""首次验证成功时不重试"""
|
||||||
|
loop = VerificationLoop(commands=["echo ok"], max_retries=3, timeout=10.0)
|
||||||
|
result = await loop.verify_and_retry()
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.attempts == 1
|
||||||
|
|
@ -12,10 +12,8 @@ from fastapi.testclient import TestClient
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
from agentkit.server.app import create_app
|
from agentkit.server.app import create_app
|
||||||
from agentkit.server.routes.portal import (
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||||
CAPABILITY_CATEGORIES,
|
from agentkit.server.routes.portal import CAPABILITY_CATEGORIES
|
||||||
ConversationStore,
|
|
||||||
)
|
|
||||||
from agentkit.skills.base import Skill, SkillConfig
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
from agentkit.skills.registry import SkillRegistry
|
from agentkit.skills.registry import SkillRegistry
|
||||||
from agentkit.tools.registry import ToolRegistry
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
@ -84,81 +82,90 @@ def _register_skill(registry: SkillRegistry, name: str = "chat_skill", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestConversationStore:
|
class TestConversationStore:
|
||||||
def test_get_or_create_new(self):
|
"""Tests for SqliteConversationStore (async, in-memory DB)."""
|
||||||
store = ConversationStore()
|
|
||||||
conv = store.get_or_create()
|
@pytest.fixture
|
||||||
|
def store(self, tmp_path):
|
||||||
|
return SqliteConversationStore(db_path=str(tmp_path / "test.db"))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_or_create_new(self, store):
|
||||||
|
conv = await store.get_or_create()
|
||||||
assert conv.id is not None
|
assert conv.id is not None
|
||||||
assert conv.messages == []
|
assert conv.messages == []
|
||||||
|
|
||||||
def test_get_or_create_with_id(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_get_or_create_with_id(self, store):
|
||||||
conv = store.get_or_create("test-id-123")
|
conv = await store.get_or_create("test-id-123")
|
||||||
assert conv.id == "test-id-123"
|
assert conv.id == "test-id-123"
|
||||||
|
|
||||||
def test_get_or_create_reuse(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_get_or_create_reuse(self, store):
|
||||||
store.get_or_create("reuse-id")
|
await store.get_or_create("reuse-id")
|
||||||
store.add_message("reuse-id", "user", "hello")
|
await store.add_message("reuse-id", "user", "hello")
|
||||||
conv2 = store.get_or_create("reuse-id")
|
conv2 = await store.get_or_create("reuse-id")
|
||||||
assert conv2.id == "reuse-id"
|
assert conv2.id == "reuse-id"
|
||||||
assert len(conv2.messages) == 1
|
assert len(conv2.messages) == 1
|
||||||
|
|
||||||
def test_add_message(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_add_message(self, store):
|
||||||
conv = store.get_or_create("msg-id")
|
conv = await store.get_or_create("msg-id")
|
||||||
msg = store.add_message("msg-id", "user", "hello")
|
msg = await store.add_message("msg-id", "user", "hello")
|
||||||
assert msg.role == "user"
|
assert msg.role == "user"
|
||||||
assert msg.content == "hello"
|
assert msg.content == "hello"
|
||||||
assert len(conv.messages) == 1
|
assert len(conv.messages) == 1
|
||||||
|
|
||||||
def test_add_message_not_found(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_add_message_not_found(self, store):
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
store.add_message("nonexistent", "user", "hello")
|
await store.add_message("nonexistent", "user", "hello")
|
||||||
|
|
||||||
def test_get_history(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_get_history(self, store):
|
||||||
store.get_or_create("hist-id")
|
await store.get_or_create("hist-id")
|
||||||
store.add_message("hist-id", "user", "msg1")
|
await store.add_message("hist-id", "user", "msg1")
|
||||||
store.add_message("hist-id", "assistant", "msg2")
|
await store.add_message("hist-id", "assistant", "msg2")
|
||||||
history = store.get_history("hist-id")
|
history = await store.get_history("hist-id")
|
||||||
assert len(history) == 2
|
assert len(history) == 2
|
||||||
assert history[0].role == "user"
|
assert history[0].role == "user"
|
||||||
assert history[1].role == "assistant"
|
assert history[1].role == "assistant"
|
||||||
|
|
||||||
def test_get_history_limit(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_get_history_limit(self, store):
|
||||||
store.get_or_create("limit-id")
|
await store.get_or_create("limit-id")
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
store.add_message("limit-id", "user", f"msg{i}")
|
await store.add_message("limit-id", "user", f"msg{i}")
|
||||||
history = store.get_history("limit-id", limit=3)
|
history = await store.get_history("limit-id", limit=3)
|
||||||
assert len(history) == 3
|
assert len(history) == 3
|
||||||
assert history[0].content == "msg7"
|
assert history[0].content == "msg7"
|
||||||
|
|
||||||
def test_get_history_nonexistent(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_get_history_nonexistent(self, store):
|
||||||
history = store.get_history("no-such-id")
|
history = await store.get_history("no-such-id")
|
||||||
assert history == []
|
assert history == []
|
||||||
|
|
||||||
def test_list_conversations(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_list_conversations(self, store):
|
||||||
store.get_or_create("conv-a")
|
await store.get_or_create("conv-a")
|
||||||
store.get_or_create("conv-b")
|
await store.get_or_create("conv-b")
|
||||||
convs = store.list_conversations()
|
convs = await store.list_conversations()
|
||||||
assert len(convs) == 2
|
assert len(convs) == 2
|
||||||
|
|
||||||
def test_list_conversations_limit(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore()
|
async def test_list_conversations_limit(self, store):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
store.get_or_create(f"conv-{i}")
|
await store.get_or_create(f"conv-{i}")
|
||||||
convs = store.list_conversations(limit=2)
|
convs = await store.list_conversations(limit=2)
|
||||||
assert len(convs) == 2
|
assert len(convs) == 2
|
||||||
|
|
||||||
def test_max_conversations_eviction(self):
|
@pytest.mark.asyncio
|
||||||
store = ConversationStore(max_conversations=3)
|
async def test_max_conversations_eviction(self, tmp_path):
|
||||||
|
store = SqliteConversationStore(
|
||||||
|
db_path=str(tmp_path / "evict.db"), max_conversations=3
|
||||||
|
)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
store.get_or_create(f"evict-{i}")
|
await store.get_or_create(f"evict-{i}")
|
||||||
assert len(store._conversations) <= 3
|
assert len(store._cache) <= 3
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -178,7 +185,7 @@ class TestPortalChat:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["conversation_id"] is not None
|
assert data["conversation_id"] is not None
|
||||||
assert data["matched_skill"] == "chat_skill"
|
assert data["matched_skill"] == "chat_skill"
|
||||||
assert data["routing_method"] == "direct"
|
assert data["routing_method"] == "skill_prefix"
|
||||||
assert data["confidence"] == 1.0
|
assert data["confidence"] == 1.0
|
||||||
assert data["status"] == "completed"
|
assert data["status"] == "completed"
|
||||||
|
|
||||||
|
|
@ -196,12 +203,13 @@ class TestPortalChat:
|
||||||
assert data["conversation_id"] is not None
|
assert data["conversation_id"] is not None
|
||||||
|
|
||||||
def test_chat_no_skills_available(self, client):
|
def test_chat_no_skills_available(self, client):
|
||||||
|
"""Greeting fast-path works even without skills (DIRECT_CHAT mode)."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/v1/portal/chat",
|
"/api/v1/portal/chat",
|
||||||
json={"message": "hello"},
|
json={"message": "hello"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 400
|
# Greeting regex fast-path: no skill needed, returns 200
|
||||||
assert "No skills available" in response.json()["detail"]
|
assert response.status_code == 200
|
||||||
|
|
||||||
def test_chat_skill_not_found(self, client):
|
def test_chat_skill_not_found(self, client):
|
||||||
response = client.post(
|
response = client.post(
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Covers:
|
||||||
3. ConfigDrivenAgent compressor passthrough
|
3. ConfigDrivenAgent compressor passthrough
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
@ -114,8 +115,11 @@ class TestCreateAppCompression:
|
||||||
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
||||||
mock_create.return_value = None
|
mock_create.return_value = None
|
||||||
|
|
||||||
# No server_config at all
|
# No server_config at all — also prevent auto-discovery of agentkit.yaml
|
||||||
app = create_app()
|
with patch.dict(os.environ, {"AGENTKIT_CONFIG_PATH": ""}, clear=False):
|
||||||
|
# Prevent CWD agentkit.yaml from being auto-loaded
|
||||||
|
with patch("os.path.exists", return_value=False):
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
# create_compressor should not be called (no server_config)
|
# create_compressor should not be called (no server_config)
|
||||||
mock_create.assert_not_called()
|
mock_create.assert_not_called()
|
||||||
|
|
@ -184,6 +188,7 @@ class TestConfigDrivenAgentCompression:
|
||||||
agent._skill_config = skill_config
|
agent._skill_config = skill_config
|
||||||
agent._prompt_template = None
|
agent._prompt_template = None
|
||||||
agent._tools = []
|
agent._tools = []
|
||||||
|
agent._tool_registry = None
|
||||||
agent._memory_retriever = None
|
agent._memory_retriever = None
|
||||||
agent._compressor = mock_compressor
|
agent._compressor = mock_compressor
|
||||||
agent._evolution_enabled = False
|
agent._evolution_enabled = False
|
||||||
|
|
@ -230,6 +235,7 @@ class TestConfigDrivenAgentCompression:
|
||||||
agent._skill_config = skill_config
|
agent._skill_config = skill_config
|
||||||
agent._prompt_template = None
|
agent._prompt_template = None
|
||||||
agent._tools = []
|
agent._tools = []
|
||||||
|
agent._tool_registry = None
|
||||||
agent._memory_retriever = None
|
agent._memory_retriever = None
|
||||||
agent._compressor = None
|
agent._compressor = None
|
||||||
agent._evolution_enabled = False
|
agent._evolution_enabled = False
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from agentkit.telemetry.tracer import (
|
||||||
get_tracer,
|
get_tracer,
|
||||||
init_telemetry,
|
init_telemetry,
|
||||||
)
|
)
|
||||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
|
||||||
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -155,46 +154,6 @@ class TestGetTracer:
|
||||||
assert tracer1 is tracer2
|
assert tracer1 is tracer2
|
||||||
|
|
||||||
|
|
||||||
# ── CostAwareRouter span 测试 ──────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestCostAwareRouterSpan:
|
|
||||||
"""CostAwareRouter 创建 span 并设置属性"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_router_creates_span_with_attributes(self):
|
|
||||||
"""路由时创建 span 并设置 route.layer 和 route.target 属性"""
|
|
||||||
init_telemetry(TelemetryConfig(enabled=False))
|
|
||||||
tracer = get_tracer()
|
|
||||||
|
|
||||||
# 用 mock 替换 start_span 以验证调用
|
|
||||||
mock_span = MagicMock()
|
|
||||||
mock_span.__enter__ = MagicMock(return_value=mock_span)
|
|
||||||
mock_span.__exit__ = MagicMock(return_value=False)
|
|
||||||
mock_span.set_attribute = MagicMock()
|
|
||||||
|
|
||||||
with patch.object(tracer, "start_span", return_value=mock_span):
|
|
||||||
router = CostAwareRouter()
|
|
||||||
result = await router.route(
|
|
||||||
content="你好",
|
|
||||||
skill_registry=MagicMock(),
|
|
||||||
intent_router=MagicMock(),
|
|
||||||
default_tools=[],
|
|
||||||
default_system_prompt="You are helpful.",
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer.start_span.assert_called_once_with("router.route")
|
|
||||||
# 验证 span 设置了 input.length 属性
|
|
||||||
mock_span.set_attribute.assert_any_call("input.length", len("你好"))
|
|
||||||
# 验证 span 设置了 route.layer 和 route.target
|
|
||||||
call_args_list = [
|
|
||||||
(call.args[0], call.args[1])
|
|
||||||
for call in mock_span.set_attribute.call_args_list
|
|
||||||
]
|
|
||||||
assert ("route.layer", "greeting") in call_args_list
|
|
||||||
assert ("route.target", "default") in call_args_list
|
|
||||||
|
|
||||||
|
|
||||||
# ── AlignmentGuard span 测试 ──────────────────────────────
|
# ── AlignmentGuard span 测试 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue