559 lines
21 KiB
Python
559 lines
21 KiB
Python
"""GEO Skill 工具绑定与端到端验证 — U4 集成测试
|
||
|
||
验证:
|
||
- SkillConfig with tools 字段加载正确
|
||
- ConfigDrivenAgent 从 ToolRegistry 注册声明的工具
|
||
- citation_detector 绑定 search + crawl 工具
|
||
- competitor_analyzer 绑定 search + crawl 工具
|
||
- geo_optimizer 绑定 schema_generate 工具
|
||
- schema_advisor 绑定 extract + generate 工具
|
||
- Tool 不在 ToolRegistry 中时优雅降级(log warning, skip)
|
||
- GEO Pipeline 配置加载正确
|
||
"""
|
||
|
||
import os
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
import yaml
|
||
|
||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||
from agentkit.skills.base import Skill, SkillConfig
|
||
from agentkit.skills.loader import SkillLoader
|
||
from agentkit.skills.registry import SkillRegistry
|
||
from agentkit.tools.baidu_search import BaiduSearchTool
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||
from agentkit.tools.web_crawl import WebCrawlTool
|
||
|
||
|
||
# ── Fixtures ────────────────────────────────────────────────
|
||
|
||
CONFIGS_DIR = os.path.join(
|
||
os.path.dirname(__file__), "..", "..", "configs"
|
||
)
|
||
SKILLS_DIR = os.path.join(CONFIGS_DIR, "skills")
|
||
PIPELINES_DIR = os.path.join(CONFIGS_DIR, "pipelines")
|
||
|
||
|
||
@pytest.fixture
|
||
def tool_registry_with_infra_tools():
|
||
"""创建包含基础设施工具的 ToolRegistry"""
|
||
registry = ToolRegistry()
|
||
registry.register(BaiduSearchTool())
|
||
registry.register(WebCrawlTool())
|
||
registry.register(SchemaExtractTool())
|
||
registry.register(SchemaGenerateTool())
|
||
return registry
|
||
|
||
|
||
@pytest.fixture
|
||
def tool_registry_empty():
|
||
"""创建空的 ToolRegistry(用于测试工具不可用时的降级)"""
|
||
return ToolRegistry()
|
||
|
||
|
||
# ── Test: SkillConfig tools 字段加载 ────────────────────────
|
||
|
||
|
||
class TestSkillConfigToolsField:
|
||
"""验证 SkillConfig 的 tools 字段正确加载"""
|
||
|
||
def test_citation_detector_tools_loaded(self):
|
||
"""citation_detector YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||
)
|
||
assert "baidu_search" in config.tools
|
||
assert "web_crawl" in config.tools
|
||
# 原有业务工具也保留
|
||
assert "execute_single_platform" in config.tools
|
||
assert "get_or_create_task" in config.tools
|
||
|
||
def test_competitor_analyzer_tools_loaded(self):
|
||
"""competitor_analyzer YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "competitor_analyzer.yaml")
|
||
)
|
||
assert "baidu_search" in config.tools
|
||
assert "web_crawl" in config.tools
|
||
assert "competitor_analyze" in config.tools
|
||
|
||
def test_geo_optimizer_tools_loaded(self):
|
||
"""geo_optimizer YAML 加载后 tools 包含 schema_generate"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "geo_optimizer.yaml")
|
||
)
|
||
assert "schema_generate" in config.tools
|
||
|
||
def test_schema_advisor_tools_loaded(self):
|
||
"""schema_advisor YAML 加载后 tools 包含 schema_extract + schema_generate"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "schema_advisor.yaml")
|
||
)
|
||
assert "schema_extract" in config.tools
|
||
assert "schema_generate" in config.tools
|
||
assert "fill_schema_with_llm" in config.tools
|
||
|
||
def test_monitor_tools_loaded(self):
|
||
"""monitor YAML 加载后 tools 包含 baidu_search"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "monitor.yaml")
|
||
)
|
||
assert "baidu_search" in config.tools
|
||
assert "monitor_check_and_compare" in config.tools
|
||
|
||
def test_trend_agent_tools_loaded(self):
|
||
"""trend_agent YAML 加载后 tools 包含 baidu_search + web_crawl"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "trend_agent.yaml")
|
||
)
|
||
assert "baidu_search" in config.tools
|
||
assert "web_crawl" in config.tools
|
||
|
||
def test_content_generator_tools_loaded(self):
|
||
"""content_generator YAML 加载后 tools 包含 baidu_search"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "content_generator.yaml")
|
||
)
|
||
assert "baidu_search" in config.tools
|
||
assert "retrieve_knowledge" in config.tools
|
||
|
||
def test_deai_agent_tools_loaded(self):
|
||
"""deai_agent YAML 加载后 tools 包含 detect_ai_patterns"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "deai_agent.yaml")
|
||
)
|
||
assert "detect_ai_patterns" in config.tools
|
||
|
||
def test_all_skills_load_without_error(self):
|
||
"""所有 GEO Skill YAML 都能成功加载"""
|
||
yaml_files = [
|
||
"citation_detector.yaml",
|
||
"competitor_analyzer.yaml",
|
||
"geo_optimizer.yaml",
|
||
"monitor.yaml",
|
||
"schema_advisor.yaml",
|
||
"trend_agent.yaml",
|
||
"content_generator.yaml",
|
||
"deai_agent.yaml",
|
||
]
|
||
for filename in yaml_files:
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, filename)
|
||
)
|
||
assert config.name, f"{filename} should have a name"
|
||
assert config.tools is not None, f"{filename} should have tools field"
|
||
|
||
|
||
# ── Test: ConfigDrivenAgent 工具绑定 ────────────────────────
|
||
|
||
|
||
class TestConfigDrivenAgentToolBinding:
|
||
"""验证 ConfigDrivenAgent 从 ToolRegistry 注册声明的工具"""
|
||
|
||
def test_citation_detector_binds_search_and_crawl(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""citation_detector 绑定 baidu_search + web_crawl 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "baidu_search" in tool_names
|
||
assert "web_crawl" in tool_names
|
||
|
||
def test_competitor_analyzer_binds_search_and_crawl(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""competitor_analyzer 绑定 baidu_search + web_crawl 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "competitor_analyzer.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "baidu_search" in tool_names
|
||
assert "web_crawl" in tool_names
|
||
|
||
def test_geo_optimizer_binds_schema_generate(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""geo_optimizer 绑定 schema_generate 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "geo_optimizer.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "schema_generate" in tool_names
|
||
|
||
def test_schema_advisor_binds_extract_and_generate(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""schema_advisor 绑定 schema_extract + schema_generate 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "schema_advisor.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "schema_extract" in tool_names
|
||
assert "schema_generate" in tool_names
|
||
|
||
def test_monitor_binds_search(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""monitor 绑定 baidu_search 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "monitor.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "baidu_search" in tool_names
|
||
|
||
def test_trend_agent_binds_search_and_crawl(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""trend_agent 绑定 baidu_search + web_crawl 工具"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "trend_agent.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
assert "baidu_search" in tool_names
|
||
assert "web_crawl" in tool_names
|
||
|
||
|
||
# ── Test: 工具不可用时优雅降级 ──────────────────────────────
|
||
|
||
|
||
class TestToolNotFoundGracefulDegradation:
|
||
"""验证 Tool 不在 ToolRegistry 中时优雅降级"""
|
||
|
||
def test_missing_tool_does_not_crash(self, tool_registry_empty):
|
||
"""Tool 不在 ToolRegistry 中时 Agent 不会崩溃"""
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||
)
|
||
# citation_detector 声明了 baidu_search, web_crawl, execute_single_platform, get_or_create_task
|
||
# 这些工具都不在空 registry 中
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry_empty,
|
||
)
|
||
# Agent 应该成功创建,只是没有绑定任何工具
|
||
assert agent is not None
|
||
assert len(agent.get_tools()) == 0
|
||
|
||
def test_partial_tool_binding(self):
|
||
"""部分工具在 Registry 中时,只绑定可用的工具"""
|
||
registry = ToolRegistry()
|
||
registry.register(BaiduSearchTool())
|
||
# 只注册了 baidu_search,没有 web_crawl
|
||
|
||
config = SkillConfig.from_yaml(
|
||
os.path.join(SKILLS_DIR, "citation_detector.yaml")
|
||
)
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=registry,
|
||
)
|
||
tool_names = [t.name for t in agent.get_tools()]
|
||
# baidu_search 应该绑定成功
|
||
assert "baidu_search" in tool_names
|
||
# web_crawl 不在 registry 中,不应该绑定
|
||
assert "web_crawl" not in tool_names
|
||
|
||
|
||
# ── Test: SkillLoader 批量加载 ──────────────────────────────
|
||
|
||
|
||
class TestSkillLoaderBatchLoad:
|
||
"""验证 SkillLoader 从目录批量加载 Skill 并绑定工具"""
|
||
|
||
def test_load_all_geo_skills(self, tool_registry_with_infra_tools):
|
||
"""从 skills 目录加载所有 GEO Skill"""
|
||
skill_registry = SkillRegistry()
|
||
loader = SkillLoader(
|
||
skill_registry=skill_registry,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
skills = loader.load_from_directory(SKILLS_DIR)
|
||
|
||
# 验证所有 Skill 都加载成功
|
||
skill_names = [s.name for s in skills]
|
||
assert "citation_detector" in skill_names
|
||
assert "competitor_analyzer" in skill_names
|
||
assert "geo_optimizer" in skill_names
|
||
assert "monitor" in skill_names
|
||
assert "schema_advisor" in skill_names
|
||
assert "trend_agent" in skill_names
|
||
assert "content_generator" in skill_names
|
||
assert "deai_agent" in skill_names
|
||
|
||
def test_citation_detector_skill_has_tools(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""citation_detector Skill 绑定了 search + crawl 工具"""
|
||
skill_registry = SkillRegistry()
|
||
loader = SkillLoader(
|
||
skill_registry=skill_registry,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
loader.load_from_directory(SKILLS_DIR)
|
||
|
||
skill = skill_registry.get("citation_detector")
|
||
tool_names = [t.name for t in skill.tools]
|
||
assert "baidu_search" in tool_names
|
||
assert "web_crawl" in tool_names
|
||
|
||
def test_schema_advisor_skill_has_tools(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""schema_advisor Skill 绑定了 extract + generate 工具"""
|
||
skill_registry = SkillRegistry()
|
||
loader = SkillLoader(
|
||
skill_registry=skill_registry,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
loader.load_from_directory(SKILLS_DIR)
|
||
|
||
skill = skill_registry.get("schema_advisor")
|
||
tool_names = [t.name for t in skill.tools]
|
||
assert "schema_extract" in tool_names
|
||
assert "schema_generate" in tool_names
|
||
|
||
def test_geo_optimizer_skill_has_schema_generate(
|
||
self, tool_registry_with_infra_tools
|
||
):
|
||
"""geo_optimizer Skill 绑定了 schema_generate 工具"""
|
||
skill_registry = SkillRegistry()
|
||
loader = SkillLoader(
|
||
skill_registry=skill_registry,
|
||
tool_registry=tool_registry_with_infra_tools,
|
||
)
|
||
loader.load_from_directory(SKILLS_DIR)
|
||
|
||
skill = skill_registry.get("geo_optimizer")
|
||
tool_names = [t.name for t in skill.tools]
|
||
assert "schema_generate" in tool_names
|
||
|
||
|
||
# ── Test: GEO Pipeline 配置加载 ──────────────────────────────
|
||
|
||
|
||
class TestGEOPipelineConfig:
|
||
"""验证 GEO Pipeline 配置加载正确"""
|
||
|
||
def test_pipeline_config_loads(self):
|
||
"""geo_full_pipeline.yaml 能成功加载"""
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
assert config["name"] == "geo_full_pipeline"
|
||
assert len(config["steps"]) > 0
|
||
|
||
def test_pipeline_has_all_steps(self):
|
||
"""Pipeline 包含所有 GEO 步骤"""
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
step_names = [s["name"] for s in config["steps"]]
|
||
# 核心步骤
|
||
assert "detect" in step_names
|
||
assert "analyze_competitor" in step_names
|
||
assert "analyze_trend" in step_names
|
||
assert "optimize" in step_names
|
||
assert "schema" in step_names
|
||
assert "monitor" in step_names
|
||
# 新增步骤
|
||
assert "generate_content" in step_names
|
||
assert "deai" in step_names
|
||
|
||
def test_pipeline_step_skills_match_yaml_names(self):
|
||
"""Pipeline 步骤的 skill 字段与 YAML 文件中的 name 一致"""
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
for step in config["steps"]:
|
||
skill_name = step["skill"]
|
||
yaml_path = os.path.join(SKILLS_DIR, f"{skill_name}.yaml")
|
||
assert os.path.exists(yaml_path), (
|
||
f"Pipeline step '{step['name']}' references skill "
|
||
f"'{skill_name}' but {yaml_path} does not exist"
|
||
)
|
||
|
||
def test_pipeline_dependency_graph_is_valid(self):
|
||
"""Pipeline 依赖关系有效(无循环依赖)"""
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
step_map = {s["name"]: s.get("depends_on", []) for s in config["steps"]}
|
||
|
||
# 拓扑排序检测循环依赖
|
||
visited = set()
|
||
in_stack = set()
|
||
|
||
def dfs(name):
|
||
if name in in_stack:
|
||
return False # 循环依赖
|
||
if name in visited:
|
||
return True
|
||
in_stack.add(name)
|
||
for dep in step_map.get(name, []):
|
||
if not dfs(dep):
|
||
return False
|
||
in_stack.discard(name)
|
||
visited.add(name)
|
||
return True
|
||
|
||
for name in step_map:
|
||
assert dfs(name), f"Circular dependency detected involving '{name}'"
|
||
|
||
def test_pipeline_from_config_creates_pipeline(self):
|
||
"""GEOPipeline.from_config 能从 YAML 配置创建 Pipeline"""
|
||
from agentkit.skills.geo_pipeline import GEOPipeline
|
||
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
pipeline = GEOPipeline.from_config(config)
|
||
assert pipeline.name == "geo_full_pipeline"
|
||
assert len(pipeline._steps) == len(config["steps"])
|
||
|
||
def test_pipeline_execution_order_respects_dependencies(self):
|
||
"""Pipeline 执行顺序尊重依赖关系"""
|
||
from agentkit.skills.geo_pipeline import GEOPipeline
|
||
|
||
with open(os.path.join(PIPELINES_DIR, "geo_full_pipeline.yaml"), "r") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
pipeline = GEOPipeline.from_config(config)
|
||
groups = pipeline._build_execution_groups()
|
||
|
||
# 展平执行顺序
|
||
executed = set()
|
||
for group in groups:
|
||
for step_name in group:
|
||
step = pipeline._step_map[step_name]
|
||
# 所有依赖必须已经执行
|
||
for dep in step.depends_on:
|
||
assert dep in executed, (
|
||
f"Step '{step_name}' depends on '{dep}' "
|
||
f"but '{dep}' hasn't been executed yet"
|
||
)
|
||
executed.add(step_name)
|
||
|
||
# 所有步骤都应该被执行
|
||
assert executed == set(pipeline._step_map.keys())
|
||
|
||
|
||
# ── Test: 基础设施工具实例化 ──────────────────────────────────
|
||
|
||
|
||
class TestInfrastructureToolsInstantiation:
|
||
"""验证基础设施工具能正确实例化"""
|
||
|
||
def test_baidu_search_tool_instantiation(self):
|
||
"""BaiduSearchTool 能正确实例化"""
|
||
tool = BaiduSearchTool()
|
||
assert tool.name == "baidu_search"
|
||
assert "search" in tool.tags
|
||
|
||
def test_web_crawl_tool_instantiation(self):
|
||
"""WebCrawlTool 能正确实例化"""
|
||
tool = WebCrawlTool()
|
||
assert tool.name == "web_crawl"
|
||
assert "crawl" in tool.tags
|
||
|
||
def test_schema_extract_tool_instantiation(self):
|
||
"""SchemaExtractTool 能正确实例化"""
|
||
tool = SchemaExtractTool()
|
||
assert tool.name == "schema_extract"
|
||
assert "extraction" in tool.tags
|
||
|
||
def test_schema_generate_tool_instantiation(self):
|
||
"""SchemaGenerateTool 能正确实例化"""
|
||
tool = SchemaGenerateTool()
|
||
assert tool.name == "schema_generate"
|
||
assert "generation" in tool.tags
|
||
|
||
def test_all_infra_tools_registered_in_registry(self):
|
||
"""所有基础设施工具都能注册到 ToolRegistry"""
|
||
registry = ToolRegistry()
|
||
registry.register(BaiduSearchTool())
|
||
registry.register(WebCrawlTool())
|
||
registry.register(SchemaExtractTool())
|
||
registry.register(SchemaGenerateTool())
|
||
|
||
assert registry.has_tool("baidu_search")
|
||
assert registry.has_tool("web_crawl")
|
||
assert registry.has_tool("schema_extract")
|
||
assert registry.has_tool("schema_generate")
|
||
|
||
|
||
# ── Test: AgentConfig tools 字段向后兼容 ──────────────────────
|
||
|
||
|
||
class TestAgentConfigToolsBackwardCompat:
|
||
"""验证 AgentConfig 的 tools 字段向后兼容"""
|
||
|
||
def test_agent_config_with_tools_list(self):
|
||
"""AgentConfig 接受 tools 列表"""
|
||
config = AgentConfig(
|
||
name="test",
|
||
agent_type="test",
|
||
task_mode="tool_call",
|
||
tools=["baidu_search", "web_crawl"],
|
||
)
|
||
assert config.tools == ["baidu_search", "web_crawl"]
|
||
|
||
def test_agent_config_without_tools(self):
|
||
"""AgentConfig 不提供 tools 时默认为空列表"""
|
||
config = AgentConfig(
|
||
name="test",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "test", "instructions": "test"},
|
||
)
|
||
assert config.tools == []
|
||
|
||
def test_skill_config_inherits_tools(self):
|
||
"""SkillConfig 继承 AgentConfig 的 tools 字段"""
|
||
config = SkillConfig(
|
||
name="test",
|
||
agent_type="test",
|
||
task_mode="tool_call",
|
||
tools=["baidu_search"],
|
||
)
|
||
assert config.tools == ["baidu_search"]
|
||
|
||
def test_skill_config_from_dict_with_tools(self):
|
||
"""SkillConfig.from_dict 正确解析 tools 字段"""
|
||
data = {
|
||
"name": "test",
|
||
"agent_type": "test",
|
||
"task_mode": "tool_call",
|
||
"tools": ["baidu_search", "web_crawl", "schema_generate"],
|
||
}
|
||
config = SkillConfig.from_dict(data)
|
||
assert config.tools == ["baidu_search", "web_crawl", "schema_generate"]
|