234 lines
8.4 KiB
Python
234 lines
8.4 KiB
Python
"""ExpertTemplateRegistry 单元测试"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
|
|
|
|
|
# ── 辅助函数 ──────────────────────────────────────────────
|
|
|
|
|
|
def _make_template(
|
|
name: str = "test_template",
|
|
persona: str = "测试专家",
|
|
description: str = "测试模板描述",
|
|
is_builtin: bool = False,
|
|
bound_skills: list[str] | None = None,
|
|
) -> ExpertTemplate:
|
|
"""创建测试用 ExpertTemplate 实例"""
|
|
config = ExpertConfig(
|
|
name=name,
|
|
agent_type="expert",
|
|
persona=persona,
|
|
task_mode="llm_generate",
|
|
prompt={"identity": persona},
|
|
bound_skills=bound_skills or [],
|
|
)
|
|
return ExpertTemplate(
|
|
name=name,
|
|
config=config,
|
|
is_builtin=is_builtin,
|
|
description=description,
|
|
)
|
|
|
|
|
|
def _write_yaml_file(directory: str, filename: str, data: dict) -> str:
|
|
"""写入临时 YAML 文件并返回路径"""
|
|
filepath = os.path.join(directory, filename)
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
yaml.dump(data, f, allow_unicode=True)
|
|
return filepath
|
|
|
|
|
|
# ── ExpertTemplateRegistry 测试 ───────────────────────────
|
|
|
|
|
|
class TestExpertTemplateRegistry:
|
|
"""ExpertTemplateRegistry 注册中心测试"""
|
|
|
|
def test_register_and_get(self):
|
|
"""注册并获取模板"""
|
|
registry = ExpertTemplateRegistry()
|
|
template = _make_template("analyst", persona="分析师")
|
|
registry.register(template)
|
|
|
|
result = registry.get("analyst")
|
|
assert result is not None
|
|
assert result.name == "analyst"
|
|
assert result.config.persona == "分析师"
|
|
|
|
def test_get_nonexistent_returns_none(self):
|
|
"""获取不存在的模板返回 None"""
|
|
registry = ExpertTemplateRegistry()
|
|
assert registry.get("nonexistent") is None
|
|
|
|
def test_list_all_templates(self):
|
|
"""列出所有模板"""
|
|
registry = ExpertTemplateRegistry()
|
|
registry.register(_make_template("a", description="模板A"))
|
|
registry.register(_make_template("b", description="模板B"))
|
|
registry.register(_make_template("c", description="模板C"))
|
|
|
|
templates = registry.list()
|
|
names = {t.name for t in templates}
|
|
assert names == {"a", "b", "c"}
|
|
|
|
def test_list_empty_registry(self):
|
|
"""空注册中心返回空列表"""
|
|
registry = ExpertTemplateRegistry()
|
|
assert registry.list() == []
|
|
|
|
def test_search_by_name_case_insensitive(self):
|
|
"""按名称搜索(大小写不敏感)"""
|
|
registry = ExpertTemplateRegistry()
|
|
registry.register(_make_template("DataAnalyst", description="数据分析师"))
|
|
registry.register(_make_template("CodeReviewer", description="代码审查员"))
|
|
|
|
results = registry.search("data")
|
|
assert len(results) == 1
|
|
assert results[0].name == "DataAnalyst"
|
|
|
|
results = registry.search("CODEREVIEWER")
|
|
assert len(results) == 1
|
|
assert results[0].name == "CodeReviewer"
|
|
|
|
def test_search_by_description(self):
|
|
"""按描述搜索"""
|
|
registry = ExpertTemplateRegistry()
|
|
registry.register(_make_template("analyst", description="数据分析专家"))
|
|
registry.register(_make_template("writer", description="内容创作专家"))
|
|
|
|
results = registry.search("数据")
|
|
assert len(results) == 1
|
|
assert results[0].name == "analyst"
|
|
|
|
results = registry.search("创作")
|
|
assert len(results) == 1
|
|
assert results[0].name == "writer"
|
|
|
|
def test_search_no_matches_returns_empty(self):
|
|
"""搜索无匹配返回空列表"""
|
|
registry = ExpertTemplateRegistry()
|
|
registry.register(_make_template("analyst", description="数据分析师"))
|
|
|
|
results = registry.search("nonexistent_keyword")
|
|
assert results == []
|
|
|
|
def test_register_overwrites_same_name(self):
|
|
"""同名模板注册覆盖旧模板"""
|
|
registry = ExpertTemplateRegistry()
|
|
v1 = _make_template("expert_a", persona="版本1", description="旧版本")
|
|
v2 = _make_template("expert_a", persona="版本2", description="新版本")
|
|
|
|
registry.register(v1)
|
|
registry.register(v2)
|
|
|
|
result = registry.get("expert_a")
|
|
assert result is not None
|
|
assert result.config.persona == "版本2"
|
|
assert result.description == "新版本"
|
|
# 确保只有一个
|
|
assert len(registry.list()) == 1
|
|
|
|
def test_load_from_yaml(self):
|
|
"""从 YAML 文件加载模板"""
|
|
yaml_data = {
|
|
"name": "yaml_expert",
|
|
"is_builtin": False,
|
|
"description": "YAML 加载的专家",
|
|
"config": {
|
|
"name": "yaml_expert",
|
|
"agent_type": "expert",
|
|
"persona": "YAML 专家",
|
|
"thinking_style": "结构化思维",
|
|
"collaboration_strategy": "cooperative",
|
|
"bound_skills": ["skill_a", "skill_b"],
|
|
"avatar": "🤖",
|
|
"color": "#fa8c16",
|
|
"is_lead": False,
|
|
"task_mode": "llm_generate",
|
|
"prompt": {"identity": "YAML 专家"},
|
|
},
|
|
}
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
filepath = _write_yaml_file(tmpdir, "expert.yaml", yaml_data)
|
|
registry = ExpertTemplateRegistry()
|
|
template = registry.load_from_yaml(filepath)
|
|
|
|
assert template.name == "yaml_expert"
|
|
assert template.config.persona == "YAML 专家"
|
|
assert template.config.thinking_style == "结构化思维"
|
|
assert template.config.bound_skills == ["skill_a", "skill_b"]
|
|
assert template.config.avatar == "🤖"
|
|
assert template.config.color == "#fa8c16"
|
|
# 同时注册到 registry
|
|
assert registry.get("yaml_expert") is template
|
|
|
|
def test_load_from_directory(self):
|
|
"""从目录批量加载模板"""
|
|
yaml_data_a = {
|
|
"name": "dir_expert_a",
|
|
"description": "目录专家A",
|
|
"config": {
|
|
"name": "dir_expert_a",
|
|
"agent_type": "expert",
|
|
"persona": "专家A",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {"identity": "专家A"},
|
|
},
|
|
}
|
|
yaml_data_b = {
|
|
"name": "dir_expert_b",
|
|
"description": "目录专家B",
|
|
"config": {
|
|
"name": "dir_expert_b",
|
|
"agent_type": "expert",
|
|
"persona": "专家B",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {"identity": "专家B"},
|
|
},
|
|
}
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
_write_yaml_file(tmpdir, "expert_a.yaml", yaml_data_a)
|
|
_write_yaml_file(tmpdir, "expert_b.yml", yaml_data_b)
|
|
# 非 YAML 文件应被忽略
|
|
with open(os.path.join(tmpdir, "readme.txt"), "w") as f:
|
|
f.write("not a yaml")
|
|
|
|
registry = ExpertTemplateRegistry()
|
|
loaded = registry.load_from_directory(tmpdir)
|
|
|
|
assert len(loaded) == 2
|
|
names = {t.name for t in loaded}
|
|
assert "dir_expert_a" in names
|
|
assert "dir_expert_b" in names
|
|
# 同时注册到 registry
|
|
assert registry.get("dir_expert_a") is not None
|
|
assert registry.get("dir_expert_b") is not None
|
|
|
|
def test_load_from_directory_nonexistent(self):
|
|
"""从不存在的目录加载返回空列表"""
|
|
registry = ExpertTemplateRegistry()
|
|
loaded = registry.load_from_directory("/nonexistent/path")
|
|
assert loaded == []
|
|
|
|
def test_load_from_yaml_invalid_format(self):
|
|
"""加载非字典格式的 YAML 抛出异常"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
filepath = os.path.join(tmpdir, "invalid.yaml")
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
yaml.dump(["not", "a", "dict"], f)
|
|
|
|
registry = ExpertTemplateRegistry()
|
|
with pytest.raises(Exception):
|
|
registry.load_from_yaml(filepath)
|