336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""QueryTransformer 单元测试"""
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.memory.base import Memory, MemoryItem
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
from agentkit.memory.query_transformer import (
|
||
LLMQueryTransformer,
|
||
NoOpQueryTransformer,
|
||
QueryTransformerBase,
|
||
RuleQueryTransformer,
|
||
TransformedQuery,
|
||
create_query_transformer,
|
||
)
|
||
|
||
|
||
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||
|
||
|
||
class InMemoryMemory(Memory):
|
||
"""基于内存的 Memory 实现,用于测试"""
|
||
|
||
def __init__(self):
|
||
self._store: dict[str, MemoryItem] = {}
|
||
|
||
async def store(self, key: str, value, metadata=None) -> None:
|
||
self._store[key] = MemoryItem(
|
||
key=key, value=value, metadata=metadata or {}, score=1.0
|
||
)
|
||
|
||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||
return self._store.get(key)
|
||
|
||
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||
results = []
|
||
for item in self._store.values():
|
||
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||
results.append(item)
|
||
return results[:top_k]
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
return self._store.pop(key, None) is not None
|
||
|
||
|
||
# ── TestTransformedQuery ──────────────────────────────────
|
||
|
||
|
||
class TestTransformedQuery:
|
||
"""TransformedQuery dataclass 测试"""
|
||
|
||
def test_creation_and_field_access(self):
|
||
tq = TransformedQuery(
|
||
main_query="SEO策略",
|
||
sub_queries=["搜索引擎优化策略"],
|
||
original_query="帮我分析一下SEO策略",
|
||
)
|
||
assert tq.main_query == "SEO策略"
|
||
assert tq.sub_queries == ["搜索引擎优化策略"]
|
||
assert tq.original_query == "帮我分析一下SEO策略"
|
||
|
||
def test_empty_sub_queries(self):
|
||
tq = TransformedQuery(main_query="AI趋势", sub_queries=[], original_query="AI趋势")
|
||
assert tq.sub_queries == []
|
||
|
||
|
||
# ── TestLLMQueryTransformer ───────────────────────────────
|
||
|
||
|
||
class TestLLMQueryTransformer:
|
||
"""LLMQueryTransformer 测试"""
|
||
|
||
async def test_successful_transformation(self):
|
||
"""LLM 返回有效 JSON,验证 main_query 和 sub_queries"""
|
||
gateway = AsyncMock()
|
||
gateway.chat.return_value = MagicMock(
|
||
content=json.dumps({
|
||
"main_query": "SEO optimization strategies",
|
||
"sub_queries": ["search engine ranking", "keyword research"],
|
||
})
|
||
)
|
||
|
||
transformer = LLMQueryTransformer(gateway)
|
||
result = await transformer.transform("How to improve SEO?")
|
||
|
||
assert result.main_query == "SEO optimization strategies"
|
||
assert len(result.sub_queries) == 2
|
||
assert "search engine ranking" in result.sub_queries
|
||
assert result.original_query == "How to improve SEO?"
|
||
|
||
async def test_llm_error_fallback(self):
|
||
"""LLM 抛出异常,回退到原始查询"""
|
||
gateway = AsyncMock()
|
||
gateway.chat.side_effect = Exception("LLM service unavailable")
|
||
|
||
transformer = LLMQueryTransformer(gateway)
|
||
result = await transformer.transform("test query")
|
||
|
||
assert result.main_query == "test query"
|
||
assert result.sub_queries == []
|
||
assert result.original_query == "test query"
|
||
|
||
async def test_invalid_json_response(self):
|
||
"""LLM 返回非 JSON,回退到原始查询"""
|
||
gateway = AsyncMock()
|
||
gateway.chat.return_value = MagicMock(content="This is not JSON")
|
||
|
||
transformer = LLMQueryTransformer(gateway)
|
||
result = await transformer.transform("test query")
|
||
|
||
assert result.main_query == "test query"
|
||
assert result.sub_queries == []
|
||
|
||
async def test_max_sub_queries_limit(self):
|
||
"""LLM 返回 5 个 sub_queries,但 max_sub_queries=3,只保留 3 个"""
|
||
gateway = AsyncMock()
|
||
gateway.chat.return_value = MagicMock(
|
||
content=json.dumps({
|
||
"main_query": "query",
|
||
"sub_queries": ["sq1", "sq2", "sq3", "sq4", "sq5"],
|
||
})
|
||
)
|
||
|
||
transformer = LLMQueryTransformer(gateway, max_sub_queries=3)
|
||
result = await transformer.transform("test")
|
||
|
||
assert len(result.sub_queries) == 3
|
||
assert result.sub_queries == ["sq1", "sq2", "sq3"]
|
||
|
||
async def test_prompt_contains_original_query(self):
|
||
"""验证发送给 LLM 的 prompt 包含原始查询"""
|
||
gateway = AsyncMock()
|
||
gateway.chat.return_value = MagicMock(
|
||
content=json.dumps({"main_query": "q", "sub_queries": []})
|
||
)
|
||
|
||
transformer = LLMQueryTransformer(gateway)
|
||
await transformer.transform("my original query")
|
||
|
||
call_args = gateway.chat.call_args
|
||
messages = call_args.kwargs.get("messages") or call_args[1].get("messages") or call_args[0][0]
|
||
# The prompt should contain the original query
|
||
prompt_text = messages[0]["content"]
|
||
assert "my original query" in prompt_text
|
||
|
||
|
||
# ── TestRuleQueryTransformer ──────────────────────────────
|
||
|
||
|
||
class TestRuleQueryTransformer:
|
||
"""RuleQueryTransformer 测试"""
|
||
|
||
async def test_chinese_filler_word_removal(self):
|
||
"""去除中文填充词:'帮我分析一下SEO策略' → main_query 包含 'SEO策略'"""
|
||
transformer = RuleQueryTransformer()
|
||
result = await transformer.transform("帮我分析一下SEO策略")
|
||
|
||
assert "SEO策略" in result.main_query
|
||
assert "帮我" not in result.main_query
|
||
assert "一下" not in result.main_query
|
||
assert result.original_query == "帮我分析一下SEO策略"
|
||
|
||
async def test_english_filler_word_removal(self):
|
||
"""去除英文填充词:'Please help me analyze' → main_query 包含 'analyze'"""
|
||
transformer = RuleQueryTransformer()
|
||
result = await transformer.transform("Please help me analyze")
|
||
|
||
assert "analyze" in result.main_query
|
||
assert "Please" not in result.main_query
|
||
assert "help me" not in result.main_query
|
||
|
||
async def test_synonym_expansion(self):
|
||
"""同义扩展:SEO → 搜索引擎优化, Search Engine Optimization"""
|
||
synonyms = {"SEO": ["搜索引擎优化", "Search Engine Optimization"]}
|
||
transformer = RuleQueryTransformer(synonyms=synonyms)
|
||
|
||
result = await transformer.transform("SEO策略")
|
||
|
||
assert "SEO策略" in result.main_query
|
||
assert len(result.sub_queries) == 2
|
||
assert any("搜索引擎优化" in sq for sq in result.sub_queries)
|
||
assert any("Search Engine Optimization" in sq for sq in result.sub_queries)
|
||
|
||
async def test_no_op_for_clean_query(self):
|
||
"""干净查询原样返回:'AI行业趋势' → 不变"""
|
||
transformer = RuleQueryTransformer()
|
||
result = await transformer.transform("AI行业趋势")
|
||
|
||
assert result.main_query == "AI行业趋势"
|
||
assert result.sub_queries == []
|
||
|
||
async def test_max_sub_queries_limit(self):
|
||
"""同义扩展受 max_sub_queries 限制"""
|
||
synonyms = {"AI": ["人工智能", "Artificial Intelligence", "machine intelligence", "ML"]}
|
||
transformer = RuleQueryTransformer(synonyms=synonyms, max_sub_queries=2)
|
||
|
||
result = await transformer.transform("AI trends")
|
||
|
||
assert len(result.sub_queries) <= 2
|
||
|
||
|
||
# ── TestNoOpQueryTransformer ──────────────────────────────
|
||
|
||
|
||
class TestNoOpQueryTransformer:
|
||
"""NoOpQueryTransformer 测试"""
|
||
|
||
async def test_returns_original_query_unchanged(self):
|
||
"""原样返回原始查询"""
|
||
transformer = NoOpQueryTransformer()
|
||
result = await transformer.transform("帮我分析一下SEO策略")
|
||
|
||
assert result.main_query == "帮我分析一下SEO策略"
|
||
assert result.sub_queries == []
|
||
assert result.original_query == "帮我分析一下SEO策略"
|
||
|
||
|
||
# ── TestCreateQueryTransformer ────────────────────────────
|
||
|
||
|
||
class TestCreateQueryTransformer:
|
||
"""create_query_transformer 工厂函数测试"""
|
||
|
||
def test_llm_strategy(self):
|
||
"""strategy='llm' 创建 LLMQueryTransformer"""
|
||
gateway = AsyncMock()
|
||
transformer = create_query_transformer(strategy="llm", llm_gateway=gateway)
|
||
assert isinstance(transformer, LLMQueryTransformer)
|
||
|
||
def test_rule_strategy(self):
|
||
"""strategy='rule' 创建 RuleQueryTransformer"""
|
||
transformer = create_query_transformer(strategy="rule")
|
||
assert isinstance(transformer, RuleQueryTransformer)
|
||
|
||
def test_none_strategy(self):
|
||
"""strategy='none' 创建 NoOpQueryTransformer"""
|
||
transformer = create_query_transformer(strategy="none")
|
||
assert isinstance(transformer, NoOpQueryTransformer)
|
||
|
||
def test_unknown_strategy_defaults_to_noop(self):
|
||
"""未知 strategy 默认创建 NoOpQueryTransformer"""
|
||
transformer = create_query_transformer(strategy="unknown")
|
||
assert isinstance(transformer, NoOpQueryTransformer)
|
||
|
||
def test_llm_strategy_without_gateway_falls_back(self):
|
||
"""strategy='llm' 但无 gateway 时回退到 NoOp"""
|
||
transformer = create_query_transformer(strategy="llm", llm_gateway=None)
|
||
assert isinstance(transformer, NoOpQueryTransformer)
|
||
|
||
|
||
# ── TestMemoryRetrieverWithTransformer ────────────────────
|
||
|
||
|
||
class TestMemoryRetrieverWithTransformer:
|
||
"""MemoryRetriever 集成 QueryTransformer 测试"""
|
||
|
||
async def test_retrieve_calls_transformer_before_search(self):
|
||
"""retrieve() 在搜索前调用 transformer"""
|
||
memory = InMemoryMemory()
|
||
await memory.store("k1", "SEO optimization content")
|
||
|
||
transformer = AsyncMock(spec=QueryTransformerBase)
|
||
transformer.transform.return_value = TransformedQuery(
|
||
main_query="SEO optimization",
|
||
sub_queries=[],
|
||
original_query="帮我分析一下SEO",
|
||
)
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=memory,
|
||
query_transformer=transformer,
|
||
)
|
||
|
||
results = await retriever.retrieve("帮我分析一下SEO")
|
||
|
||
transformer.transform.assert_called_once_with("帮我分析一下SEO")
|
||
assert len(results) >= 1
|
||
|
||
async def test_sub_queries_searched_in_parallel(self):
|
||
"""子查询被并行搜索"""
|
||
memory = InMemoryMemory()
|
||
await memory.store("k1", "SEO optimization content")
|
||
await memory.store("k2", "Search engine ranking factors")
|
||
|
||
transformer = AsyncMock(spec=QueryTransformerBase)
|
||
transformer.transform.return_value = TransformedQuery(
|
||
main_query="SEO optimization",
|
||
sub_queries=["search engine ranking"],
|
||
original_query="SEO",
|
||
)
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=memory,
|
||
query_transformer=transformer,
|
||
)
|
||
|
||
results = await retriever.retrieve("SEO")
|
||
# Both main query and sub-query results should be present
|
||
assert len(results) >= 1
|
||
|
||
async def test_results_deduplicated_by_key(self):
|
||
"""子查询结果按 key 去重,保留最高分"""
|
||
memory = InMemoryMemory()
|
||
await memory.store("k1", "SEO optimization content")
|
||
|
||
# The same key appears in both main and sub-query results
|
||
transformer = AsyncMock(spec=QueryTransformerBase)
|
||
transformer.transform.return_value = TransformedQuery(
|
||
main_query="SEO",
|
||
sub_queries=["SEO"], # Same query → same key match
|
||
original_query="SEO",
|
||
)
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=memory,
|
||
query_transformer=transformer,
|
||
)
|
||
|
||
results = await retriever.retrieve("SEO")
|
||
# Should not have duplicate keys
|
||
keys = [r.key for r in results]
|
||
assert len(keys) == len(set(keys))
|
||
|
||
async def test_without_transformer_backward_compatible(self):
|
||
"""不设置 transformer 时行为不变(向后兼容)"""
|
||
memory = InMemoryMemory()
|
||
await memory.store("k1", "AI research content")
|
||
|
||
retriever = MemoryRetriever(working_memory=memory)
|
||
results = await retriever.retrieve("AI")
|
||
|
||
assert len(results) >= 1
|
||
assert results[0].key == "k1"
|