265 lines
11 KiB
Python
265 lines
11 KiB
Python
"""E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM).
|
|
|
|
Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from
|
|
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
|
|
|
|
Key differences from old CostAwareRouter backtest:
|
|
- No HeuristicClassifier complexity scoring
|
|
- No IntentRouter LLM classification
|
|
- No SemanticRouter embedding matching
|
|
- RequestPreprocessor: @skill prefix + greeting regex + default REACT
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
|
from agentkit.chat.skill_routing import ExecutionMode
|
|
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
|
|
from agentkit.server.config import ServerConfig
|
|
from agentkit.skills.registry import SkillRegistry
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# Test cases — covering all known problem scenarios
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
ROUTING_TEST_CASES = [
|
|
# --- Greeting/Chitchat → DIRECT_CHAT ---
|
|
{"id": "greeting_cn", "input": "你好", "expected_mode": "direct_chat"},
|
|
{"id": "greeting_en", "input": "hello", "expected_mode": "direct_chat"},
|
|
{"id": "chitchat_thanks", "input": "谢谢", "expected_mode": "direct_chat"},
|
|
{"id": "identity_who", "input": "你是谁", "expected_mode": "direct_chat"},
|
|
|
|
# --- Tool-requiring queries → REACT ---
|
|
# These are the core problem scenarios that CostAwareRouter failed on
|
|
{"id": "colloquial_ip_1", "input": "查下ip", "expected_mode": "react"},
|
|
{"id": "colloquial_ip_2", "input": "查看当前ip", "expected_mode": "react"},
|
|
{"id": "colloquial_ip_3", "input": "获取ip地址", "expected_mode": "react"},
|
|
{"id": "colloquial_ip_4", "input": "看下ip", "expected_mode": "react"},
|
|
{"id": "colloquial_ip_5", "input": "帮我查一下ip", "expected_mode": "react"},
|
|
{"id": "tool_search", "input": "搜索golang教程", "expected_mode": "react"},
|
|
{"id": "tool_shell", "input": "执行ls命令", "expected_mode": "react"},
|
|
{"id": "tool_file", "input": "读一下配置文件", "expected_mode": "react"},
|
|
{"id": "tool_monitor", "input": "检查服务状态", "expected_mode": "react"},
|
|
{"id": "tool_download", "input": "下载这个文件", "expected_mode": "react"},
|
|
|
|
# --- Translation/knowledge → REACT (LLM decides no tool needed) ---
|
|
{"id": "translation", "input": "翻译hello为中文", "expected_mode": "react"},
|
|
{"id": "knowledge", "input": "什么是机器学习", "expected_mode": "react"},
|
|
{"id": "summarize", "input": "帮我总结一下这段话", "expected_mode": "react"},
|
|
|
|
# --- Complex queries → REACT ---
|
|
{"id": "complex_analysis", "input": "帮我分析一下这个数据并生成报告", "expected_mode": "react"},
|
|
{"id": "complex_code", "input": "重构这个函数使其更高效", "expected_mode": "react"},
|
|
{"id": "complex_multi", "input": "搜索最新的AI论文并总结关键发现", "expected_mode": "react"},
|
|
|
|
# --- @skill prefix → SKILL_REACT ---
|
|
{"id": "skill_prefix_shell", "input": "@skill:react_agent 查看当前ip", "expected_mode": "skill_react"},
|
|
]
|
|
|
|
# Paraphrase consistency test cases — same intent, different expressions
|
|
PARAPHRASE_CASES = [
|
|
{
|
|
"id": "ip_check_variants",
|
|
"original": "查看当前ip",
|
|
"paraphrases": ["查下ip", "获取ip地址", "看下ip", "帮我查一下ip", "ip是什么"],
|
|
"expected_mode": "react",
|
|
},
|
|
{
|
|
"id": "search_variants",
|
|
"original": "搜索golang教程",
|
|
"paraphrases": ["搜一下golang教程", "找下golang学习资料", "帮我搜golang入门"],
|
|
"expected_mode": "react",
|
|
},
|
|
]
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# Real component initialization
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
def _find_config_path() -> str | None:
|
|
candidates = [
|
|
os.environ.get("AGENTKIT_CONFIG", ""),
|
|
str(Path.cwd() / "agentkit.yaml"),
|
|
str(Path.home() / ".agentkit" / "agentkit.yaml"),
|
|
]
|
|
for path in candidates:
|
|
if path and Path(path).is_file():
|
|
return path
|
|
return None
|
|
|
|
|
|
def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
|
config_path = _find_config_path()
|
|
if not config_path:
|
|
pytest.skip("No agentkit.yaml found")
|
|
|
|
env_path = Path(config_path).parent / ".env"
|
|
if env_path.exists():
|
|
try:
|
|
from dotenv import load_dotenv
|
|
load_dotenv(env_path)
|
|
except ImportError:
|
|
with open(env_path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line and not line.startswith("#") and "=" in line:
|
|
key, _, value = line.partition("=")
|
|
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
|
|
|
|
server_config = ServerConfig.from_yaml(config_path)
|
|
|
|
if not server_config.has_llm_provider():
|
|
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
|
|
if dashscope_key:
|
|
for name, pconf in server_config.llm_config.providers.items():
|
|
if not pconf.api_key:
|
|
pconf.api_key = dashscope_key
|
|
if not pconf.base_url:
|
|
if dashscope_key.startswith("sk-sp-"):
|
|
pconf.base_url = "https://coding.dashscope.aliyuncs.com/v1"
|
|
else:
|
|
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
break
|
|
|
|
if not server_config.has_llm_provider():
|
|
pytest.skip("No LLM provider with valid API key")
|
|
|
|
skill_registry = _build_skill_registry(server_config)
|
|
preprocessor = RequestPreprocessor(skill_registry=skill_registry)
|
|
|
|
return preprocessor, skill_registry
|
|
|
|
|
|
_cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None
|
|
|
|
|
|
def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
|
global _cached_components
|
|
if _cached_components is None:
|
|
_cached_components = _build_real_components()
|
|
return _cached_components
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# Test classes
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
@pytest.mark.e2e_capability
|
|
class TestRequestPreprocessorBasic:
|
|
"""Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
ROUTING_TEST_CASES,
|
|
ids=[c["id"] for c in ROUTING_TEST_CASES],
|
|
)
|
|
def test_routing(self, case: dict):
|
|
preprocessor, skill_registry = _get_components()
|
|
result = asyncio.run(
|
|
preprocessor.preprocess(
|
|
content=case["input"],
|
|
skill_registry=skill_registry,
|
|
default_tools=["shell", "search", "file_read"],
|
|
)
|
|
)
|
|
actual_mode = result.execution_mode.value
|
|
expected_mode = case["expected_mode"]
|
|
assert actual_mode == expected_mode, (
|
|
f"'{case['input']}': expected {expected_mode}, got {actual_mode} "
|
|
f"(method={result.match_method}, confidence={result.match_confidence})"
|
|
)
|
|
|
|
|
|
@pytest.mark.e2e_capability
|
|
class TestRequestPreprocessorParaphraseConsistency:
|
|
"""Test that paraphrased inputs preprocess to the same execution mode."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"case",
|
|
PARAPHRASE_CASES,
|
|
ids=[c["id"] for c in PARAPHRASE_CASES],
|
|
)
|
|
def test_paraphrase_consistency(self, case: dict):
|
|
preprocessor, skill_registry = _get_components()
|
|
expected_mode = case["expected_mode"]
|
|
|
|
# Test original
|
|
result = asyncio.run(
|
|
preprocessor.preprocess(
|
|
content=case["original"],
|
|
skill_registry=skill_registry,
|
|
default_tools=["shell", "search", "file_read"],
|
|
)
|
|
)
|
|
assert result.execution_mode.value == expected_mode, (
|
|
f"Original '{case['original']}': expected {expected_mode}, got {result.execution_mode.value}"
|
|
)
|
|
|
|
# Test all paraphrases
|
|
for para in case["paraphrases"]:
|
|
result = asyncio.run(
|
|
preprocessor.preprocess(
|
|
content=para,
|
|
skill_registry=skill_registry,
|
|
default_tools=["shell", "search", "file_read"],
|
|
)
|
|
)
|
|
assert result.execution_mode.value == expected_mode, (
|
|
f"Paraphrase '{para}': expected {expected_mode}, got {result.execution_mode.value}"
|
|
)
|
|
|
|
|
|
@pytest.mark.e2e_capability
|
|
class TestRequestPreprocessorMetrics:
|
|
"""Compute and report preprocessing accuracy metrics."""
|
|
|
|
def test_accuracy_report(self):
|
|
"""Run all test cases and compute accuracy metrics."""
|
|
preprocessor, skill_registry = _get_components()
|
|
total = len(ROUTING_TEST_CASES)
|
|
correct = 0
|
|
results = []
|
|
|
|
for case in ROUTING_TEST_CASES:
|
|
result = asyncio.run(
|
|
preprocessor.preprocess(
|
|
content=case["input"],
|
|
skill_registry=skill_registry,
|
|
default_tools=["shell", "search", "file_read"],
|
|
)
|
|
)
|
|
actual_mode = result.execution_mode.value
|
|
is_correct = actual_mode == case["expected_mode"]
|
|
if is_correct:
|
|
correct += 1
|
|
results.append({
|
|
"id": case["id"],
|
|
"input": case["input"],
|
|
"expected": case["expected_mode"],
|
|
"actual": actual_mode,
|
|
"method": result.match_method,
|
|
"correct": is_correct,
|
|
})
|
|
|
|
accuracy = correct / total * 100
|
|
print(f"\n{'='*60}")
|
|
print(f"RequestPreprocessor Accuracy Report")
|
|
print(f"{'='*60}")
|
|
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
|
|
print(f"{'-'*60}")
|
|
for r in results:
|
|
status = "✓" if r["correct"] else "✗"
|
|
print(f" {status} {r['id']}: '{r['input']}' → {r['actual']} (expected {r['expected']})")
|
|
print(f"{'='*60}")
|
|
|
|
# Assert minimum accuracy threshold
|
|
assert accuracy >= 85.0, f"Accuracy {accuracy:.1f}% is below 85% threshold"
|