266 lines
11 KiB
Python
266 lines
11 KiB
Python
"""E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM).
|
||
|
||
Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from
|
||
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
|
||
|
||
Key differences from old CostAwareRouter backtest:
|
||
- No HeuristicClassifier complexity scoring
|
||
- No IntentRouter LLM classification
|
||
- No SemanticRouter embedding matching
|
||
- RequestPreprocessor: @skill prefix + greeting regex + default REACT
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||
from agentkit.chat.skill_routing import ExecutionMode
|
||
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
|
||
from agentkit.server.config import ServerConfig
|
||
from agentkit.skills.registry import SkillRegistry
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
# Test cases — covering all known problem scenarios
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
ROUTING_TEST_CASES = [
|
||
# --- Greeting/Chitchat → DIRECT_CHAT ---
|
||
{"id": "greeting_cn", "input": "你好", "expected_mode": "direct_chat"},
|
||
{"id": "greeting_en", "input": "hello", "expected_mode": "direct_chat"},
|
||
{"id": "chitchat_thanks", "input": "谢谢", "expected_mode": "direct_chat"},
|
||
{"id": "identity_who", "input": "你是谁", "expected_mode": "direct_chat"},
|
||
|
||
# --- Tool-requiring queries → REACT ---
|
||
# These are the core problem scenarios that CostAwareRouter failed on
|
||
{"id": "colloquial_ip_1", "input": "查下ip", "expected_mode": "react"},
|
||
{"id": "colloquial_ip_2", "input": "查看当前ip", "expected_mode": "react"},
|
||
{"id": "colloquial_ip_3", "input": "获取ip地址", "expected_mode": "react"},
|
||
{"id": "colloquial_ip_4", "input": "看下ip", "expected_mode": "react"},
|
||
{"id": "colloquial_ip_5", "input": "帮我查一下ip", "expected_mode": "react"},
|
||
{"id": "tool_search", "input": "搜索golang教程", "expected_mode": "react"},
|
||
{"id": "tool_shell", "input": "执行ls命令", "expected_mode": "react"},
|
||
{"id": "tool_file", "input": "读一下配置文件", "expected_mode": "react"},
|
||
{"id": "tool_monitor", "input": "检查服务状态", "expected_mode": "react"},
|
||
{"id": "tool_download", "input": "下载这个文件", "expected_mode": "react"},
|
||
|
||
# --- Translation/knowledge → REACT (LLM decides no tool needed) ---
|
||
{"id": "translation", "input": "翻译hello为中文", "expected_mode": "react"},
|
||
# U5: 纯知识问答(无工具上下文)→ DIRECT_CHAT(零成本快速路径)
|
||
{"id": "knowledge", "input": "什么是机器学习", "expected_mode": "direct_chat"},
|
||
{"id": "summarize", "input": "帮我总结一下这段话", "expected_mode": "react"},
|
||
|
||
# --- Complex queries → REACT ---
|
||
{"id": "complex_analysis", "input": "帮我分析一下这个数据并生成报告", "expected_mode": "react"},
|
||
{"id": "complex_code", "input": "重构这个函数使其更高效", "expected_mode": "react"},
|
||
{"id": "complex_multi", "input": "搜索最新的AI论文并总结关键发现", "expected_mode": "react"},
|
||
|
||
# --- @skill prefix → SKILL_REACT ---
|
||
{"id": "skill_prefix_shell", "input": "@skill:react_agent 查看当前ip", "expected_mode": "skill_react"},
|
||
]
|
||
|
||
# Paraphrase consistency test cases — same intent, different expressions
|
||
PARAPHRASE_CASES = [
|
||
{
|
||
"id": "ip_check_variants",
|
||
"original": "查看当前ip",
|
||
"paraphrases": ["查下ip", "获取ip地址", "看下ip", "帮我查一下ip", "ip是什么"],
|
||
"expected_mode": "react",
|
||
},
|
||
{
|
||
"id": "search_variants",
|
||
"original": "搜索golang教程",
|
||
"paraphrases": ["搜一下golang教程", "找下golang学习资料", "帮我搜golang入门"],
|
||
"expected_mode": "react",
|
||
},
|
||
]
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
# Real component initialization
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
def _find_config_path() -> str | None:
|
||
candidates = [
|
||
os.environ.get("AGENTKIT_CONFIG", ""),
|
||
str(Path.cwd() / "agentkit.yaml"),
|
||
str(Path.home() / ".agentkit" / "agentkit.yaml"),
|
||
]
|
||
for path in candidates:
|
||
if path and Path(path).is_file():
|
||
return path
|
||
return None
|
||
|
||
|
||
def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
||
config_path = _find_config_path()
|
||
if not config_path:
|
||
pytest.skip("No agentkit.yaml found")
|
||
|
||
env_path = Path(config_path).parent / ".env"
|
||
if env_path.exists():
|
||
try:
|
||
from dotenv import load_dotenv
|
||
load_dotenv(env_path)
|
||
except ImportError:
|
||
with open(env_path) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line and not line.startswith("#") and "=" in line:
|
||
key, _, value = line.partition("=")
|
||
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
|
||
|
||
server_config = ServerConfig.from_yaml(config_path)
|
||
|
||
if not server_config.has_llm_provider():
|
||
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||
if dashscope_key:
|
||
for name, pconf in server_config.llm_config.providers.items():
|
||
if not pconf.api_key:
|
||
pconf.api_key = dashscope_key
|
||
if not pconf.base_url:
|
||
if dashscope_key.startswith("sk-sp-"):
|
||
pconf.base_url = "https://coding.dashscope.aliyuncs.com/v1"
|
||
else:
|
||
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
break
|
||
|
||
if not server_config.has_llm_provider():
|
||
pytest.skip("No LLM provider with valid API key")
|
||
|
||
skill_registry = _build_skill_registry(server_config)
|
||
preprocessor = RequestPreprocessor(skill_registry=skill_registry)
|
||
|
||
return preprocessor, skill_registry
|
||
|
||
|
||
_cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None
|
||
|
||
|
||
def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
||
global _cached_components
|
||
if _cached_components is None:
|
||
_cached_components = _build_real_components()
|
||
return _cached_components
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
# Test classes
|
||
# ═══════════════════════════════════════════════════════════════════════════
|
||
|
||
|
||
@pytest.mark.e2e_capability
|
||
class TestRequestPreprocessorBasic:
|
||
"""Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT."""
|
||
|
||
@pytest.mark.parametrize(
|
||
"case",
|
||
ROUTING_TEST_CASES,
|
||
ids=[c["id"] for c in ROUTING_TEST_CASES],
|
||
)
|
||
def test_routing(self, case: dict):
|
||
preprocessor, skill_registry = _get_components()
|
||
result = asyncio.run(
|
||
preprocessor.preprocess(
|
||
content=case["input"],
|
||
skill_registry=skill_registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
)
|
||
)
|
||
actual_mode = result.execution_mode.value
|
||
expected_mode = case["expected_mode"]
|
||
assert actual_mode == expected_mode, (
|
||
f"'{case['input']}': expected {expected_mode}, got {actual_mode} "
|
||
f"(method={result.match_method}, confidence={result.match_confidence})"
|
||
)
|
||
|
||
|
||
@pytest.mark.e2e_capability
|
||
class TestRequestPreprocessorParaphraseConsistency:
|
||
"""Test that paraphrased inputs preprocess to the same execution mode."""
|
||
|
||
@pytest.mark.parametrize(
|
||
"case",
|
||
PARAPHRASE_CASES,
|
||
ids=[c["id"] for c in PARAPHRASE_CASES],
|
||
)
|
||
def test_paraphrase_consistency(self, case: dict):
|
||
preprocessor, skill_registry = _get_components()
|
||
expected_mode = case["expected_mode"]
|
||
|
||
# Test original
|
||
result = asyncio.run(
|
||
preprocessor.preprocess(
|
||
content=case["original"],
|
||
skill_registry=skill_registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
)
|
||
)
|
||
assert result.execution_mode.value == expected_mode, (
|
||
f"Original '{case['original']}': expected {expected_mode}, got {result.execution_mode.value}"
|
||
)
|
||
|
||
# Test all paraphrases
|
||
for para in case["paraphrases"]:
|
||
result = asyncio.run(
|
||
preprocessor.preprocess(
|
||
content=para,
|
||
skill_registry=skill_registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
)
|
||
)
|
||
assert result.execution_mode.value == expected_mode, (
|
||
f"Paraphrase '{para}': expected {expected_mode}, got {result.execution_mode.value}"
|
||
)
|
||
|
||
|
||
@pytest.mark.e2e_capability
|
||
class TestRequestPreprocessorMetrics:
|
||
"""Compute and report preprocessing accuracy metrics."""
|
||
|
||
def test_accuracy_report(self):
|
||
"""Run all test cases and compute accuracy metrics."""
|
||
preprocessor, skill_registry = _get_components()
|
||
total = len(ROUTING_TEST_CASES)
|
||
correct = 0
|
||
results = []
|
||
|
||
for case in ROUTING_TEST_CASES:
|
||
result = asyncio.run(
|
||
preprocessor.preprocess(
|
||
content=case["input"],
|
||
skill_registry=skill_registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
)
|
||
)
|
||
actual_mode = result.execution_mode.value
|
||
is_correct = actual_mode == case["expected_mode"]
|
||
if is_correct:
|
||
correct += 1
|
||
results.append({
|
||
"id": case["id"],
|
||
"input": case["input"],
|
||
"expected": case["expected_mode"],
|
||
"actual": actual_mode,
|
||
"method": result.match_method,
|
||
"correct": is_correct,
|
||
})
|
||
|
||
accuracy = correct / total * 100
|
||
print(f"\n{'='*60}")
|
||
print(f"RequestPreprocessor Accuracy Report")
|
||
print(f"{'='*60}")
|
||
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
|
||
print(f"{'-'*60}")
|
||
for r in results:
|
||
status = "OK" if r["correct"] else "FAIL"
|
||
print(f" {status} {r['id']}: '{r['input']}' → {r['actual']} (expected {r['expected']})")
|
||
print(f"{'='*60}")
|
||
|
||
# Assert minimum accuracy threshold
|
||
assert accuracy >= 85.0, f"Accuracy {accuracy:.1f}% is below 85% threshold"
|