fischer-agentkit/tests/e2e/test_request_preprocessor_b...

266 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM).
Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
Key differences from old CostAwareRouter backtest:
- No HeuristicClassifier complexity scoring
- No IntentRouter LLM classification
- No SemanticRouter embedding matching
- RequestPreprocessor: @skill prefix + greeting regex + default REACT
"""
import asyncio
import os
from pathlib import Path
import pytest
from agentkit.chat.request_preprocessor import RequestPreprocessor
from agentkit.chat.skill_routing import ExecutionMode
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
from agentkit.server.config import ServerConfig
from agentkit.skills.registry import SkillRegistry
# ═══════════════════════════════════════════════════════════════════════════
# Test cases — covering all known problem scenarios
# ═══════════════════════════════════════════════════════════════════════════
ROUTING_TEST_CASES = [
# --- Greeting/Chitchat → DIRECT_CHAT ---
{"id": "greeting_cn", "input": "你好", "expected_mode": "direct_chat"},
{"id": "greeting_en", "input": "hello", "expected_mode": "direct_chat"},
{"id": "chitchat_thanks", "input": "谢谢", "expected_mode": "direct_chat"},
{"id": "identity_who", "input": "你是谁", "expected_mode": "direct_chat"},
# --- Tool-requiring queries → REACT ---
# These are the core problem scenarios that CostAwareRouter failed on
{"id": "colloquial_ip_1", "input": "查下ip", "expected_mode": "react"},
{"id": "colloquial_ip_2", "input": "查看当前ip", "expected_mode": "react"},
{"id": "colloquial_ip_3", "input": "获取ip地址", "expected_mode": "react"},
{"id": "colloquial_ip_4", "input": "看下ip", "expected_mode": "react"},
{"id": "colloquial_ip_5", "input": "帮我查一下ip", "expected_mode": "react"},
{"id": "tool_search", "input": "搜索golang教程", "expected_mode": "react"},
{"id": "tool_shell", "input": "执行ls命令", "expected_mode": "react"},
{"id": "tool_file", "input": "读一下配置文件", "expected_mode": "react"},
{"id": "tool_monitor", "input": "检查服务状态", "expected_mode": "react"},
{"id": "tool_download", "input": "下载这个文件", "expected_mode": "react"},
# --- Translation/knowledge → REACT (LLM decides no tool needed) ---
{"id": "translation", "input": "翻译hello为中文", "expected_mode": "react"},
# U5: 纯知识问答(无工具上下文)→ DIRECT_CHAT零成本快速路径
{"id": "knowledge", "input": "什么是机器学习", "expected_mode": "direct_chat"},
{"id": "summarize", "input": "帮我总结一下这段话", "expected_mode": "react"},
# --- Complex queries → REACT ---
{"id": "complex_analysis", "input": "帮我分析一下这个数据并生成报告", "expected_mode": "react"},
{"id": "complex_code", "input": "重构这个函数使其更高效", "expected_mode": "react"},
{"id": "complex_multi", "input": "搜索最新的AI论文并总结关键发现", "expected_mode": "react"},
# --- @skill prefix → SKILL_REACT ---
{"id": "skill_prefix_shell", "input": "@skill:react_agent 查看当前ip", "expected_mode": "skill_react"},
]
# Paraphrase consistency test cases — same intent, different expressions
PARAPHRASE_CASES = [
{
"id": "ip_check_variants",
"original": "查看当前ip",
"paraphrases": ["查下ip", "获取ip地址", "看下ip", "帮我查一下ip", "ip是什么"],
"expected_mode": "react",
},
{
"id": "search_variants",
"original": "搜索golang教程",
"paraphrases": ["搜一下golang教程", "找下golang学习资料", "帮我搜golang入门"],
"expected_mode": "react",
},
]
# ═══════════════════════════════════════════════════════════════════════════
# Real component initialization
# ═══════════════════════════════════════════════════════════════════════════
def _find_config_path() -> str | None:
candidates = [
os.environ.get("AGENTKIT_CONFIG", ""),
str(Path.cwd() / "agentkit.yaml"),
str(Path.home() / ".agentkit" / "agentkit.yaml"),
]
for path in candidates:
if path and Path(path).is_file():
return path
return None
def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]:
config_path = _find_config_path()
if not config_path:
pytest.skip("No agentkit.yaml found")
env_path = Path(config_path).parent / ".env"
if env_path.exists():
try:
from dotenv import load_dotenv
load_dotenv(env_path)
except ImportError:
with open(env_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
server_config = ServerConfig.from_yaml(config_path)
if not server_config.has_llm_provider():
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
if dashscope_key:
for name, pconf in server_config.llm_config.providers.items():
if not pconf.api_key:
pconf.api_key = dashscope_key
if not pconf.base_url:
if dashscope_key.startswith("sk-sp-"):
pconf.base_url = "https://coding.dashscope.aliyuncs.com/v1"
else:
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
break
if not server_config.has_llm_provider():
pytest.skip("No LLM provider with valid API key")
skill_registry = _build_skill_registry(server_config)
preprocessor = RequestPreprocessor(skill_registry=skill_registry)
return preprocessor, skill_registry
_cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None
def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]:
global _cached_components
if _cached_components is None:
_cached_components = _build_real_components()
return _cached_components
# ═══════════════════════════════════════════════════════════════════════════
# Test classes
# ═══════════════════════════════════════════════════════════════════════════
@pytest.mark.e2e_capability
class TestRequestPreprocessorBasic:
"""Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT."""
@pytest.mark.parametrize(
"case",
ROUTING_TEST_CASES,
ids=[c["id"] for c in ROUTING_TEST_CASES],
)
def test_routing(self, case: dict):
preprocessor, skill_registry = _get_components()
result = asyncio.run(
preprocessor.preprocess(
content=case["input"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
actual_mode = result.execution_mode.value
expected_mode = case["expected_mode"]
assert actual_mode == expected_mode, (
f"'{case['input']}': expected {expected_mode}, got {actual_mode} "
f"(method={result.match_method}, confidence={result.match_confidence})"
)
@pytest.mark.e2e_capability
class TestRequestPreprocessorParaphraseConsistency:
"""Test that paraphrased inputs preprocess to the same execution mode."""
@pytest.mark.parametrize(
"case",
PARAPHRASE_CASES,
ids=[c["id"] for c in PARAPHRASE_CASES],
)
def test_paraphrase_consistency(self, case: dict):
preprocessor, skill_registry = _get_components()
expected_mode = case["expected_mode"]
# Test original
result = asyncio.run(
preprocessor.preprocess(
content=case["original"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
assert result.execution_mode.value == expected_mode, (
f"Original '{case['original']}': expected {expected_mode}, got {result.execution_mode.value}"
)
# Test all paraphrases
for para in case["paraphrases"]:
result = asyncio.run(
preprocessor.preprocess(
content=para,
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
assert result.execution_mode.value == expected_mode, (
f"Paraphrase '{para}': expected {expected_mode}, got {result.execution_mode.value}"
)
@pytest.mark.e2e_capability
class TestRequestPreprocessorMetrics:
"""Compute and report preprocessing accuracy metrics."""
def test_accuracy_report(self):
"""Run all test cases and compute accuracy metrics."""
preprocessor, skill_registry = _get_components()
total = len(ROUTING_TEST_CASES)
correct = 0
results = []
for case in ROUTING_TEST_CASES:
result = asyncio.run(
preprocessor.preprocess(
content=case["input"],
skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"],
)
)
actual_mode = result.execution_mode.value
is_correct = actual_mode == case["expected_mode"]
if is_correct:
correct += 1
results.append({
"id": case["id"],
"input": case["input"],
"expected": case["expected_mode"],
"actual": actual_mode,
"method": result.match_method,
"correct": is_correct,
})
accuracy = correct / total * 100
print(f"\n{'='*60}")
print(f"RequestPreprocessor Accuracy Report")
print(f"{'='*60}")
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
print(f"{'-'*60}")
for r in results:
status = "" if r["correct"] else ""
print(f" {status} {r['id']}: '{r['input']}'{r['actual']} (expected {r['expected']})")
print(f"{'='*60}")
# Assert minimum accuracy threshold
assert accuracy >= 85.0, f"Accuracy {accuracy:.1f}% is below 85% threshold"