"""QueryTransformer 单元测试""" import json from unittest.mock import AsyncMock, MagicMock import pytest from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.query_transformer import ( LLMQueryTransformer, NoOpQueryTransformer, QueryTransformerBase, RuleQueryTransformer, TransformedQuery, create_query_transformer, ) # ── In-Memory Memory 实现(用于测试) ──────────────────── class InMemoryMemory(Memory): """基于内存的 Memory 实现,用于测试""" def __init__(self): self._store: dict[str, MemoryItem] = {} async def store(self, key: str, value, metadata=None) -> None: self._store[key] = MemoryItem( key=key, value=value, metadata=metadata or {}, score=1.0 ) async def retrieve(self, key: str) -> MemoryItem | None: return self._store.get(key) async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]: results = [] for item in self._store.values(): if query.lower() in str(item.value).lower() or query.lower() in item.key.lower(): results.append(item) return results[:top_k] async def delete(self, key: str) -> bool: return self._store.pop(key, None) is not None # ── TestTransformedQuery ────────────────────────────────── class TestTransformedQuery: """TransformedQuery dataclass 测试""" def test_creation_and_field_access(self): tq = TransformedQuery( main_query="SEO策略", sub_queries=["搜索引擎优化策略"], original_query="帮我分析一下SEO策略", ) assert tq.main_query == "SEO策略" assert tq.sub_queries == ["搜索引擎优化策略"] assert tq.original_query == "帮我分析一下SEO策略" def test_empty_sub_queries(self): tq = TransformedQuery(main_query="AI趋势", sub_queries=[], original_query="AI趋势") assert tq.sub_queries == [] # ── TestLLMQueryTransformer ─────────────────────────────── class TestLLMQueryTransformer: """LLMQueryTransformer 测试""" async def test_successful_transformation(self): """LLM 返回有效 JSON,验证 main_query 和 sub_queries""" gateway = AsyncMock() gateway.chat.return_value = MagicMock( content=json.dumps({ "main_query": "SEO optimization strategies", "sub_queries": ["search engine ranking", "keyword research"], }) ) transformer = LLMQueryTransformer(gateway) result = await transformer.transform("How to improve SEO?") assert result.main_query == "SEO optimization strategies" assert len(result.sub_queries) == 2 assert "search engine ranking" in result.sub_queries assert result.original_query == "How to improve SEO?" async def test_llm_error_fallback(self): """LLM 抛出异常,回退到原始查询""" gateway = AsyncMock() gateway.chat.side_effect = Exception("LLM service unavailable") transformer = LLMQueryTransformer(gateway) result = await transformer.transform("test query") assert result.main_query == "test query" assert result.sub_queries == [] assert result.original_query == "test query" async def test_invalid_json_response(self): """LLM 返回非 JSON,回退到原始查询""" gateway = AsyncMock() gateway.chat.return_value = MagicMock(content="This is not JSON") transformer = LLMQueryTransformer(gateway) result = await transformer.transform("test query") assert result.main_query == "test query" assert result.sub_queries == [] async def test_max_sub_queries_limit(self): """LLM 返回 5 个 sub_queries,但 max_sub_queries=3,只保留 3 个""" gateway = AsyncMock() gateway.chat.return_value = MagicMock( content=json.dumps({ "main_query": "query", "sub_queries": ["sq1", "sq2", "sq3", "sq4", "sq5"], }) ) transformer = LLMQueryTransformer(gateway, max_sub_queries=3) result = await transformer.transform("test") assert len(result.sub_queries) == 3 assert result.sub_queries == ["sq1", "sq2", "sq3"] async def test_prompt_contains_original_query(self): """验证发送给 LLM 的 prompt 包含原始查询""" gateway = AsyncMock() gateway.chat.return_value = MagicMock( content=json.dumps({"main_query": "q", "sub_queries": []}) ) transformer = LLMQueryTransformer(gateway) await transformer.transform("my original query") call_args = gateway.chat.call_args messages = call_args.kwargs.get("messages") or call_args[1].get("messages") or call_args[0][0] # The prompt should contain the original query prompt_text = messages[0]["content"] assert "my original query" in prompt_text # ── TestRuleQueryTransformer ────────────────────────────── class TestRuleQueryTransformer: """RuleQueryTransformer 测试""" async def test_chinese_filler_word_removal(self): """去除中文填充词:'帮我分析一下SEO策略' → main_query 包含 'SEO策略'""" transformer = RuleQueryTransformer() result = await transformer.transform("帮我分析一下SEO策略") assert "SEO策略" in result.main_query assert "帮我" not in result.main_query assert "一下" not in result.main_query assert result.original_query == "帮我分析一下SEO策略" async def test_english_filler_word_removal(self): """去除英文填充词:'Please help me analyze' → main_query 包含 'analyze'""" transformer = RuleQueryTransformer() result = await transformer.transform("Please help me analyze") assert "analyze" in result.main_query assert "Please" not in result.main_query assert "help me" not in result.main_query async def test_synonym_expansion(self): """同义扩展:SEO → 搜索引擎优化, Search Engine Optimization""" synonyms = {"SEO": ["搜索引擎优化", "Search Engine Optimization"]} transformer = RuleQueryTransformer(synonyms=synonyms) result = await transformer.transform("SEO策略") assert "SEO策略" in result.main_query assert len(result.sub_queries) == 2 assert any("搜索引擎优化" in sq for sq in result.sub_queries) assert any("Search Engine Optimization" in sq for sq in result.sub_queries) async def test_no_op_for_clean_query(self): """干净查询原样返回:'AI行业趋势' → 不变""" transformer = RuleQueryTransformer() result = await transformer.transform("AI行业趋势") assert result.main_query == "AI行业趋势" assert result.sub_queries == [] async def test_max_sub_queries_limit(self): """同义扩展受 max_sub_queries 限制""" synonyms = {"AI": ["人工智能", "Artificial Intelligence", "machine intelligence", "ML"]} transformer = RuleQueryTransformer(synonyms=synonyms, max_sub_queries=2) result = await transformer.transform("AI trends") assert len(result.sub_queries) <= 2 # ── TestNoOpQueryTransformer ────────────────────────────── class TestNoOpQueryTransformer: """NoOpQueryTransformer 测试""" async def test_returns_original_query_unchanged(self): """原样返回原始查询""" transformer = NoOpQueryTransformer() result = await transformer.transform("帮我分析一下SEO策略") assert result.main_query == "帮我分析一下SEO策略" assert result.sub_queries == [] assert result.original_query == "帮我分析一下SEO策略" # ── TestCreateQueryTransformer ──────────────────────────── class TestCreateQueryTransformer: """create_query_transformer 工厂函数测试""" def test_llm_strategy(self): """strategy='llm' 创建 LLMQueryTransformer""" gateway = AsyncMock() transformer = create_query_transformer(strategy="llm", llm_gateway=gateway) assert isinstance(transformer, LLMQueryTransformer) def test_rule_strategy(self): """strategy='rule' 创建 RuleQueryTransformer""" transformer = create_query_transformer(strategy="rule") assert isinstance(transformer, RuleQueryTransformer) def test_none_strategy(self): """strategy='none' 创建 NoOpQueryTransformer""" transformer = create_query_transformer(strategy="none") assert isinstance(transformer, NoOpQueryTransformer) def test_unknown_strategy_defaults_to_noop(self): """未知 strategy 默认创建 NoOpQueryTransformer""" transformer = create_query_transformer(strategy="unknown") assert isinstance(transformer, NoOpQueryTransformer) def test_llm_strategy_without_gateway_falls_back(self): """strategy='llm' 但无 gateway 时回退到 NoOp""" transformer = create_query_transformer(strategy="llm", llm_gateway=None) assert isinstance(transformer, NoOpQueryTransformer) # ── TestMemoryRetrieverWithTransformer ──────────────────── class TestMemoryRetrieverWithTransformer: """MemoryRetriever 集成 QueryTransformer 测试""" async def test_retrieve_calls_transformer_before_search(self): """retrieve() 在搜索前调用 transformer""" memory = InMemoryMemory() await memory.store("k1", "SEO optimization content") transformer = AsyncMock(spec=QueryTransformerBase) transformer.transform.return_value = TransformedQuery( main_query="SEO optimization", sub_queries=[], original_query="帮我分析一下SEO", ) retriever = MemoryRetriever( working_memory=memory, query_transformer=transformer, ) results = await retriever.retrieve("帮我分析一下SEO") transformer.transform.assert_called_once_with("帮我分析一下SEO") assert len(results) >= 1 async def test_sub_queries_searched_in_parallel(self): """子查询被并行搜索""" memory = InMemoryMemory() await memory.store("k1", "SEO optimization content") await memory.store("k2", "Search engine ranking factors") transformer = AsyncMock(spec=QueryTransformerBase) transformer.transform.return_value = TransformedQuery( main_query="SEO optimization", sub_queries=["search engine ranking"], original_query="SEO", ) retriever = MemoryRetriever( working_memory=memory, query_transformer=transformer, ) results = await retriever.retrieve("SEO") # Both main query and sub-query results should be present assert len(results) >= 1 async def test_results_deduplicated_by_key(self): """子查询结果按 key 去重,保留最高分""" memory = InMemoryMemory() await memory.store("k1", "SEO optimization content") # The same key appears in both main and sub-query results transformer = AsyncMock(spec=QueryTransformerBase) transformer.transform.return_value = TransformedQuery( main_query="SEO", sub_queries=["SEO"], # Same query → same key match original_query="SEO", ) retriever = MemoryRetriever( working_memory=memory, query_transformer=transformer, ) results = await retriever.retrieve("SEO") # Should not have duplicate keys keys = [r.key for r in results] assert len(keys) == len(set(keys)) async def test_without_transformer_backward_compatible(self): """不设置 transformer 时行为不变(向后兼容)""" memory = InMemoryMemory() await memory.store("k1", "AI research content") retriever = MemoryRetriever(working_memory=memory) results = await retriever.retrieve("AI") assert len(results) >= 1 assert results[0].key == "k1"