feat(geo): U4 GEO skill tool binding with BaiduSearch and E2E tests
Add BaiduSearchTool (API mode + scraping fallback), bind tools to GEO skill YAML configs (baidu_search, web_crawl, schema_extract, schema_generate), extend geo_full_pipeline with generate_content and deai steps, add 36 E2E integration tests.
This commit is contained in:
parent
9ec1740047
commit
2e547e345a
|
|
@ -1,5 +1,5 @@
|
||||||
name: geo_full_pipeline
|
name: geo_full_pipeline
|
||||||
description: "GEO 端到端工作流:检测→分析→优化→追踪"
|
description: "GEO 端到端工作流:检测→分析→优化→Schema→内容生成→去AI化→追踪"
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: detect
|
- name: detect
|
||||||
|
|
@ -35,6 +35,20 @@ steps:
|
||||||
optimization: $.steps.optimize.output
|
optimization: $.steps.optimize.output
|
||||||
depends_on: [optimize]
|
depends_on: [optimize]
|
||||||
|
|
||||||
|
- name: generate_content
|
||||||
|
skill: content_generator
|
||||||
|
input_mapping:
|
||||||
|
brand: $.input.brand
|
||||||
|
optimization: $.steps.optimize.output
|
||||||
|
schema: $.steps.schema.output
|
||||||
|
depends_on: [schema]
|
||||||
|
|
||||||
|
- name: deai
|
||||||
|
skill: deai_agent
|
||||||
|
input_mapping:
|
||||||
|
content: $.steps.generate_content.output
|
||||||
|
depends_on: [generate_content]
|
||||||
|
|
||||||
- name: monitor
|
- name: monitor
|
||||||
skill: monitor
|
skill: monitor
|
||||||
input_mapping:
|
input_mapping:
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,8 @@ output_schema:
|
||||||
tools:
|
tools:
|
||||||
- execute_single_platform
|
- execute_single_platform
|
||||||
- get_or_create_task
|
- get_or_create_task
|
||||||
|
- baidu_search
|
||||||
|
- web_crawl
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
working:
|
working:
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,8 @@ output_schema:
|
||||||
tools:
|
tools:
|
||||||
- competitor_analyze
|
- competitor_analyze
|
||||||
- competitor_gap_analysis
|
- competitor_gap_analysis
|
||||||
|
- baidu_search
|
||||||
|
- web_crawl
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
working:
|
working:
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,7 @@ llm:
|
||||||
|
|
||||||
tools:
|
tools:
|
||||||
- retrieve_knowledge
|
- retrieve_knowledge
|
||||||
|
- baidu_search
|
||||||
|
|
||||||
quality_gate:
|
quality_gate:
|
||||||
required_fields: ["content"]
|
required_fields: ["content"]
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,8 @@ llm:
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_tokens: 8000
|
max_tokens: 8000
|
||||||
|
|
||||||
tools: []
|
tools:
|
||||||
|
- schema_generate
|
||||||
|
|
||||||
quality_gate:
|
quality_gate:
|
||||||
required_fields: ["optimized_content"]
|
required_fields: ["optimized_content"]
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ tools:
|
||||||
- monitor_check_and_compare
|
- monitor_check_and_compare
|
||||||
- monitor_generate_report
|
- monitor_generate_report
|
||||||
- monitor_create_record
|
- monitor_create_record
|
||||||
|
- baidu_search
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
working:
|
working:
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,8 @@ output_schema:
|
||||||
|
|
||||||
tools:
|
tools:
|
||||||
- fill_schema_with_llm
|
- fill_schema_with_llm
|
||||||
|
- schema_extract
|
||||||
|
- schema_generate
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
working:
|
working:
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,8 @@ output_schema:
|
||||||
tools:
|
tools:
|
||||||
- trend_insight
|
- trend_insight
|
||||||
- trend_hotspot
|
- trend_hotspot
|
||||||
|
- baidu_search
|
||||||
|
- web_crawl
|
||||||
|
|
||||||
memory:
|
memory:
|
||||||
working:
|
working:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
"""BaiduSearchTool - 百度搜索工具,支持优雅降级
|
||||||
|
|
||||||
|
通过百度搜索 API 执行关键词搜索,返回搜索结果列表。
|
||||||
|
当百度搜索 API 不可用时,返回包含降级提示的错误信息。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduSearchTool(Tool):
|
||||||
|
"""百度搜索工具 - 执行关键词搜索,返回搜索结果
|
||||||
|
|
||||||
|
支持两种模式:
|
||||||
|
1. 百度搜索 API(需要 API key 配置)
|
||||||
|
2. 直接抓取百度搜索结果页(降级模式,无需 API key)
|
||||||
|
|
||||||
|
当两种模式都不可用时,返回包含降级提示的错误信息。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "baidu_search",
|
||||||
|
description: str = "执行百度搜索,返回搜索结果列表",
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_url: str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["search", "baidu"],
|
||||||
|
)
|
||||||
|
self._api_key = api_key
|
||||||
|
self._api_url = api_url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "搜索关键词",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "最大返回结果数",
|
||||||
|
"default": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"results": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"url": {"type": "string"},
|
||||||
|
"snippet": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"description": "搜索结果列表",
|
||||||
|
},
|
||||||
|
"total": {"type": "integer", "description": "结果总数"},
|
||||||
|
"success": {"type": "boolean", "description": "是否成功"},
|
||||||
|
"error": {"type": "string", "description": "错误信息(仅失败时)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行百度搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词(必需)
|
||||||
|
max_results: 最大返回结果数(默认 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 results 列表和 success 布尔值的字典
|
||||||
|
"""
|
||||||
|
query = kwargs.get("query")
|
||||||
|
if not query:
|
||||||
|
return {"error": "query 参数是必需的", "results": [], "total": 0, "success": False}
|
||||||
|
|
||||||
|
max_results = kwargs.get("max_results", 5)
|
||||||
|
|
||||||
|
# 优先使用 API 模式
|
||||||
|
if self._api_key and self._api_url:
|
||||||
|
return await self._search_via_api(query, max_results)
|
||||||
|
|
||||||
|
# 降级:直接抓取百度搜索结果页
|
||||||
|
return await self._search_via_scrape(query, max_results)
|
||||||
|
|
||||||
|
async def _search_via_api(self, query: str, max_results: int) -> dict:
|
||||||
|
"""通过百度搜索 API 执行搜索"""
|
||||||
|
try:
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"num": max_results,
|
||||||
|
}
|
||||||
|
url = f"{self._api_url}?{urllib.parse.urlencode(params)}"
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "AgentKit/1.0",
|
||||||
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
data = json.loads(resp.read().decode("utf-8"))
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item in data.get("results", [])[:max_results]:
|
||||||
|
results.append({
|
||||||
|
"title": item.get("title", ""),
|
||||||
|
"url": item.get("url", ""),
|
||||||
|
"snippet": item.get("snippet", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results), "success": True}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BaiduSearchTool API 搜索失败: {e}")
|
||||||
|
# 降级到抓取模式
|
||||||
|
return await self._search_via_scrape(query, max_results)
|
||||||
|
|
||||||
|
async def _search_via_scrape(self, query: str, max_results: int) -> dict:
|
||||||
|
"""通过直接抓取百度搜索结果页执行搜索(降级模式)"""
|
||||||
|
try:
|
||||||
|
encoded_query = urllib.parse.quote(query)
|
||||||
|
url = f"https://www.baidu.com/s?wd={encoded_query}&rn={max_results}"
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"User-Agent": (
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/120.0.0.0 Safari/537.36"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
html = resp.read().decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# 简单解析搜索结果(基于百度搜索结果页 HTML 结构)
|
||||||
|
results = self._parse_baidu_html(html, max_results)
|
||||||
|
|
||||||
|
return {"results": results, "total": len(results), "success": True}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"BaiduSearchTool 抓取搜索失败: {e}")
|
||||||
|
return {
|
||||||
|
"error": f"百度搜索不可用: {e}",
|
||||||
|
"results": [],
|
||||||
|
"total": 0,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_baidu_html(html: str, max_results: int) -> list[dict[str, str]]:
|
||||||
|
"""解析百度搜索结果页 HTML,提取标题、URL、摘要
|
||||||
|
|
||||||
|
注意:百度 HTML 结构可能变化,此解析器尽力提取关键信息。
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
results: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
# 匹配百度搜索结果块
|
||||||
|
# 百度搜索结果通常在 <div class="result c-container"> 中
|
||||||
|
pattern = re.compile(
|
||||||
|
r'<h3[^>]*class="[^"]*t[^"]*"[^>]*>.*?href="([^"]*)"[^>]*>(.*?)</a>',
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
snippet_pattern = re.compile(
|
||||||
|
r'<span[^>]*class="[^"]*content-right_[^"]*"[^>]*>(.*?)</span>',
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
for match in pattern.finditer(html):
|
||||||
|
if len(results) >= max_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
url = match.group(1)
|
||||||
|
title = re.sub(r"<[^>]+>", "", match.group(2)).strip()
|
||||||
|
|
||||||
|
# 跳过百度内部链接
|
||||||
|
if "baidu.com/link?" not in url and not url.startswith("http"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试提取摘要
|
||||||
|
snippet = ""
|
||||||
|
snippet_match = snippet_pattern.search(html[match.end():match.end() + 2000])
|
||||||
|
if snippet_match:
|
||||||
|
snippet = re.sub(r"<[^>]+>", "", snippet_match.group(1)).strip()
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"title": title,
|
||||||
|
"url": url,
|
||||||
|
"snippet": snippet[:200] if snippet else "",
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,558 @@
|
||||||
|
"""GEO Skill 工具绑定与端到端验证 — U4 集成测试
|
||||||
|
|
||||||
|
验证:
|
||||||
|
- SkillConfig with tools 字段加载正确
|
||||||
|
- ConfigDrivenAgent 从 ToolRegistry 注册声明的工具
|
||||||
|
- citation_detector 绑定 search + crawl 工具
|
||||||
|
- competitor_analyzer 绑定 search + crawl 工具
|
||||||
|
- geo_optimizer 绑定 schema_generate 工具
|
||||||
|
- schema_advisor 绑定 extract + generate 工具
|
||||||
|
- Tool 不在 ToolRegistry 中时优雅降级(log warning, skip)
|
||||||
|
- GEO Pipeline 配置加载正确
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||||
|
from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
CONFIGS_DIR = os.path.join(
|
||||||
|
os.path.dirname(__file__), "..", "..", "configs"
|
||||||
|
)
|
||||||
|
SKILLS_DIR = os.path.join(CONFIGS_DIR, "skills")
|
||||||
|
PIPELINES_DIR = os.path.join(CONFIGS_DIR, "pipelines")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tool_registry_with_infra_tools():
|
||||||
|
"""创建包含基础设施工具的 ToolRegistry"""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(BaiduSearchTool())
|
||||||
|
registry.register(WebCrawlTool())
|
||||||
|
registry.register(SchemaExtractTool())
|
||||||
|
registry.register(SchemaGenerateTool())
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tool_registry_empty():
|
||||||
|
"""创建空的 ToolRegistry(用于测试工具不可用时的降级)"""
|
||||||
|
return ToolRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: SkillConfig tools 字段加载 ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillConfigToolsField:
|
||||||
|
"""验证 SkillConfig 的 tools 字段正确加载"""
|
||||||
|
|
||||||
|
def test_citation_detector_tools_loaded(self):
|
||||||
|
"""citation_detector YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||||||
|
)
|
||||||
|
assert "baidu_search" in config.tools
|
||||||
|
assert "web_crawl" in config.tools
|
||||||
|
# 原有业务工具也保留
|
||||||
|
assert "execute_single_platform" in config.tools
|
||||||
|
assert "get_or_create_task" in config.tools
|
||||||
|
|
||||||
|
def test_competitor_analyzer_tools_loaded(self):
|
||||||
|
"""competitor_analyzer YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "competitor_analyzer.yaml")
|
||||||
|
)
|
||||||
|
assert "baidu_search" in config.tools
|
||||||
|
assert "web_crawl" in config.tools
|
||||||
|
assert "competitor_analyze" in config.tools
|
||||||
|
|
||||||
|
def test_geo_optimizer_tools_loaded(self):
|
||||||
|
"""geo_optimizer YAML 加载后 tools 包含 schema_generate"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "geo_optimizer.yaml")
|
||||||
|
)
|
||||||
|
assert "schema_generate" in config.tools
|
||||||
|
|
||||||
|
def test_schema_advisor_tools_loaded(self):
|
||||||
|
"""schema_advisor YAML 加载后 tools 包含 schema_extract + schema_generate"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "schema_advisor.yaml")
|
||||||
|
)
|
||||||
|
assert "schema_extract" in config.tools
|
||||||
|
assert "schema_generate" in config.tools
|
||||||
|
assert "fill_schema_with_llm" in config.tools
|
||||||
|
|
||||||
|
def test_monitor_tools_loaded(self):
|
||||||
|
"""monitor YAML 加载后 tools 包含 baidu_search"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "monitor.yaml")
|
||||||
|
)
|
||||||
|
assert "baidu_search" in config.tools
|
||||||
|
assert "monitor_check_and_compare" in config.tools
|
||||||
|
|
||||||
|
def test_trend_agent_tools_loaded(self):
|
||||||
|
"""trend_agent YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "trend_agent.yaml")
|
||||||
|
)
|
||||||
|
assert "baidu_search" in config.tools
|
||||||
|
assert "web_crawl" in config.tools
|
||||||
|
|
||||||
|
def test_content_generator_tools_loaded(self):
|
||||||
|
"""content_generator YAML 加载后 tools 包含 baidu_search"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "content_generator.yaml")
|
||||||
|
)
|
||||||
|
assert "baidu_search" in config.tools
|
||||||
|
assert "retrieve_knowledge" in config.tools
|
||||||
|
|
||||||
|
def test_deai_agent_tools_loaded(self):
|
||||||
|
"""deai_agent YAML 加载后 tools 包含 detect_ai_patterns"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "deai_agent.yaml")
|
||||||
|
)
|
||||||
|
assert "detect_ai_patterns" in config.tools
|
||||||
|
|
||||||
|
def test_all_skills_load_without_error(self):
|
||||||
|
"""所有 GEO Skill YAML 都能成功加载"""
|
||||||
|
yaml_files = [
|
||||||
|
"citation_detector.yaml",
|
||||||
|
"competitor_analyzer.yaml",
|
||||||
|
"geo_optimizer.yaml",
|
||||||
|
"monitor.yaml",
|
||||||
|
"schema_advisor.yaml",
|
||||||
|
"trend_agent.yaml",
|
||||||
|
"content_generator.yaml",
|
||||||
|
"deai_agent.yaml",
|
||||||
|
]
|
||||||
|
for filename in yaml_files:
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, filename)
|
||||||
|
)
|
||||||
|
assert config.name, f"{filename} should have a name"
|
||||||
|
assert config.tools is not None, f"{filename} should have tools field"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: ConfigDrivenAgent 工具绑定 ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigDrivenAgentToolBinding:
|
||||||
|
"""验证 ConfigDrivenAgent 从 ToolRegistry 注册声明的工具"""
|
||||||
|
|
||||||
|
def test_citation_detector_binds_search_and_crawl(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""citation_detector 绑定 baidu_search + web_crawl 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
assert "web_crawl" in tool_names
|
||||||
|
|
||||||
|
def test_competitor_analyzer_binds_search_and_crawl(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""competitor_analyzer 绑定 baidu_search + web_crawl 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "competitor_analyzer.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
assert "web_crawl" in tool_names
|
||||||
|
|
||||||
|
def test_geo_optimizer_binds_schema_generate(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""geo_optimizer 绑定 schema_generate 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "geo_optimizer.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "schema_generate" in tool_names
|
||||||
|
|
||||||
|
def test_schema_advisor_binds_extract_and_generate(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""schema_advisor 绑定 schema_extract + schema_generate 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "schema_advisor.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "schema_extract" in tool_names
|
||||||
|
assert "schema_generate" in tool_names
|
||||||
|
|
||||||
|
def test_monitor_binds_search(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""monitor 绑定 baidu_search 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "monitor.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
|
||||||
|
def test_trend_agent_binds_search_and_crawl(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""trend_agent 绑定 baidu_search + web_crawl 工具"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "trend_agent.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
assert "web_crawl" in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: 工具不可用时优雅降级 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolNotFoundGracefulDegradation:
|
||||||
|
"""验证 Tool 不在 ToolRegistry 中时优雅降级"""
|
||||||
|
|
||||||
|
def test_missing_tool_does_not_crash(self, tool_registry_empty):
|
||||||
|
"""Tool 不在 ToolRegistry 中时 Agent 不会崩溃"""
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||||||
|
)
|
||||||
|
# citation_detector 声明了 baidu_search, web_crawl, execute_single_platform, get_or_create_task
|
||||||
|
# 这些工具都不在空 registry 中
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=tool_registry_empty,
|
||||||
|
)
|
||||||
|
# Agent 应该成功创建,只是没有绑定任何工具
|
||||||
|
assert agent is not None
|
||||||
|
assert len(agent.get_tools()) == 0
|
||||||
|
|
||||||
|
def test_partial_tool_binding(self):
|
||||||
|
"""部分工具在 Registry 中时,只绑定可用的工具"""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(BaiduSearchTool())
|
||||||
|
# 只注册了 baidu_search,没有 web_crawl
|
||||||
|
|
||||||
|
config = SkillConfig.from_yaml(
|
||||||
|
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=registry,
|
||||||
|
)
|
||||||
|
tool_names = [t.name for t in agent.get_tools()]
|
||||||
|
# baidu_search 应该绑定成功
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
# web_crawl 不在 registry 中,不应该绑定
|
||||||
|
assert "web_crawl" not in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: SkillLoader 批量加载 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkillLoaderBatchLoad:
|
||||||
|
"""验证 SkillLoader 从目录批量加载 Skill 并绑定工具"""
|
||||||
|
|
||||||
|
def test_load_all_geo_skills(self, tool_registry_with_infra_tools):
|
||||||
|
"""从 skills 目录加载所有 GEO Skill"""
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
skills = loader.load_from_directory(SKILLS_DIR)
|
||||||
|
|
||||||
|
# 验证所有 Skill 都加载成功
|
||||||
|
skill_names = [s.name for s in skills]
|
||||||
|
assert "citation_detector" in skill_names
|
||||||
|
assert "competitor_analyzer" in skill_names
|
||||||
|
assert "geo_optimizer" in skill_names
|
||||||
|
assert "monitor" in skill_names
|
||||||
|
assert "schema_advisor" in skill_names
|
||||||
|
assert "trend_agent" in skill_names
|
||||||
|
assert "content_generator" in skill_names
|
||||||
|
assert "deai_agent" in skill_names
|
||||||
|
|
||||||
|
def test_citation_detector_skill_has_tools(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""citation_detector Skill 绑定了 search + crawl 工具"""
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
loader.load_from_directory(SKILLS_DIR)
|
||||||
|
|
||||||
|
skill = skill_registry.get("citation_detector")
|
||||||
|
tool_names = [t.name for t in skill.tools]
|
||||||
|
assert "baidu_search" in tool_names
|
||||||
|
assert "web_crawl" in tool_names
|
||||||
|
|
||||||
|
def test_schema_advisor_skill_has_tools(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""schema_advisor Skill 绑定了 extract + generate 工具"""
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
loader.load_from_directory(SKILLS_DIR)
|
||||||
|
|
||||||
|
skill = skill_registry.get("schema_advisor")
|
||||||
|
tool_names = [t.name for t in skill.tools]
|
||||||
|
assert "schema_extract" in tool_names
|
||||||
|
assert "schema_generate" in tool_names
|
||||||
|
|
||||||
|
def test_geo_optimizer_skill_has_schema_generate(
|
||||||
|
self, tool_registry_with_infra_tools
|
||||||
|
):
|
||||||
|
"""geo_optimizer Skill 绑定了 schema_generate 工具"""
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry_with_infra_tools,
|
||||||
|
)
|
||||||
|
loader.load_from_directory(SKILLS_DIR)
|
||||||
|
|
||||||
|
skill = skill_registry.get("geo_optimizer")
|
||||||
|
tool_names = [t.name for t in skill.tools]
|
||||||
|
assert "schema_generate" in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: GEO Pipeline 配置加载 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGEOPipelineConfig:
|
||||||
|
"""验证 GEO Pipeline 配置加载正确"""
|
||||||
|
|
||||||
|
def test_pipeline_config_loads(self):
|
||||||
|
"""geo_full_pipeline.yaml 能成功加载"""
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
assert config["name"] == "geo_full_pipeline"
|
||||||
|
assert len(config["steps"]) > 0
|
||||||
|
|
||||||
|
def test_pipeline_has_all_steps(self):
|
||||||
|
"""Pipeline 包含所有 GEO 步骤"""
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
step_names = [s["name"] for s in config["steps"]]
|
||||||
|
# 核心步骤
|
||||||
|
assert "detect" in step_names
|
||||||
|
assert "analyze_competitor" in step_names
|
||||||
|
assert "analyze_trend" in step_names
|
||||||
|
assert "optimize" in step_names
|
||||||
|
assert "schema" in step_names
|
||||||
|
assert "monitor" in step_names
|
||||||
|
# 新增步骤
|
||||||
|
assert "generate_content" in step_names
|
||||||
|
assert "deai" in step_names
|
||||||
|
|
||||||
|
def test_pipeline_step_skills_match_yaml_names(self):
|
||||||
|
"""Pipeline 步骤的 skill 字段与 YAML 文件中的 name 一致"""
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
for step in config["steps"]:
|
||||||
|
skill_name = step["skill"]
|
||||||
|
yaml_path = os.path.join(SKILLS_DIR, f"{skill_name}.yaml")
|
||||||
|
assert os.path.exists(yaml_path), (
|
||||||
|
f"Pipeline step '{step['name']}' references skill "
|
||||||
|
f"'{skill_name}' but {yaml_path} does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_pipeline_dependency_graph_is_valid(self):
|
||||||
|
"""Pipeline 依赖关系有效(无循环依赖)"""
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
step_map = {s["name"]: s.get("depends_on", []) for s in config["steps"]}
|
||||||
|
|
||||||
|
# 拓扑排序检测循环依赖
|
||||||
|
visited = set()
|
||||||
|
in_stack = set()
|
||||||
|
|
||||||
|
def dfs(name):
|
||||||
|
if name in in_stack:
|
||||||
|
return False # 循环依赖
|
||||||
|
if name in visited:
|
||||||
|
return True
|
||||||
|
in_stack.add(name)
|
||||||
|
for dep in step_map.get(name, []):
|
||||||
|
if not dfs(dep):
|
||||||
|
return False
|
||||||
|
in_stack.discard(name)
|
||||||
|
visited.add(name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
for name in step_map:
|
||||||
|
assert dfs(name), f"Circular dependency detected involving '{name}'"
|
||||||
|
|
||||||
|
def test_pipeline_from_config_creates_pipeline(self):
|
||||||
|
"""GEOPipeline.from_config 能从 YAML 配置创建 Pipeline"""
|
||||||
|
from agentkit.skills.geo_pipeline import GEOPipeline
|
||||||
|
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
pipeline = GEOPipeline.from_config(config)
|
||||||
|
assert pipeline.name == "geo_full_pipeline"
|
||||||
|
assert len(pipeline._steps) == len(config["steps"])
|
||||||
|
|
||||||
|
def test_pipeline_execution_order_respects_dependencies(self):
|
||||||
|
"""Pipeline 执行顺序尊重依赖关系"""
|
||||||
|
from agentkit.skills.geo_pipeline import GEOPipeline
|
||||||
|
|
||||||
|
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
pipeline = GEOPipeline.from_config(config)
|
||||||
|
groups = pipeline._build_execution_groups()
|
||||||
|
|
||||||
|
# 展平执行顺序
|
||||||
|
executed = set()
|
||||||
|
for group in groups:
|
||||||
|
for step_name in group:
|
||||||
|
step = pipeline._step_map[step_name]
|
||||||
|
# 所有依赖必须已经执行
|
||||||
|
for dep in step.depends_on:
|
||||||
|
assert dep in executed, (
|
||||||
|
f"Step '{step_name}' depends on '{dep}' "
|
||||||
|
f"but '{dep}' hasn't been executed yet"
|
||||||
|
)
|
||||||
|
executed.add(step_name)
|
||||||
|
|
||||||
|
# 所有步骤都应该被执行
|
||||||
|
assert executed == set(pipeline._step_map.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: 基础设施工具实例化 ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestInfrastructureToolsInstantiation:
|
||||||
|
"""验证基础设施工具能正确实例化"""
|
||||||
|
|
||||||
|
def test_baidu_search_tool_instantiation(self):
|
||||||
|
"""BaiduSearchTool 能正确实例化"""
|
||||||
|
tool = BaiduSearchTool()
|
||||||
|
assert tool.name == "baidu_search"
|
||||||
|
assert "search" in tool.tags
|
||||||
|
|
||||||
|
def test_web_crawl_tool_instantiation(self):
|
||||||
|
"""WebCrawlTool 能正确实例化"""
|
||||||
|
tool = WebCrawlTool()
|
||||||
|
assert tool.name == "web_crawl"
|
||||||
|
assert "crawl" in tool.tags
|
||||||
|
|
||||||
|
def test_schema_extract_tool_instantiation(self):
|
||||||
|
"""SchemaExtractTool 能正确实例化"""
|
||||||
|
tool = SchemaExtractTool()
|
||||||
|
assert tool.name == "schema_extract"
|
||||||
|
assert "extraction" in tool.tags
|
||||||
|
|
||||||
|
def test_schema_generate_tool_instantiation(self):
|
||||||
|
"""SchemaGenerateTool 能正确实例化"""
|
||||||
|
tool = SchemaGenerateTool()
|
||||||
|
assert tool.name == "schema_generate"
|
||||||
|
assert "generation" in tool.tags
|
||||||
|
|
||||||
|
def test_all_infra_tools_registered_in_registry(self):
|
||||||
|
"""所有基础设施工具都能注册到 ToolRegistry"""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(BaiduSearchTool())
|
||||||
|
registry.register(WebCrawlTool())
|
||||||
|
registry.register(SchemaExtractTool())
|
||||||
|
registry.register(SchemaGenerateTool())
|
||||||
|
|
||||||
|
assert registry.has_tool("baidu_search")
|
||||||
|
assert registry.has_tool("web_crawl")
|
||||||
|
assert registry.has_tool("schema_extract")
|
||||||
|
assert registry.has_tool("schema_generate")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: AgentConfig tools 字段向后兼容 ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConfigToolsBackwardCompat:
|
||||||
|
"""验证 AgentConfig 的 tools 字段向后兼容"""
|
||||||
|
|
||||||
|
def test_agent_config_with_tools_list(self):
|
||||||
|
"""AgentConfig 接受 tools 列表"""
|
||||||
|
config = AgentConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="tool_call",
|
||||||
|
tools=["baidu_search", "web_crawl"],
|
||||||
|
)
|
||||||
|
assert config.tools == ["baidu_search", "web_crawl"]
|
||||||
|
|
||||||
|
def test_agent_config_without_tools(self):
|
||||||
|
"""AgentConfig 不提供 tools 时默认为空列表"""
|
||||||
|
config = AgentConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "test", "instructions": "test"},
|
||||||
|
)
|
||||||
|
assert config.tools == []
|
||||||
|
|
||||||
|
def test_skill_config_inherits_tools(self):
|
||||||
|
"""SkillConfig 继承 AgentConfig 的 tools 字段"""
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="tool_call",
|
||||||
|
tools=["baidu_search"],
|
||||||
|
)
|
||||||
|
assert config.tools == ["baidu_search"]
|
||||||
|
|
||||||
|
def test_skill_config_from_dict_with_tools(self):
|
||||||
|
"""SkillConfig.from_dict 正确解析 tools 字段"""
|
||||||
|
data = {
|
||||||
|
"name": "test",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "tool_call",
|
||||||
|
"tools": ["baidu_search", "web_crawl", "schema_generate"],
|
||||||
|
}
|
||||||
|
config = SkillConfig.from_dict(data)
|
||||||
|
assert config.tools == ["baidu_search", "web_crawl", "schema_generate"]
|
||||||
Loading…
Reference in New Issue