diff --git a/configs/pipelines/geo_full_pipeline.yaml b/configs/pipelines/geo_full_pipeline.yaml index 2ed1e55..9d78d24 100644 --- a/configs/pipelines/geo_full_pipeline.yaml +++ b/configs/pipelines/geo_full_pipeline.yaml @@ -1,5 +1,5 @@ name: geo_full_pipeline -description: "GEO 端到端工作流:检测→分析→优化→追踪" +description: "GEO 端到端工作流:检测→分析→优化→Schema→内容生成→去AI化→追踪" steps: - name: detect @@ -35,6 +35,20 @@ steps: optimization: $.steps.optimize.output depends_on: [optimize] + - name: generate_content + skill: content_generator + input_mapping: + brand: $.input.brand + optimization: $.steps.optimize.output + schema: $.steps.schema.output + depends_on: [schema] + + - name: deai + skill: deai_agent + input_mapping: + content: $.steps.generate_content.output + depends_on: [generate_content] + - name: monitor skill: monitor input_mapping: diff --git a/configs/skills/citation_detector.yaml b/configs/skills/citation_detector.yaml index 25def00..285720b 100644 --- a/configs/skills/citation_detector.yaml +++ b/configs/skills/citation_detector.yaml @@ -47,6 +47,8 @@ output_schema: tools: - execute_single_platform - get_or_create_task + - baidu_search + - web_crawl memory: working: diff --git a/configs/skills/competitor_analyzer.yaml b/configs/skills/competitor_analyzer.yaml index 9397612..43368d2 100644 --- a/configs/skills/competitor_analyzer.yaml +++ b/configs/skills/competitor_analyzer.yaml @@ -47,6 +47,8 @@ output_schema: tools: - competitor_analyze - competitor_gap_analysis + - baidu_search + - web_crawl memory: working: diff --git a/configs/skills/content_generator.yaml b/configs/skills/content_generator.yaml index 3b88414..c8c6081 100644 --- a/configs/skills/content_generator.yaml +++ b/configs/skills/content_generator.yaml @@ -93,6 +93,7 @@ llm: tools: - retrieve_knowledge + - baidu_search quality_gate: required_fields: ["content"] diff --git a/configs/skills/geo_optimizer.yaml b/configs/skills/geo_optimizer.yaml index ceccb3e..389a73b 100644 --- a/configs/skills/geo_optimizer.yaml +++ b/configs/skills/geo_optimizer.yaml @@ -68,7 +68,8 @@ llm: temperature: 0.5 max_tokens: 8000 -tools: [] +tools: + - schema_generate quality_gate: required_fields: ["optimized_content"] diff --git a/configs/skills/monitor.yaml b/configs/skills/monitor.yaml index dab88ef..3dc599c 100644 --- a/configs/skills/monitor.yaml +++ b/configs/skills/monitor.yaml @@ -46,6 +46,7 @@ tools: - monitor_check_and_compare - monitor_generate_report - monitor_create_record + - baidu_search memory: working: diff --git a/configs/skills/schema_advisor.yaml b/configs/skills/schema_advisor.yaml index 88dc0ca..6da2166 100644 --- a/configs/skills/schema_advisor.yaml +++ b/configs/skills/schema_advisor.yaml @@ -40,6 +40,8 @@ output_schema: tools: - fill_schema_with_llm + - schema_extract + - schema_generate memory: working: diff --git a/configs/skills/trend_agent.yaml b/configs/skills/trend_agent.yaml index 075a158..89c42c3 100644 --- a/configs/skills/trend_agent.yaml +++ b/configs/skills/trend_agent.yaml @@ -52,6 +52,8 @@ output_schema: tools: - trend_insight - trend_hotspot + - baidu_search + - web_crawl memory: working: diff --git a/src/agentkit/tools/baidu_search.py b/src/agentkit/tools/baidu_search.py new file mode 100644 index 0000000..87dea84 --- /dev/null +++ b/src/agentkit/tools/baidu_search.py @@ -0,0 +1,223 @@ +"""BaiduSearchTool - 百度搜索工具,支持优雅降级 + +通过百度搜索 API 执行关键词搜索,返回搜索结果列表。 +当百度搜索 API 不可用时,返回包含降级提示的错误信息。 +""" + +import json +import logging +import urllib.parse +import urllib.request +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class BaiduSearchTool(Tool): + """百度搜索工具 - 执行关键词搜索,返回搜索结果 + + 支持两种模式: + 1. 百度搜索 API(需要 API key 配置) + 2. 直接抓取百度搜索结果页(降级模式,无需 API key) + + 当两种模式都不可用时,返回包含降级提示的错误信息。 + """ + + def __init__( + self, + name: str = "baidu_search", + description: str = "执行百度搜索,返回搜索结果列表", + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + api_key: str | None = None, + api_url: str | None = None, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["search", "baidu"], + ) + self._api_key = api_key + self._api_url = api_url + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词", + }, + "max_results": { + "type": "integer", + "description": "最大返回结果数", + "default": 5, + }, + }, + "required": ["query"], + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "url": {"type": "string"}, + "snippet": {"type": "string"}, + }, + }, + "description": "搜索结果列表", + }, + "total": {"type": "integer", "description": "结果总数"}, + "success": {"type": "boolean", "description": "是否成功"}, + "error": {"type": "string", "description": "错误信息(仅失败时)"}, + }, + } + + async def execute(self, **kwargs) -> dict: + """执行百度搜索 + + Args: + query: 搜索关键词(必需) + max_results: 最大返回结果数(默认 5) + + Returns: + 包含 results 列表和 success 布尔值的字典 + """ + query = kwargs.get("query") + if not query: + return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False} + + max_results = kwargs.get("max_results", 5) + + # 优先使用 API 模式 + if self._api_key and self._api_url: + return await self._search_via_api(query, max_results) + + # 降级:直接抓取百度搜索结果页 + return await self._search_via_scrape(query, max_results) + + async def _search_via_api(self, query: str, max_results: int) -> dict: + """通过百度搜索 API 执行搜索""" + try: + params = { + "query": query, + "num": max_results, + } + url = f"{self._api_url}?{urllib.parse.urlencode(params)}" + req = urllib.request.Request( + url, + headers={ + "User-Agent": "AgentKit/1.0", + "Authorization": f"Bearer {self._api_key}", + }, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + data = json.loads(resp.read().decode("utf-8")) + + results = [] + for item in data.get("results", [])[:max_results]: + results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("snippet", ""), + }) + + return {"results": results, "total": len(results), "success": True} + + except Exception as e: + logger.error(f"BaiduSearchTool API 搜索失败: {e}") + # 降级到抓取模式 + return await self._search_via_scrape(query, max_results) + + async def _search_via_scrape(self, query: str, max_results: int) -> dict: + """通过直接抓取百度搜索结果页执行搜索(降级模式)""" + try: + encoded_query = urllib.parse.quote(query) + url = f"https://www.baidu.com/s?wd={encoded_query}&rn={max_results}" + req = urllib.request.Request( + url, + headers={ + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ), + }, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + html = resp.read().decode("utf-8", errors="replace") + + # 简单解析搜索结果(基于百度搜索结果页 HTML 结构) + results = self._parse_baidu_html(html, max_results) + + return {"results": results, "total": len(results), "success": True} + + except Exception as e: + logger.error(f"BaiduSearchTool 抓取搜索失败: {e}") + return { + "error": f"百度搜索不可用: {e}", + "results": [], + "total": 0, + "success": False, + } + + @staticmethod + def _parse_baidu_html(html: str, max_results: int) -> list[dict[str, str]]: + """解析百度搜索结果页 HTML,提取标题、URL、摘要 + + 注意:百度 HTML 结构可能变化,此解析器尽力提取关键信息。 + """ + import re + + results: list[dict[str, str]] = [] + + # 匹配百度搜索结果块 + # 百度搜索结果通常在
中 + pattern = re.compile( + r']*class="[^"]*t[^"]*"[^>]*>.*?href="([^"]*)"[^>]*>(.*?)', + re.DOTALL, + ) + snippet_pattern = re.compile( + r']*class="[^"]*content-right_[^"]*"[^>]*>(.*?)', + re.DOTALL, + ) + + for match in pattern.finditer(html): + if len(results) >= max_results: + break + + url = match.group(1) + title = re.sub(r"<[^>]+>", "", match.group(2)).strip() + + # 跳过百度内部链接 + if "baidu.com/link?" not in url and not url.startswith("http"): + continue + + # 尝试提取摘要 + snippet = "" + snippet_match = snippet_pattern.search(html[match.end():match.end() + 2000]) + if snippet_match: + snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip() + + results.append({ + "title": title, + "url": url, + "snippet": snippet[:200] if snippet else "", + }) + + return results diff --git a/tests/integration/test_geo_e2e.py b/tests/integration/test_geo_e2e.py new file mode 100644 index 0000000..2c7e174 --- /dev/null +++ b/tests/integration/test_geo_e2e.py @@ -0,0 +1,558 @@ +"""GEO Skill 工具绑定与端到端验证 — U4 集成测试 + +验证: +- SkillConfig with tools 字段加载正确 +- ConfigDrivenAgent 从 ToolRegistry 注册声明的工具 +- citation_detector 绑定 search + crawl 工具 +- competitor_analyzer 绑定 search + crawl 工具 +- geo_optimizer 绑定 schema_generate 工具 +- schema_advisor 绑定 extract + generate 工具 +- Tool 不在 ToolRegistry 中时优雅降级(log warning, skip) +- GEO Pipeline 配置加载正确 +""" + +import os +from unittest.mock import AsyncMock + +import pytest +import yaml + +from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent +from agentkit.skills.base import Skill, SkillConfig +from agentkit.skills.loader import SkillLoader +from agentkit.skills.registry import SkillRegistry +from agentkit.tools.baidu_search import BaiduSearchTool +from agentkit.tools.base import Tool +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.registry import ToolRegistry +from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool +from agentkit.tools.web_crawl import WebCrawlTool + + +# ── Fixtures ──────────────────────────────────────────────── + +CONFIGS_DIR = os.path.join( + os.path.dirname(__file__), "..", "..", "configs" +) +SKILLS_DIR = os.path.join(CONFIGS_DIR, "skills") +PIPELINES_DIR = os.path.join(CONFIGS_DIR, "pipelines") + + +@pytest.fixture +def tool_registry_with_infra_tools(): + """创建包含基础设施工具的 ToolRegistry""" + registry = ToolRegistry() + registry.register(BaiduSearchTool()) + registry.register(WebCrawlTool()) + registry.register(SchemaExtractTool()) + registry.register(SchemaGenerateTool()) + return registry + + +@pytest.fixture +def tool_registry_empty(): + """创建空的 ToolRegistry(用于测试工具不可用时的降级)""" + return ToolRegistry() + + +# ── Test: SkillConfig tools 字段加载 ──────────────────────── + + +class TestSkillConfigToolsField: + """验证 SkillConfig 的 tools 字段正确加载""" + + def test_citation_detector_tools_loaded(self): + """citation_detector YAML 加载后 tools 包含 baidu_search + web_crawl""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "citation_detector.yaml") + ) + assert "baidu_search" in config.tools + assert "web_crawl" in config.tools + # 原有业务工具也保留 + assert "execute_single_platform" in config.tools + assert "get_or_create_task" in config.tools + + def test_competitor_analyzer_tools_loaded(self): + """competitor_analyzer YAML 加载后 tools 包含 baidu_search + web_crawl""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "competitor_analyzer.yaml") + ) + assert "baidu_search" in config.tools + assert "web_crawl" in config.tools + assert "competitor_analyze" in config.tools + + def test_geo_optimizer_tools_loaded(self): + """geo_optimizer YAML 加载后 tools 包含 schema_generate""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "geo_optimizer.yaml") + ) + assert "schema_generate" in config.tools + + def test_schema_advisor_tools_loaded(self): + """schema_advisor YAML 加载后 tools 包含 schema_extract + schema_generate""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "schema_advisor.yaml") + ) + assert "schema_extract" in config.tools + assert "schema_generate" in config.tools + assert "fill_schema_with_llm" in config.tools + + def test_monitor_tools_loaded(self): + """monitor YAML 加载后 tools 包含 baidu_search""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "monitor.yaml") + ) + assert "baidu_search" in config.tools + assert "monitor_check_and_compare" in config.tools + + def test_trend_agent_tools_loaded(self): + """trend_agent YAML 加载后 tools 包含 baidu_search + web_crawl""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "trend_agent.yaml") + ) + assert "baidu_search" in config.tools + assert "web_crawl" in config.tools + + def test_content_generator_tools_loaded(self): + """content_generator YAML 加载后 tools 包含 baidu_search""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "content_generator.yaml") + ) + assert "baidu_search" in config.tools + assert "retrieve_knowledge" in config.tools + + def test_deai_agent_tools_loaded(self): + """deai_agent YAML 加载后 tools 包含 detect_ai_patterns""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "deai_agent.yaml") + ) + assert "detect_ai_patterns" in config.tools + + def test_all_skills_load_without_error(self): + """所有 GEO Skill YAML 都能成功加载""" + yaml_files = [ + "citation_detector.yaml", + "competitor_analyzer.yaml", + "geo_optimizer.yaml", + "monitor.yaml", + "schema_advisor.yaml", + "trend_agent.yaml", + "content_generator.yaml", + "deai_agent.yaml", + ] + for filename in yaml_files: + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, filename) + ) + assert config.name, f"{filename} should have a name" + assert config.tools is not None, f"{filename} should have tools field" + + +# ── Test: ConfigDrivenAgent 工具绑定 ──────────────────────── + + +class TestConfigDrivenAgentToolBinding: + """验证 ConfigDrivenAgent 从 ToolRegistry 注册声明的工具""" + + def test_citation_detector_binds_search_and_crawl( + self, tool_registry_with_infra_tools + ): + """citation_detector 绑定 baidu_search + web_crawl 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "citation_detector.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "baidu_search" in tool_names + assert "web_crawl" in tool_names + + def test_competitor_analyzer_binds_search_and_crawl( + self, tool_registry_with_infra_tools + ): + """competitor_analyzer 绑定 baidu_search + web_crawl 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "competitor_analyzer.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "baidu_search" in tool_names + assert "web_crawl" in tool_names + + def test_geo_optimizer_binds_schema_generate( + self, tool_registry_with_infra_tools + ): + """geo_optimizer 绑定 schema_generate 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "geo_optimizer.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "schema_generate" in tool_names + + def test_schema_advisor_binds_extract_and_generate( + self, tool_registry_with_infra_tools + ): + """schema_advisor 绑定 schema_extract + schema_generate 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "schema_advisor.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "schema_extract" in tool_names + assert "schema_generate" in tool_names + + def test_monitor_binds_search( + self, tool_registry_with_infra_tools + ): + """monitor 绑定 baidu_search 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "monitor.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "baidu_search" in tool_names + + def test_trend_agent_binds_search_and_crawl( + self, tool_registry_with_infra_tools + ): + """trend_agent 绑定 baidu_search + web_crawl 工具""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "trend_agent.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_with_infra_tools, + ) + tool_names = [t.name for t in agent.get_tools()] + assert "baidu_search" in tool_names + assert "web_crawl" in tool_names + + +# ── Test: 工具不可用时优雅降级 ────────────────────────────── + + +class TestToolNotFoundGracefulDegradation: + """验证 Tool 不在 ToolRegistry 中时优雅降级""" + + def test_missing_tool_does_not_crash(self, tool_registry_empty): + """Tool 不在 ToolRegistry 中时 Agent 不会崩溃""" + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "citation_detector.yaml") + ) + # citation_detector 声明了 baidu_search, web_crawl, execute_single_platform, get_or_create_task + # 这些工具都不在空 registry 中 + agent = ConfigDrivenAgent( + config=config, + tool_registry=tool_registry_empty, + ) + # Agent 应该成功创建,只是没有绑定任何工具 + assert agent is not None + assert len(agent.get_tools()) == 0 + + def test_partial_tool_binding(self): + """部分工具在 Registry 中时,只绑定可用的工具""" + registry = ToolRegistry() + registry.register(BaiduSearchTool()) + # 只注册了 baidu_search,没有 web_crawl + + config = SkillConfig.from_yaml( + os.path.join(SKILLS_DIR, "citation_detector.yaml") + ) + agent = ConfigDrivenAgent( + config=config, + tool_registry=registry, + ) + tool_names = [t.name for t in agent.get_tools()] + # baidu_search 应该绑定成功 + assert "baidu_search" in tool_names + # web_crawl 不在 registry 中,不应该绑定 + assert "web_crawl" not in tool_names + + +# ── Test: SkillLoader 批量加载 ────────────────────────────── + + +class TestSkillLoaderBatchLoad: + """验证 SkillLoader 从目录批量加载 Skill 并绑定工具""" + + def test_load_all_geo_skills(self, tool_registry_with_infra_tools): + """从 skills 目录加载所有 GEO Skill""" + skill_registry = SkillRegistry() + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry_with_infra_tools, + ) + skills = loader.load_from_directory(SKILLS_DIR) + + # 验证所有 Skill 都加载成功 + skill_names = [s.name for s in skills] + assert "citation_detector" in skill_names + assert "competitor_analyzer" in skill_names + assert "geo_optimizer" in skill_names + assert "monitor" in skill_names + assert "schema_advisor" in skill_names + assert "trend_agent" in skill_names + assert "content_generator" in skill_names + assert "deai_agent" in skill_names + + def test_citation_detector_skill_has_tools( + self, tool_registry_with_infra_tools + ): + """citation_detector Skill 绑定了 search + crawl 工具""" + skill_registry = SkillRegistry() + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry_with_infra_tools, + ) + loader.load_from_directory(SKILLS_DIR) + + skill = skill_registry.get("citation_detector") + tool_names = [t.name for t in skill.tools] + assert "baidu_search" in tool_names + assert "web_crawl" in tool_names + + def test_schema_advisor_skill_has_tools( + self, tool_registry_with_infra_tools + ): + """schema_advisor Skill 绑定了 extract + generate 工具""" + skill_registry = SkillRegistry() + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry_with_infra_tools, + ) + loader.load_from_directory(SKILLS_DIR) + + skill = skill_registry.get("schema_advisor") + tool_names = [t.name for t in skill.tools] + assert "schema_extract" in tool_names + assert "schema_generate" in tool_names + + def test_geo_optimizer_skill_has_schema_generate( + self, tool_registry_with_infra_tools + ): + """geo_optimizer Skill 绑定了 schema_generate 工具""" + skill_registry = SkillRegistry() + loader = SkillLoader( + skill_registry=skill_registry, + tool_registry=tool_registry_with_infra_tools, + ) + loader.load_from_directory(SKILLS_DIR) + + skill = skill_registry.get("geo_optimizer") + tool_names = [t.name for t in skill.tools] + assert "schema_generate" in tool_names + + +# ── Test: GEO Pipeline 配置加载 ────────────────────────────── + + +class TestGEOPipelineConfig: + """验证 GEO Pipeline 配置加载正确""" + + def test_pipeline_config_loads(self): + """geo_full_pipeline.yaml 能成功加载""" + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + assert config["name"] == "geo_full_pipeline" + assert len(config["steps"]) > 0 + + def test_pipeline_has_all_steps(self): + """Pipeline 包含所有 GEO 步骤""" + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + + step_names = [s["name"] for s in config["steps"]] + # 核心步骤 + assert "detect" in step_names + assert "analyze_competitor" in step_names + assert "analyze_trend" in step_names + assert "optimize" in step_names + assert "schema" in step_names + assert "monitor" in step_names + # 新增步骤 + assert "generate_content" in step_names + assert "deai" in step_names + + def test_pipeline_step_skills_match_yaml_names(self): + """Pipeline 步骤的 skill 字段与 YAML 文件中的 name 一致""" + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + + for step in config["steps"]: + skill_name = step["skill"] + yaml_path = os.path.join(SKILLS_DIR, f"{skill_name}.yaml") + assert os.path.exists(yaml_path), ( + f"Pipeline step '{step['name']}' references skill " + f"'{skill_name}' but {yaml_path} does not exist" + ) + + def test_pipeline_dependency_graph_is_valid(self): + """Pipeline 依赖关系有效(无循环依赖)""" + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + + step_map = {s["name"]: s.get("depends_on", []) for s in config["steps"]} + + # 拓扑排序检测循环依赖 + visited = set() + in_stack = set() + + def dfs(name): + if name in in_stack: + return False # 循环依赖 + if name in visited: + return True + in_stack.add(name) + for dep in step_map.get(name, []): + if not dfs(dep): + return False + in_stack.discard(name) + visited.add(name) + return True + + for name in step_map: + assert dfs(name), f"Circular dependency detected involving '{name}'" + + def test_pipeline_from_config_creates_pipeline(self): + """GEOPipeline.from_config 能从 YAML 配置创建 Pipeline""" + from agentkit.skills.geo_pipeline import GEOPipeline + + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + + pipeline = GEOPipeline.from_config(config) + assert pipeline.name == "geo_full_pipeline" + assert len(pipeline._steps) == len(config["steps"]) + + def test_pipeline_execution_order_respects_dependencies(self): + """Pipeline 执行顺序尊重依赖关系""" + from agentkit.skills.geo_pipeline import GEOPipeline + + with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f: + config = yaml.safe_load(f) + + pipeline = GEOPipeline.from_config(config) + groups = pipeline._build_execution_groups() + + # 展平执行顺序 + executed = set() + for group in groups: + for step_name in group: + step = pipeline._step_map[step_name] + # 所有依赖必须已经执行 + for dep in step.depends_on: + assert dep in executed, ( + f"Step '{step_name}' depends on '{dep}' " + f"but '{dep}' hasn't been executed yet" + ) + executed.add(step_name) + + # 所有步骤都应该被执行 + assert executed == set(pipeline._step_map.keys()) + + +# ── Test: 基础设施工具实例化 ────────────────────────────────── + + +class TestInfrastructureToolsInstantiation: + """验证基础设施工具能正确实例化""" + + def test_baidu_search_tool_instantiation(self): + """BaiduSearchTool 能正确实例化""" + tool = BaiduSearchTool() + assert tool.name == "baidu_search" + assert "search" in tool.tags + + def test_web_crawl_tool_instantiation(self): + """WebCrawlTool 能正确实例化""" + tool = WebCrawlTool() + assert tool.name == "web_crawl" + assert "crawl" in tool.tags + + def test_schema_extract_tool_instantiation(self): + """SchemaExtractTool 能正确实例化""" + tool = SchemaExtractTool() + assert tool.name == "schema_extract" + assert "extraction" in tool.tags + + def test_schema_generate_tool_instantiation(self): + """SchemaGenerateTool 能正确实例化""" + tool = SchemaGenerateTool() + assert tool.name == "schema_generate" + assert "generation" in tool.tags + + def test_all_infra_tools_registered_in_registry(self): + """所有基础设施工具都能注册到 ToolRegistry""" + registry = ToolRegistry() + registry.register(BaiduSearchTool()) + registry.register(WebCrawlTool()) + registry.register(SchemaExtractTool()) + registry.register(SchemaGenerateTool()) + + assert registry.has_tool("baidu_search") + assert registry.has_tool("web_crawl") + assert registry.has_tool("schema_extract") + assert registry.has_tool("schema_generate") + + +# ── Test: AgentConfig tools 字段向后兼容 ────────────────────── + + +class TestAgentConfigToolsBackwardCompat: + """验证 AgentConfig 的 tools 字段向后兼容""" + + def test_agent_config_with_tools_list(self): + """AgentConfig 接受 tools 列表""" + config = AgentConfig( + name="test", + agent_type="test", + task_mode="tool_call", + tools=["baidu_search", "web_crawl"], + ) + assert config.tools == ["baidu_search", "web_crawl"] + + def test_agent_config_without_tools(self): + """AgentConfig 不提供 tools 时默认为空列表""" + config = AgentConfig( + name="test", + agent_type="test", + task_mode="llm_generate", + prompt={"identity": "test", "instructions": "test"}, + ) + assert config.tools == [] + + def test_skill_config_inherits_tools(self): + """SkillConfig 继承 AgentConfig 的 tools 字段""" + config = SkillConfig( + name="test", + agent_type="test", + task_mode="tool_call", + tools=["baidu_search"], + ) + assert config.tools == ["baidu_search"] + + def test_skill_config_from_dict_with_tools(self): + """SkillConfig.from_dict 正确解析 tools 字段""" + data = { + "name": "test", + "agent_type": "test", + "task_mode": "tool_call", + "tools": ["baidu_search", "web_crawl", "schema_generate"], + } + config = SkillConfig.from_dict(data) + assert config.tools == ["baidu_search", "web_crawl", "schema_generate"]