fischer-agentkit/tests/e2e/test_simple_router_backtest.py

265 lines
11 KiB
Python

"""E2E Agent Capability Tests — SimpleRouter Backtest (Real LLM).
Tests SimpleRouter.route() using real LLM configuration loaded from
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
Key differences from old CostAwareRouter backtest:
- No HeuristicClassifier complexity scoring
- No IntentRouter LLM classification
- No SemanticRouter embedding matching
- SimpleRouter: @skill prefix + greeting regex + default REACT
"""
import asyncio
import os
from pathlib import Path
import pytest
from agentkit.chat.simple_router import SimpleRouter
from agentkit.chat.skill_routing import ExecutionMode
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
from agentkit.server.config import ServerConfig
from agentkit.skills.registry import SkillRegistry
# ═══════════════════════════════════════════════════════════════════════════
# Test cases — covering all known problem scenarios
# ═══════════════════════════════════════════════════════════════════════════
ROUTING_TEST_CASES = [
# --- Greeting/Chitchat → DIRECT_CHAT ---
{"id": "greeting_cn", "input": "你好", "expected_mode": "direct_chat"},
{"id": "greeting_en", "input": "hello", "expected_mode": "direct_chat"},
{"id": "chitchat_thanks", "input": "谢谢", "expected_mode": "direct_chat"},
{"id": "identity_who", "input": "你是谁", "expected_mode": "direct_chat"},
# --- Tool-requiring queries → REACT ---
# These are the core problem scenarios that CostAwareRouter failed on
{"id": "colloquial_ip_1", "input": "查下ip", "expected_mode": "react"},
{"id": "colloquial_ip_2", "input": "查看当前ip", "expected_mode": "react"},
{"id": "colloquial_ip_3", "input": "获取ip地址", "expected_mode": "react"},
{"id": "colloquial_ip_4", "input": "看下ip", "expected_mode": "react"},
{"id": "colloquial_ip_5", "input": "帮我查一下ip", "expected_mode": "react"},
{"id": "tool_search", "input": "搜索golang教程", "expected_mode": "react"},
{"id": "tool_shell", "input": "执行ls命令", "expected_mode": "react"},
{"id": "tool_file", "input": "读一下配置文件", "expected_mode": "react"},
{"id": "tool_monitor", "input": "检查服务状态", "expected_mode": "react"},
{"id": "tool_download", "input": "下载这个文件", "expected_mode": "react"},
# --- Translation/knowledge → REACT (LLM decides no tool needed) ---
{"id": "translation", "input": "翻译hello为中文", "expected_mode": "react"},
{"id": "knowledge", "input": "什么是机器学习", "expected_mode": "react"},
{"id": "summarize", "input": "帮我总结一下这段话", "expected_mode": "react"},
# --- Complex queries → REACT ---
{"id": "complex_analysis", "input": "帮我分析一下这个数据并生成报告", "expected_mode": "react"},
{"id": "complex_code", "input": "重构这个函数使其更高效", "expected_mode": "react"},
{"id": "complex_multi", "input": "搜索最新的AI论文并总结关键发现", "expected_mode": "react"},
# --- @skill prefix → SKILL_REACT ---
{"id": "skill_prefix_shell", "input": "@skill:react_agent 查看当前ip", "expected_mode": "skill_react"},
]
# Paraphrase consistency test cases — same intent, different expressions
PARAPHRASE_CASES = [
{
"id": "ip_check_variants",
"original": "查看当前ip",
"paraphrases": ["查下ip", "获取ip地址", "看下ip", "帮我查一下ip", "ip是什么"],
"expected_mode": "react",
},
{
"id": "search_variants",
"original": "搜索golang教程",
"paraphrases": ["搜一下golang教程", "找下golang学习资料", "帮我搜golang入门"],
"expected_mode": "react",
},
]
# ═══════════════════════════════════════════════════════════════════════════
# Real component initialization
# ═══════════════════════════════════════════════════════════════════════════
def _find_config_path() -> str | None:
candidates = [
os.environ.get("AGENTKIT_CONFIG", ""),
str(Path.cwd() / "agentkit.yaml"),
str(Path.home() / ".agentkit" / "agentkit.yaml"),
]
for path in candidates:
if path and Path(path).is_file():
return path
return None
def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]:
config_path = _find_config_path()
if not config_path:
pytest.skip("No agentkit.yaml found")
env_path = Path(config_path).parent / ".env"
if env_path.exists():
try:
from dotenv import load_dotenv
load_dotenv(env_path)
except ImportError:
with open(env_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
server_config = ServerConfig.from_yaml(config_path)
if not server_config.has_llm_provider():
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
if dashscope_key:
for name, pconf in server_config.llm_config.providers.items():
if not pconf.api_key:
pconf.api_key = dashscope_key
if not pconf.base_url:
if dashscope_key.startswith("sk-sp-"):
pconf.base_url = "https://coding.dashscope.aliyuncs.com/v1"
else:
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
break
if not server_config.has_llm_provider():
pytest.skip("No LLM provider with valid API key")
skill_registry = _build_skill_registry(server_config)
router = SimpleRouter(skill_registry=skill_registry)
return router, skill_registry
_cached_components: tuple[SimpleRouter, SkillRegistry] | None = None
def _get_components() -> tuple[SimpleRouter, SkillRegistry]:
global _cached_components
if _cached_components is None:
_cached_components = _build_real_components()
return _cached_components
# ═══════════════════════════════════════════════════════════════════════════
# Test classes
# ═══════════════════════════════════════════════════════════════════════════
@pytest.mark.e2e_capability
class TestSimpleRouterBasic:
"""Test SimpleRouter basic routing: greeting → DIRECT_CHAT, others → REACT."""
@pytest.mark.parametrize(
"case",
ROUTING_TEST_CASES,
ids=[c["id"] for c in ROUTING_TEST_CASES],
)
def test_routing(self, case: dict):
router, skill_registry = _get_components()
result = asyncio.run(
router.route(
content=case["input"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
actual_mode = result.execution_mode.value
expected_mode = case["expected_mode"]
assert actual_mode == expected_mode, (
f"'{case['input']}': expected {expected_mode}, got {actual_mode} "
f"(method={result.match_method}, confidence={result.match_confidence})"
)
@pytest.mark.e2e_capability
class TestSimpleRouterParaphraseConsistency:
"""Test that paraphrased inputs route to the same execution mode."""
@pytest.mark.parametrize(
"case",
PARAPHRASE_CASES,
ids=[c["id"] for c in PARAPHRASE_CASES],
)
def test_paraphrase_consistency(self, case: dict):
router, skill_registry = _get_components()
expected_mode = case["expected_mode"]
# Test original
result = asyncio.run(
router.route(
content=case["original"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
assert result.execution_mode.value == expected_mode, (
f"Original '{case['original']}': expected {expected_mode}, got {result.execution_mode.value}"
)
# Test all paraphrases
for para in case["paraphrases"]:
result = asyncio.run(
router.route(
content=para,
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
assert result.execution_mode.value == expected_mode, (
f"Paraphrase '{para}': expected {expected_mode}, got {result.execution_mode.value}"
)
@pytest.mark.e2e_capability
class TestSimpleRouterMetrics:
"""Compute and report routing accuracy metrics."""
def test_accuracy_report(self):
"""Run all test cases and compute accuracy metrics."""
router, skill_registry = _get_components()
total = len(ROUTING_TEST_CASES)
correct = 0
results = []
for case in ROUTING_TEST_CASES:
result = asyncio.run(
router.route(
content=case["input"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
actual_mode = result.execution_mode.value
is_correct = actual_mode == case["expected_mode"]
if is_correct:
correct += 1
results.append({
"id": case["id"],
"input": case["input"],
"expected": case["expected_mode"],
"actual": actual_mode,
"method": result.match_method,
"correct": is_correct,
})
accuracy = correct / total * 100
print(f"\n{'='*60}")
print(f"SimpleRouter Accuracy Report")
print(f"{'='*60}")
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
print(f"{'-'*60}")
for r in results:
status = "" if r["correct"] else ""
print(f" {status} {r['id']}: '{r['input']}'{r['actual']} (expected {r['expected']})")
print(f"{'='*60}")
# Assert minimum accuracy threshold
assert accuracy >= 85.0, f"Accuracy {accuracy:.1f}% is below 85% threshold"