fischer-agentkit/tests/unit/test_query_transformer.py

336 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""QueryTransformer 单元测试"""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.query_transformer import (
LLMQueryTransformer,
NoOpQueryTransformer,
QueryTransformerBase,
RuleQueryTransformer,
TransformedQuery,
create_query_transformer,
)
# ── In-Memory Memory 实现(用于测试) ────────────────────
class InMemoryMemory(Memory):
"""基于内存的 Memory 实现,用于测试"""
def __init__(self):
self._store: dict[str, MemoryItem] = {}
async def store(self, key: str, value, metadata=None) -> None:
self._store[key] = MemoryItem(
key=key, value=value, metadata=metadata or {}, score=1.0
)
async def retrieve(self, key: str) -> MemoryItem | None:
return self._store.get(key)
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
results = []
for item in self._store.values():
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
results.append(item)
return results[:top_k]
async def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
# ── TestTransformedQuery ──────────────────────────────────
class TestTransformedQuery:
"""TransformedQuery dataclass 测试"""
def test_creation_and_field_access(self):
tq = TransformedQuery(
main_query="SEO策略",
sub_queries=["搜索引擎优化策略"],
original_query="帮我分析一下SEO策略",
)
assert tq.main_query == "SEO策略"
assert tq.sub_queries == ["搜索引擎优化策略"]
assert tq.original_query == "帮我分析一下SEO策略"
def test_empty_sub_queries(self):
tq = TransformedQuery(main_query="AI趋势", sub_queries=[], original_query="AI趋势")
assert tq.sub_queries == []
# ── TestLLMQueryTransformer ───────────────────────────────
class TestLLMQueryTransformer:
"""LLMQueryTransformer 测试"""
async def test_successful_transformation(self):
"""LLM 返回有效 JSON验证 main_query 和 sub_queries"""
gateway = AsyncMock()
gateway.chat.return_value = MagicMock(
content=json.dumps({
"main_query": "SEO optimization strategies",
"sub_queries": ["search engine ranking", "keyword research"],
})
)
transformer = LLMQueryTransformer(gateway)
result = await transformer.transform("How to improve SEO?")
assert result.main_query == "SEO optimization strategies"
assert len(result.sub_queries) == 2
assert "search engine ranking" in result.sub_queries
assert result.original_query == "How to improve SEO?"
async def test_llm_error_fallback(self):
"""LLM 抛出异常,回退到原始查询"""
gateway = AsyncMock()
gateway.chat.side_effect = Exception("LLM service unavailable")
transformer = LLMQueryTransformer(gateway)
result = await transformer.transform("test query")
assert result.main_query == "test query"
assert result.sub_queries == []
assert result.original_query == "test query"
async def test_invalid_json_response(self):
"""LLM 返回非 JSON回退到原始查询"""
gateway = AsyncMock()
gateway.chat.return_value = MagicMock(content="This is not JSON")
transformer = LLMQueryTransformer(gateway)
result = await transformer.transform("test query")
assert result.main_query == "test query"
assert result.sub_queries == []
async def test_max_sub_queries_limit(self):
"""LLM 返回 5 个 sub_queries但 max_sub_queries=3只保留 3 个"""
gateway = AsyncMock()
gateway.chat.return_value = MagicMock(
content=json.dumps({
"main_query": "query",
"sub_queries": ["sq1", "sq2", "sq3", "sq4", "sq5"],
})
)
transformer = LLMQueryTransformer(gateway, max_sub_queries=3)
result = await transformer.transform("test")
assert len(result.sub_queries) == 3
assert result.sub_queries == ["sq1", "sq2", "sq3"]
async def test_prompt_contains_original_query(self):
"""验证发送给 LLM 的 prompt 包含原始查询"""
gateway = AsyncMock()
gateway.chat.return_value = MagicMock(
content=json.dumps({"main_query": "q", "sub_queries": []})
)
transformer = LLMQueryTransformer(gateway)
await transformer.transform("my original query")
call_args = gateway.chat.call_args
messages = call_args.kwargs.get("messages") or call_args[1].get("messages") or call_args[0][0]
# The prompt should contain the original query
prompt_text = messages[0]["content"]
assert "my original query" in prompt_text
# ── TestRuleQueryTransformer ──────────────────────────────
class TestRuleQueryTransformer:
"""RuleQueryTransformer 测试"""
async def test_chinese_filler_word_removal(self):
"""去除中文填充词:'帮我分析一下SEO策略' → main_query 包含 'SEO策略'"""
transformer = RuleQueryTransformer()
result = await transformer.transform("帮我分析一下SEO策略")
assert "SEO策略" in result.main_query
assert "帮我" not in result.main_query
assert "一下" not in result.main_query
assert result.original_query == "帮我分析一下SEO策略"
async def test_english_filler_word_removal(self):
"""去除英文填充词:'Please help me analyze' → main_query 包含 'analyze'"""
transformer = RuleQueryTransformer()
result = await transformer.transform("Please help me analyze")
assert "analyze" in result.main_query
assert "Please" not in result.main_query
assert "help me" not in result.main_query
async def test_synonym_expansion(self):
"""同义扩展SEO → 搜索引擎优化, Search Engine Optimization"""
synonyms = {"SEO": ["搜索引擎优化", "Search Engine Optimization"]}
transformer = RuleQueryTransformer(synonyms=synonyms)
result = await transformer.transform("SEO策略")
assert "SEO策略" in result.main_query
assert len(result.sub_queries) == 2
assert any("搜索引擎优化" in sq for sq in result.sub_queries)
assert any("Search Engine Optimization" in sq for sq in result.sub_queries)
async def test_no_op_for_clean_query(self):
"""干净查询原样返回:'AI行业趋势' → 不变"""
transformer = RuleQueryTransformer()
result = await transformer.transform("AI行业趋势")
assert result.main_query == "AI行业趋势"
assert result.sub_queries == []
async def test_max_sub_queries_limit(self):
"""同义扩展受 max_sub_queries 限制"""
synonyms = {"AI": ["人工智能", "Artificial Intelligence", "machine intelligence", "ML"]}
transformer = RuleQueryTransformer(synonyms=synonyms, max_sub_queries=2)
result = await transformer.transform("AI trends")
assert len(result.sub_queries) <= 2
# ── TestNoOpQueryTransformer ──────────────────────────────
class TestNoOpQueryTransformer:
"""NoOpQueryTransformer 测试"""
async def test_returns_original_query_unchanged(self):
"""原样返回原始查询"""
transformer = NoOpQueryTransformer()
result = await transformer.transform("帮我分析一下SEO策略")
assert result.main_query == "帮我分析一下SEO策略"
assert result.sub_queries == []
assert result.original_query == "帮我分析一下SEO策略"
# ── TestCreateQueryTransformer ────────────────────────────
class TestCreateQueryTransformer:
"""create_query_transformer 工厂函数测试"""
def test_llm_strategy(self):
"""strategy='llm' 创建 LLMQueryTransformer"""
gateway = AsyncMock()
transformer = create_query_transformer(strategy="llm", llm_gateway=gateway)
assert isinstance(transformer, LLMQueryTransformer)
def test_rule_strategy(self):
"""strategy='rule' 创建 RuleQueryTransformer"""
transformer = create_query_transformer(strategy="rule")
assert isinstance(transformer, RuleQueryTransformer)
def test_none_strategy(self):
"""strategy='none' 创建 NoOpQueryTransformer"""
transformer = create_query_transformer(strategy="none")
assert isinstance(transformer, NoOpQueryTransformer)
def test_unknown_strategy_defaults_to_noop(self):
"""未知 strategy 默认创建 NoOpQueryTransformer"""
transformer = create_query_transformer(strategy="unknown")
assert isinstance(transformer, NoOpQueryTransformer)
def test_llm_strategy_without_gateway_falls_back(self):
"""strategy='llm' 但无 gateway 时回退到 NoOp"""
transformer = create_query_transformer(strategy="llm", llm_gateway=None)
assert isinstance(transformer, NoOpQueryTransformer)
# ── TestMemoryRetrieverWithTransformer ────────────────────
class TestMemoryRetrieverWithTransformer:
"""MemoryRetriever 集成 QueryTransformer 测试"""
async def test_retrieve_calls_transformer_before_search(self):
"""retrieve() 在搜索前调用 transformer"""
memory = InMemoryMemory()
await memory.store("k1", "SEO optimization content")
transformer = AsyncMock(spec=QueryTransformerBase)
transformer.transform.return_value = TransformedQuery(
main_query="SEO optimization",
sub_queries=[],
original_query="帮我分析一下SEO",
)
retriever = MemoryRetriever(
working_memory=memory,
query_transformer=transformer,
)
results = await retriever.retrieve("帮我分析一下SEO")
transformer.transform.assert_called_once_with("帮我分析一下SEO")
assert len(results) >= 1
async def test_sub_queries_searched_in_parallel(self):
"""子查询被并行搜索"""
memory = InMemoryMemory()
await memory.store("k1", "SEO optimization content")
await memory.store("k2", "Search engine ranking factors")
transformer = AsyncMock(spec=QueryTransformerBase)
transformer.transform.return_value = TransformedQuery(
main_query="SEO optimization",
sub_queries=["search engine ranking"],
original_query="SEO",
)
retriever = MemoryRetriever(
working_memory=memory,
query_transformer=transformer,
)
results = await retriever.retrieve("SEO")
# Both main query and sub-query results should be present
assert len(results) >= 1
async def test_results_deduplicated_by_key(self):
"""子查询结果按 key 去重,保留最高分"""
memory = InMemoryMemory()
await memory.store("k1", "SEO optimization content")
# The same key appears in both main and sub-query results
transformer = AsyncMock(spec=QueryTransformerBase)
transformer.transform.return_value = TransformedQuery(
main_query="SEO",
sub_queries=["SEO"], # Same query → same key match
original_query="SEO",
)
retriever = MemoryRetriever(
working_memory=memory,
query_transformer=transformer,
)
results = await retriever.retrieve("SEO")
# Should not have duplicate keys
keys = [r.key for r in results]
assert len(keys) == len(set(keys))
async def test_without_transformer_backward_compatible(self):
"""不设置 transformer 时行为不变(向后兼容)"""
memory = InMemoryMemory()
await memory.store("k1", "AI research content")
retriever = MemoryRetriever(working_memory=memory)
results = await retriever.retrieve("AI")
assert len(results) >= 1
assert results[0].key == "k1"