From 5a90824c77c8c6ddf9903984601d9ab51c4fd7df Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 4 Jun 2026 22:39:25 +0800 Subject: [PATCH] feat(core): add ConfigDrivenAgent with YAML-driven agent definition - AgentConfig: YAML/dict config model with validation - ConfigDrivenAgent: 3 task modes (llm_generate, tool_call, custom) - StandaloneRunner: auto-discover YAML configs and build agents - 25 new tests covering all modes and edge cases - Total: 56 tests passing --- src/agentkit/__init__.py | 3 + src/agentkit/core/__init__.py | 3 + src/agentkit/core/config_driven.py | 421 +++++++++++++++++++++++++++++ src/agentkit/core/standalone.py | 173 ++++++++++++ tests/unit/test_config_driven.py | 356 ++++++++++++++++++++++++ 5 files changed, 956 insertions(+) create mode 100644 src/agentkit/core/config_driven.py create mode 100644 src/agentkit/core/standalone.py create mode 100644 tests/unit/test_config_driven.py diff --git a/src/agentkit/__init__.py b/src/agentkit/__init__.py index eda5b75..bf91674 100644 --- a/src/agentkit/__init__.py +++ b/src/agentkit/__init__.py @@ -1,6 +1,7 @@ """Fischer AgentKit - Unified Agent Framework""" from agentkit.core.base import BaseAgent +from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent from agentkit.core.protocol import ( AgentCapability, AgentStatus, @@ -15,6 +16,8 @@ __version__ = "0.1.0" __all__ = [ "BaseAgent", + "AgentConfig", + "ConfigDrivenAgent", "AgentCapability", "AgentStatus", "HandoffMessage", diff --git a/src/agentkit/core/__init__.py b/src/agentkit/core/__init__.py index 6617fdc..d05711f 100644 --- a/src/agentkit/core/__init__.py +++ b/src/agentkit/core/__init__.py @@ -1,6 +1,7 @@ """AgentKit Core - 基础组件""" from agentkit.core.base import BaseAgent +from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent from agentkit.core.exceptions import ( AgentAlreadyRegisteredError, AgentFrameworkError, @@ -33,6 +34,8 @@ from agentkit.core.protocol import ( __all__ = [ "BaseAgent", + "AgentConfig", + "ConfigDrivenAgent", "AgentCapability", "AgentStatus", "AgentFrameworkError", diff --git a/src/agentkit/core/config_driven.py b/src/agentkit/core/config_driven.py new file mode 100644 index 0000000..1b9d766 --- /dev/null +++ b/src/agentkit/core/config_driven.py @@ -0,0 +1,421 @@ +"""ConfigDrivenAgent - 配置驱动的 Agent 定义 + +核心设计: +- 从 YAML/Dict 配置自动组装 Agent(Prompt + LLM + Tool + Memory) +- 支持三种任务模式:llm_generate / tool_call / custom +- 新增 Agent 从写 150 行代码降为 10-20 行配置 +""" + +import logging +from typing import Any, Callable, Coroutine + +import yaml + +from agentkit.core.base import BaseAgent +from agentkit.core.exceptions import ConfigValidationError +from agentkit.core.protocol import AgentCapability, TaskMessage +from agentkit.prompts.section import PromptSection +from agentkit.prompts.template import PromptTemplate +from agentkit.tools.base import Tool +from agentkit.tools.registry import ToolRegistry + +logger = logging.getLogger(__name__) + + +class AgentConfig: + """Agent 配置模型,从 YAML 或 Dict 构建""" + + def __init__( + self, + name: str, + agent_type: str, + version: str = "1.0.0", + description: str = "", + task_mode: str = "llm_generate", + supported_tasks: list[str] | None = None, + max_concurrency: int = 1, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + prompt: dict[str, str] | None = None, + llm: dict[str, Any] | None = None, + tools: list[str] | None = None, + memory: dict[str, Any] | None = None, + custom_handler: str | None = None, + ): + self.name = name + self.agent_type = agent_type + self.version = version + self.description = description + self.task_mode = task_mode + self.supported_tasks = supported_tasks or [agent_type] + self.max_concurrency = max_concurrency + self.input_schema = input_schema + self.output_schema = output_schema + self.prompt = prompt or {} + self.llm = llm or {} + self.tools = tools or [] + self.memory = memory or {} + self.custom_handler = custom_handler + + self._validate() + + def _validate(self) -> None: + """校验配置合法性""" + valid_modes = {"llm_generate", "tool_call", "custom"} + if self.task_mode not in valid_modes: + raise ConfigValidationError( + agent_name=self.name, + key="task_mode", + reason=f"Invalid task_mode '{self.task_mode}', must be one of {valid_modes}", + ) + + if self.task_mode == "llm_generate" and not self.prompt: + raise ConfigValidationError( + agent_name=self.name, + key="prompt", + reason="llm_generate mode requires 'prompt' configuration", + ) + + if self.task_mode == "tool_call" and not self.tools: + raise ConfigValidationError( + agent_name=self.name, + key="tools", + reason="tool_call mode requires at least one tool in 'tools' list", + ) + + if self.task_mode == "custom" and not self.custom_handler: + raise ConfigValidationError( + agent_name=self.name, + key="custom_handler", + reason="custom mode requires 'custom_handler' (dotted path to callable)", + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentConfig": + """从字典创建配置""" + return cls( + name=data["name"], + agent_type=data["agent_type"], + version=data.get("version", "1.0.0"), + description=data.get("description", ""), + task_mode=data.get("task_mode", "llm_generate"), + supported_tasks=data.get("supported_tasks"), + max_concurrency=data.get("max_concurrency", 1), + input_schema=data.get("input_schema"), + output_schema=data.get("output_schema"), + prompt=data.get("prompt"), + llm=data.get("llm"), + tools=data.get("tools"), + memory=data.get("memory"), + custom_handler=data.get("custom_handler"), + ) + + @classmethod + def from_yaml(cls, path: str) -> "AgentConfig": + """从 YAML 文件加载配置""" + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + if not isinstance(data, dict): + raise ConfigValidationError( + agent_name="unknown", + key="config", + reason=f"YAML config must be a mapping, got {type(data)}", + ) + return cls.from_dict(data) + + def to_dict(self) -> dict[str, Any]: + """序列化为字典""" + d = { + "name": self.name, + "agent_type": self.agent_type, + "version": self.version, + "description": self.description, + "task_mode": self.task_mode, + "supported_tasks": self.supported_tasks, + "max_concurrency": self.max_concurrency, + } + if self.input_schema: + d["input_schema"] = self.input_schema + if self.output_schema: + d["output_schema"] = self.output_schema + if self.prompt: + d["prompt"] = self.prompt + if self.llm: + d["llm"] = self.llm + if self.tools: + d["tools"] = self.tools + if self.memory: + d["memory"] = self.memory + if self.custom_handler: + d["custom_handler"] = self.custom_handler + return d + + +class ConfigDrivenAgent(BaseAgent): + """配置驱动的 Agent + + 从 YAML/Dict 配置自动组装,支持三种任务模式: + - llm_generate: 渲染 Prompt → 调用 LLM → 解析 JSON 输出 + - tool_call: 调用注册的 Tool 并返回结果 + - custom: 自定义 handler 函数 + + 示例 YAML 配置:: + + name: content_generator + agent_type: content_generation + task_mode: llm_generate + description: "内容生成 Agent" + prompt: + identity: "你是一个专业的内容生成助手" + instructions: "根据用户需求生成高质量内容" + output_format: "以 JSON 格式输出" + llm: + model: "gpt-4" + temperature: 0.7 + tools: + - retrieve_knowledge + """ + + def __init__( + self, + config: AgentConfig, + tool_registry: ToolRegistry | None = None, + llm_client: Any = None, + custom_handlers: dict[str, Callable[..., Coroutine]] | None = None, + ): + super().__init__( + name=config.name, + agent_type=config.agent_type, + version=config.version, + ) + self._config = config + self._tool_registry = tool_registry or ToolRegistry() + self._llm_client = llm_client + self._custom_handlers = custom_handlers or {} + self._prompt_template: PromptTemplate | None = None + + # 从配置构建 Prompt 模板 + if config.prompt: + sections = PromptSection( + identity=config.prompt.get("identity", ""), + context=config.prompt.get("context", ""), + instructions=config.prompt.get("instructions", ""), + constraints=config.prompt.get("constraints", ""), + output_format=config.prompt.get("output_format", ""), + examples=config.prompt.get("examples", ""), + ) + self._prompt_template = PromptTemplate( + sections=sections, + name=config.name, + version=config.version, + ) + + # 从配置绑定 Tool + self._bind_tools() + + @property + def config(self) -> AgentConfig: + return self._config + + @property + def prompt_template(self) -> PromptTemplate | None: + return self._prompt_template + + def _bind_tools(self) -> None: + """根据配置绑定工具""" + for tool_name in self._config.tools: + try: + tool = self._tool_registry.get(tool_name) + self.use_tool(tool) + logger.info(f"ConfigDrivenAgent '{self.name}' bound tool '{tool_name}'") + except Exception as e: + logger.warning( + f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}" + ) + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=self._config.supported_tasks, + max_concurrency=self._config.max_concurrency, + description=self._config.description, + input_schema=self._config.input_schema, + output_schema=self._config.output_schema, + ) + + async def handle_task(self, task: TaskMessage) -> dict: + """根据 task_mode 执行任务""" + if self._config.task_mode == "llm_generate": + return await self._handle_llm_generate(task) + elif self._config.task_mode == "tool_call": + return await self._handle_tool_call(task) + elif self._config.task_mode == "custom": + return await self._handle_custom(task) + else: + raise ConfigValidationError( + agent_name=self.name, + key="task_mode", + reason=f"Unknown task_mode: {self._config.task_mode}", + ) + + async def _handle_llm_generate(self, task: TaskMessage) -> dict: + """LLM 生成模式:渲染 Prompt → 调用 LLM → 解析输出""" + if not self._prompt_template: + raise ConfigValidationError( + agent_name=self.name, + key="prompt", + reason="llm_generate mode requires prompt configuration", + ) + + # 渲染 Prompt + variables = task.input_data.copy() + variables["task_type"] = task.task_type + messages = self._prompt_template.render(variables=variables) + + # 调用 LLM + if self._llm_client is None: + # 无 LLM 客户端时返回渲染后的 Prompt(降级模式) + return { + "mode": "llm_generate_no_client", + "messages": messages, + "note": "No LLM client configured, returning rendered prompt", + } + + # 使用配置的 LLM 参数 + llm_params = self._config.llm.copy() + response = await self._call_llm(messages, **llm_params) + return self._parse_llm_response(response) + + async def _handle_tool_call(self, task: TaskMessage) -> dict: + """工具调用模式:调用指定 Tool 并返回结果""" + if not self._tools: + raise ConfigValidationError( + agent_name=self.name, + key="tools", + reason="tool_call mode requires at least one bound tool", + ) + + # 使用第一个绑定的工具(或根据 task_type 匹配) + tool = self._resolve_tool(task) + result = await tool.safe_execute(**task.input_data) + return result + + async def _handle_custom(self, task: TaskMessage) -> dict: + """自定义模式:调用注册的 handler 函数""" + handler_name = self._config.custom_handler + if handler_name not in self._custom_handlers: + # 尝试动态导入 + handler = self._import_handler(handler_name) + self._custom_handlers[handler_name] = handler + + handler = self._custom_handlers[handler_name] + result = await handler(task) + if not isinstance(result, dict): + raise ConfigValidationError( + agent_name=self.name, + key="custom_handler", + reason=f"Custom handler '{handler_name}' must return dict, got {type(result)}", + ) + return result + + def _resolve_tool(self, task: TaskMessage) -> Tool: + """根据任务类型解析要使用的工具""" + # 优先匹配 task_type 与 tool name + for tool in self._tools: + if task.task_type in tool.name or task.task_type.replace("_", "-") in tool.name: + return tool + + # 回退到第一个工具 + return self._tools[0] + + async def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str: + """调用 LLM 客户端""" + model = kwargs.pop("model", "gpt-4") + temperature = kwargs.pop("temperature", 0.7) + max_tokens = kwargs.pop("max_tokens", 2000) + + # 标准化 LLM 客户端调用接口 + if hasattr(self._llm_client, "chat"): + response = await self._llm_client.chat( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + elif hasattr(self._llm_client, "create"): + response = await self._llm_client.create( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + elif callable(self._llm_client): + response = await self._llm_client( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + else: + raise ConfigValidationError( + agent_name=self.name, + key="llm_client", + reason="LLM client must have 'chat'/'create' method or be callable", + ) + + # 提取文本内容 + if isinstance(response, str): + return response + if isinstance(response, dict): + return response.get("content", str(response)) + if hasattr(response, "content"): + return response.content + return str(response) + + def _parse_llm_response(self, response: str) -> dict: + """解析 LLM 响应为 dict""" + import json + + # 尝试直接解析 JSON + try: + return json.loads(response) + except (json.JSONDecodeError, TypeError): + pass + + # 尝试提取 JSON 块 + import re + json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except (json.JSONDecodeError, TypeError): + pass + + # 降级:包装为文本结果 + return {"text": response} + + def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]: + """动态导入自定义 handler""" + try: + module_path, func_name = dotted_path.rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + handler = getattr(module, func_name) + if not callable(handler): + raise ConfigValidationError( + agent_name=self.name, + key="custom_handler", + reason=f"'{dotted_path}' is not callable", + ) + return handler + except (ImportError, AttributeError, ValueError) as e: + raise ConfigValidationError( + agent_name=self.name, + key="custom_handler", + reason=f"Failed to import custom handler '{dotted_path}': {e}", + ) diff --git a/src/agentkit/core/standalone.py b/src/agentkit/core/standalone.py new file mode 100644 index 0000000..b08aa92 --- /dev/null +++ b/src/agentkit/core/standalone.py @@ -0,0 +1,173 @@ +"""Standalone Runner - 自动发现并启动配置驱动的 Agent + +扫描 agent_configs/ 目录下的 YAML 文件,自动注册和启动 Agent。 +支持命令行启动:python -m agentkit.core.standalone +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path + +import yaml + +from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.registry import ToolRegistry + +logger = logging.getLogger(__name__) + +DEFAULT_CONFIG_DIR = "agent_configs" + + +class StandaloneRunner: + """自动发现并启动配置驱动的 Agent + + 用法:: + + runner = StandaloneRunner(config_dir="agent_configs") + runner.add_tool(FunctionTool.from_func(my_tool_func)) + await runner.start_all(redis_url="redis://localhost:6379") + """ + + def __init__( + self, + config_dir: str = DEFAULT_CONFIG_DIR, + tool_registry: ToolRegistry | None = None, + llm_client=None, + custom_handlers: dict | None = None, + ): + self._config_dir = config_dir + self._tool_registry = tool_registry or ToolRegistry() + self._llm_client = llm_client + self._custom_handlers = custom_handlers or {} + self._agents: dict[str, ConfigDrivenAgent] = {} + + @property + def agents(self) -> dict[str, ConfigDrivenAgent]: + return self._agents + + def add_tool(self, tool) -> "StandaloneRunner": + """添加工具到注册中心""" + self._tool_registry.register(tool) + return self + + def add_custom_handler(self, name: str, handler) -> "StandaloneRunner": + """注册自定义 handler""" + self._custom_handlers[name] = handler + return self + + def discover_configs(self) -> list[AgentConfig]: + """扫描配置目录,发现所有 YAML 配置""" + configs = [] + config_path = Path(self._config_dir) + + if not config_path.exists(): + logger.warning(f"Config directory '{self._config_dir}' not found") + return configs + + for yaml_file in sorted(config_path.glob("*.yaml")): + try: + config = AgentConfig.from_yaml(str(yaml_file)) + configs.append(config) + logger.info(f"Discovered agent config: {config.name} from {yaml_file.name}") + except Exception as e: + logger.error(f"Failed to load config '{yaml_file}': {e}") + + for yml_file in sorted(config_path.glob("*.yml")): + try: + config = AgentConfig.from_yaml(str(yml_file)) + configs.append(config) + logger.info(f"Discovered agent config: {config.name} from {yml_file.name}") + except Exception as e: + logger.error(f"Failed to load config '{yml_file}': {e}") + + return configs + + def build_agents(self) -> dict[str, ConfigDrivenAgent]: + """从发现的配置构建 Agent 实例""" + configs = self.discover_configs() + self._agents.clear() + + for config in configs: + try: + agent = ConfigDrivenAgent( + config=config, + tool_registry=self._tool_registry, + llm_client=self._llm_client, + custom_handlers=self._custom_handlers, + ) + self._agents[config.name] = agent + logger.info(f"Built agent: {config.name} (mode={config.task_mode})") + except Exception as e: + logger.error(f"Failed to build agent '{config.name}': {e}") + + return self._agents + + async def start_all(self, redis_url: str = "") -> None: + """启动所有已构建的 Agent""" + if not self._agents: + self.build_agents() + + for name, agent in self._agents.items(): + try: + await agent.start(redis_url=redis_url) + logger.info(f"Agent '{name}' started") + except Exception as e: + logger.error(f"Failed to start agent '{name}': {e}") + + async def stop_all(self) -> None: + """停止所有 Agent""" + for name, agent in self._agents.items(): + try: + await agent.stop() + logger.info(f"Agent '{name}' stopped") + except Exception as e: + logger.error(f"Failed to stop agent '{name}': {e}") + + async def execute_task(self, agent_name: str, task) -> dict | None: + """在指定 Agent 上执行任务(本地模式)""" + if agent_name not in self._agents: + logger.error(f"Agent '{agent_name}' not found") + return None + + agent = self._agents[agent_name] + result = await agent.execute(task) + return result.to_dict() + + +async def main(): + """命令行入口""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", + ) + + config_dir = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_CONFIG_DIR + redis_url = os.environ.get("REDIS_URL", "") + + runner = StandaloneRunner(config_dir=config_dir) + agents = runner.build_agents() + + if not agents: + logger.error("No agents discovered. Check your config directory.") + sys.exit(1) + + logger.info(f"Discovered {len(agents)} agent(s): {list(agents.keys())}") + + try: + await runner.start_all(redis_url=redis_url) + logger.info("All agents started. Press Ctrl+C to stop.") + + # 保持运行 + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + finally: + await runner.stop_all() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/unit/test_config_driven.py b/tests/unit/test_config_driven.py new file mode 100644 index 0000000..13b958f --- /dev/null +++ b/tests/unit/test_config_driven.py @@ -0,0 +1,356 @@ +"""U2 测试: ConfigDrivenAgent + YAML 配置驱动""" + +import json +import tempfile +from pathlib import Path + +import pytest +import yaml + +from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent +from agentkit.core.protocol import TaskMessage, TaskStatus +from agentkit.core.standalone import StandaloneRunner +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.registry import ToolRegistry + + +# ── Fixtures ────────────────────────────────────────────── + + +def _make_task(**overrides) -> TaskMessage: + defaults = dict( + task_id="test-task-001", + agent_name="test_agent", + task_type="generate", + priority=1, + input_data={"query": "hello"}, + callback_url=None, + created_at=None, + ) + defaults.update(overrides) + return TaskMessage.from_dict(defaults) + + +def _sample_llm_config() -> dict: + return { + "name": "content_generator", + "agent_type": "content_generation", + "version": "1.0.0", + "description": "内容生成 Agent", + "task_mode": "llm_generate", + "supported_tasks": ["content_generation"], + "max_concurrency": 2, + "prompt": { + "identity": "你是一个专业的内容生成助手", + "instructions": "根据用户需求生成高质量内容", + "output_format": "以 JSON 格式输出 {title, content}", + }, + "llm": { + "model": "gpt-4", + "temperature": 0.7, + }, + } + + +def _sample_tool_call_config() -> dict: + return { + "name": "citation_detector", + "agent_type": "citation_detection", + "task_mode": "tool_call", + "description": "引用检测 Agent", + "tools": ["check_citation"], + } + + +def _sample_custom_config() -> dict: + return { + "name": "monitor", + "agent_type": "monitoring", + "task_mode": "custom", + "description": "监控 Agent", + "custom_handler": "my_handler", + } + + +# ── AgentConfig 测试 ────────────────────────────────────── + + +class TestAgentConfig: + def test_from_dict_llm_generate(self): + config = AgentConfig.from_dict(_sample_llm_config()) + assert config.name == "content_generator" + assert config.task_mode == "llm_generate" + assert config.prompt["identity"] == "你是一个专业的内容生成助手" + assert config.llm["model"] == "gpt-4" + + def test_from_dict_tool_call(self): + config = AgentConfig.from_dict(_sample_tool_call_config()) + assert config.task_mode == "tool_call" + assert config.tools == ["check_citation"] + + def test_from_dict_custom(self): + config = AgentConfig.from_dict(_sample_custom_config()) + assert config.task_mode == "custom" + assert config.custom_handler == "my_handler" + + def test_invalid_task_mode(self): + with pytest.raises(Exception, match="Invalid task_mode"): + AgentConfig(name="x", agent_type="x", task_mode="invalid_mode") + + def test_llm_generate_requires_prompt(self): + with pytest.raises(Exception, match="llm_generate mode requires"): + AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None) + + def test_tool_call_requires_tools(self): + with pytest.raises(Exception, match="tool_call mode requires"): + AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[]) + + def test_custom_requires_handler(self): + with pytest.raises(Exception, match="custom mode requires"): + AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None) + + def test_from_yaml(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + yaml.dump(_sample_llm_config(), f) + f.flush() + config = AgentConfig.from_yaml(f.name) + + assert config.name == "content_generator" + assert config.task_mode == "llm_generate" + + def test_to_dict_roundtrip(self): + original = _sample_llm_config() + config = AgentConfig.from_dict(original) + result = config.to_dict() + assert result["name"] == original["name"] + assert result["task_mode"] == original["task_mode"] + assert result["prompt"] == original["prompt"] + + +# ── ConfigDrivenAgent 测试 ──────────────────────────────── + + +class TestConfigDrivenAgent: + async def test_llm_generate_no_client(self): + """无 LLM 客户端时降级返回渲染后的 Prompt""" + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["mode"] == "llm_generate_no_client" + assert len(result["messages"]) == 2 # system + user + + async def test_llm_generate_with_client(self): + """有 LLM 客户端时调用 LLM 并解析结果""" + + class MockLLMClient: + async def chat(self, messages, **kwargs): + return json.dumps({"title": "Test", "content": "Hello world"}) + + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient()) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["title"] == "Test" + assert result["content"] == "Hello world" + + async def test_llm_generate_with_markdown_json(self): + """LLM 返回 markdown 代码块包裹的 JSON""" + + class MockLLMClient: + async def chat(self, messages, **kwargs): + return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```' + + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient()) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["title"] == "Wrapped" + + async def test_llm_generate_fallback_text(self): + """LLM 返回非 JSON 时降级为文本""" + + class MockLLMClient: + async def chat(self, messages, **kwargs): + return "This is plain text response" + + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient()) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["text"] == "This is plain text response" + + async def test_tool_call_mode(self): + """tool_call 模式调用注册的 Tool""" + registry = ToolRegistry() + + async def check_citation(url: str, **kwargs) -> dict: + return {"found": True, "url": url} + + tool = FunctionTool( + name="check_citation", + description="Check citation", + func=check_citation, + ) + registry.register(tool) + + config = AgentConfig.from_dict(_sample_tool_call_config()) + agent = ConfigDrivenAgent(config=config, tool_registry=registry) + + task = _make_task(input_data={"url": "https://example.com"}) + result = await agent.handle_task(task) + + assert result["found"] is True + assert result["url"] == "https://example.com" + + async def test_custom_mode(self): + """custom 模式调用自定义 handler""" + config = AgentConfig.from_dict(_sample_custom_config()) + + async def my_handler(task): + return {"status": "monitored", "task_id": task.task_id} + + agent = ConfigDrivenAgent( + config=config, + custom_handlers={"my_handler": my_handler}, + ) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["status"] == "monitored" + assert result["task_id"] == "test-task-001" + + async def test_execute_wraps_task_result(self): + """execute() 自动包装 handle_task 结果为 TaskResult""" + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config) + + task = _make_task() + result = await agent.execute(task) + + assert result.status == TaskStatus.COMPLETED + assert result.output_data is not None + assert result.metrics["elapsed_seconds"] >= 0 + + def test_get_capabilities(self): + """能力声明从配置正确构建""" + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config) + cap = agent.get_capabilities() + + assert cap.agent_name == "content_generator" + assert cap.agent_type == "content_generation" + assert cap.max_concurrency == 2 + assert "content_generation" in cap.supported_tasks + + def test_prompt_template_rendering(self): + """Prompt 模板正确渲染""" + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config) + + assert agent.prompt_template is not None + messages = agent.prompt_template.render(variables={"query": "test"}) + assert len(messages) == 2 + assert "专业的内容生成助手" in messages[0]["content"] + + async def test_callable_llm_client(self): + """LLM 客户端为可调用对象""" + + async def mock_llm(messages, **kwargs): + return '{"result": "from_callable"}' + + config = AgentConfig.from_dict(_sample_llm_config()) + agent = ConfigDrivenAgent(config=config, llm_client=mock_llm) + + task = _make_task() + result = await agent.handle_task(task) + + assert result["result"] == "from_callable" + + +# ── StandaloneRunner 测试 ───────────────────────────────── + + +class TestStandaloneRunner: + def test_discover_configs(self): + """自动发现 YAML 配置""" + with tempfile.TemporaryDirectory() as tmpdir: + for name in ["agent_a.yaml", "agent_b.yml"]: + config = { + "name": name.replace(".", "_"), + "agent_type": "test", + "task_mode": "llm_generate", + "prompt": {"identity": "test", "instructions": "test"}, + } + with open(Path(tmpdir) / name, "w") as f: + yaml.dump(config, f) + + runner = StandaloneRunner(config_dir=tmpdir) + configs = runner.discover_configs() + + assert len(configs) == 2 + + def test_build_agents(self): + """从配置构建 Agent 实例""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _sample_llm_config() + with open(Path(tmpdir) / "test.yaml", "w") as f: + yaml.dump(config, f) + + runner = StandaloneRunner(config_dir=tmpdir) + agents = runner.build_agents() + + assert "content_generator" in agents + assert agents["content_generator"].config.task_mode == "llm_generate" + + def test_add_tool(self): + """添加工具到注册中心""" + runner = StandaloneRunner() + + async def my_tool(x: int) -> dict: + return {"doubled": x * 2} + + tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool) + runner.add_tool(tool) + + assert runner._tool_registry.has_tool("my_tool") + + async def test_execute_task_local(self): + """本地模式执行任务""" + with tempfile.TemporaryDirectory() as tmpdir: + config = _sample_llm_config() + with open(Path(tmpdir) / "test.yaml", "w") as f: + yaml.dump(config, f) + + runner = StandaloneRunner(config_dir=tmpdir) + runner.build_agents() + + task = _make_task(agent_name="content_generator") + result = await runner.execute_task("content_generator", task) + + assert result is not None + assert result["status"] == "completed" + + def test_empty_config_dir(self): + """空配置目录不报错""" + with tempfile.TemporaryDirectory() as tmpdir: + runner = StandaloneRunner(config_dir=tmpdir) + configs = runner.discover_configs() + assert len(configs) == 0 + + def test_nonexistent_config_dir(self): + """不存在的配置目录不报错""" + runner = StandaloneRunner(config_dir="/nonexistent/path") + configs = runner.discover_configs() + assert len(configs) == 0