"""U5: Configurable Retrieval Parameters + Per-KB Weights Tests for: 1. ReActEngine uses configurable top_k/token_budget from retrieval_config 2. ConfigDrivenAgent passes retrieval_config from memory config 3. SemanticMemory applies per-KB weight multipliers to scores 4. Improved token estimation for mixed Chinese/English text 5. ServerConfig parsing with memory.retrieval and memory.semantic.kb_weights """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.react import ReActEngine, ReActResult from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMResponse, TokenUsage from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever, _estimate_tokens from agentkit.memory.semantic import SemanticMemory # ── Test Helpers ────────────────────────────────────────── def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway: gateway = MagicMock(spec=LLMGateway) gateway.chat = AsyncMock(side_effect=responses) return gateway def make_response( content: str = "", prompt_tokens: int = 10, completion_tokens: int = 20, ) -> LLMResponse: return LLMResponse( content=content, model="test-model", usage=TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), tool_calls=[], ) def make_mock_memory_retriever(context_string: str = "past experience data"): retriever = MagicMock() retriever.get_context_string = AsyncMock(return_value=context_string) retriever._episodic = None retriever.store_episode = AsyncMock() return retriever # ── Test: Configurable Retrieval Parameters ────────── class TestConfigurableRetrievalParameters: """ReActEngine uses configurable top_k/token_budget from retrieval_config""" async def test_default_top_k_when_no_config(self): """ReActEngine uses default top_k=5 when no config provided""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) retriever.get_context_string.assert_awaited_once_with( query="Hello", top_k=5, token_budget=2000, ) async def test_configured_top_k(self): """ReActEngine uses configured top_k from retrieval_config""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, retrieval_config={"top_k": 10, "token_budget": 4000}, ) retriever.get_context_string.assert_awaited_once_with( query="Hello", top_k=10, token_budget=4000, ) async def test_configured_token_budget(self): """ReActEngine uses configured token_budget from retrieval_config""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, retrieval_config={"token_budget": 5000}, ) call_kwargs = retriever.get_context_string.call_args assert call_kwargs.kwargs.get("token_budget") == 5000 async def test_backward_compatibility_no_config(self): """No config = same behavior as before (top_k=5, token_budget=2000)""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) call_kwargs = retriever.get_context_string.call_args.kwargs assert call_kwargs["top_k"] == 5 assert call_kwargs["token_budget"] == 2000 async def test_stream_uses_retrieval_config(self): """execute_stream also uses retrieval_config""" gateway = make_mock_gateway([make_response(content="streamed answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") events = [] async for event in engine.execute_stream( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, retrieval_config={"top_k": 8, "token_budget": 3000}, ): events.append(event) call_kwargs = retriever.get_context_string.call_args.kwargs assert call_kwargs["top_k"] == 8 assert call_kwargs["token_budget"] == 3000 async def test_partial_config_uses_defaults(self): """Partial config: only top_k specified, token_budget falls back to default""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("context") await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, retrieval_config={"top_k": 3}, ) call_kwargs = retriever.get_context_string.call_args.kwargs assert call_kwargs["top_k"] == 3 assert call_kwargs["token_budget"] == 2000 # default class TestConfigDrivenAgentRetrievalConfig: """ConfigDrivenAgent passes retrieval_config from memory config""" async def test_retrieval_config_passed_to_react_engine(self): """ConfigDrivenAgent extracts retrieval config and passes to ReActEngine""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig from agentkit.skills.base import SkillConfig config = SkillConfig( name="test-agent", agent_type="test", task_mode="llm_generate", execution_mode="react", prompt={"identity": "Test agent"}, memory={ "retrieval": {"top_k": 10, "token_budget": 5000}, "working": {"enabled": False}, "episodic": {"enabled": False}, }, ) gateway = MagicMock(spec=LLMGateway) gateway.chat = AsyncMock(return_value=make_response(content="done")) agent = ConfigDrivenAgent(config=config, llm_gateway=gateway) # Verify the agent has memory config assert agent._config.memory.get("retrieval") == {"top_k": 10, "token_budget": 5000} # ── Test: Per-KB Weights ────────────────────────────────── class TestPerKBWeights: """SemanticMemory with kb_weights applies multipliers to scores""" async def test_kb_weights_applied_to_scores(self): """kb_weights multiplies scores for matching KB IDs""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Industry data", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "industry-kb"}, {"id": "2", "content": "Enterprise data", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "enterprise-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["industry-kb", "enterprise-kb"], kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8}, ) results = await memory.search("test query") # Industry KB result should have higher score industry_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "industry-kb") enterprise_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "enterprise-kb") assert industry_item.score == pytest.approx(0.9 * 1.2) assert enterprise_item.score == pytest.approx(0.9 * 0.8) async def test_industry_kb_scores_higher_than_enterprise(self): """Industry KB (weight 1.2) results score higher than enterprise KB (weight 0.8)""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Enterprise result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "enterprise-kb"}, {"id": "2", "content": "Industry result", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "industry-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["industry-kb", "enterprise-kb"], kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8}, ) results = await memory.search("test query") # After sorting by score, industry should be first assert results[0].metadata.get("knowledge_base_id") == "industry-kb" assert results[0].score > results[1].score async def test_unweighted_kb_gets_default_score(self): """Unweighted KBs get default score (1.0 multiplier)""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Unweighted result", "score": 0.8, "source": "rag", "document_id": "d1", "knowledge_base_id": "unweighted-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["unweighted-kb"], kb_weights={"industry-kb": 1.5}, # no weight for unweighted-kb ) results = await memory.search("test query") assert len(results) == 1 assert results[0].score == pytest.approx(0.8) # unchanged async def test_kb_weights_none_no_modification(self): """kb_weights=None: no score modification""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["some-kb"], kb_weights=None, ) results = await memory.search("test query") assert results[0].score == pytest.approx(0.75) async def test_empty_kb_weights_no_modification(self): """Empty kb_weights dict: no score modification""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["some-kb"], kb_weights={}, ) results = await memory.search("test query") assert results[0].score == pytest.approx(0.75) async def test_kb_id_propagated_to_metadata(self): """knowledge_base_id is propagated to MemoryItem metadata""" rag_service = MagicMock() rag_service.search = AsyncMock(return_value=[ {"id": "1", "content": "Result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "my-kb"}, ]) memory = SemanticMemory( rag_service=rag_service, knowledge_base_ids=["my-kb"], ) results = await memory.search("test query") assert results[0].metadata["knowledge_base_id"] == "my-kb" # ── Test: Token Estimation ──────────────────────────────── class TestTokenEstimation: """Improved token estimation for mixed Chinese/English text""" def test_pure_english_text(self): """Pure English text: ~1 token per word""" text = "Hello world this is a test" result = _estimate_tokens(text) # 6 words * 1 = 6 tokens assert result == 6 def test_pure_chinese_text(self): """Pure Chinese text: ~2 tokens per character""" text = "你好世界测试" result = _estimate_tokens(text) # 6 CJK chars * 2 = 12 tokens assert result == 12 def test_mixed_chinese_english_text(self): """Mixed Chinese/English text""" text = "你好world测试test" result = _estimate_tokens(text) # 4 CJK chars * 2 = 8, plus 2 English words = 2, total = 10 assert result == 10 def test_more_accurate_than_old_for_chinese(self): """New estimation is more accurate than len(text)//4 for Chinese text""" text = "人工智能技术在近年来取得了巨大突破" new_estimate = _estimate_tokens(text) old_estimate = len(text) // 4 # For Chinese text, the old method underestimates # 17 CJK chars * 2 = 34 tokens (new) # 17 chars // 4 = 4 tokens (old) — way too low assert new_estimate > old_estimate assert new_estimate == 34 def test_empty_string(self): """Empty string: 0 tokens""" assert _estimate_tokens("") == 0 def test_whitespace_only(self): """Whitespace only: 0 tokens""" assert _estimate_tokens(" ") == 0 def test_english_with_punctuation(self): """English with punctuation""" text = "Hello, world! How are you?" result = _estimate_tokens(text) # "Hello," "world!" "How" "are" "you?" = 5 words assert result == 5 # ── Test: Config Parsing ────────────────────────────────── class TestConfigParsing: """ServerConfig.from_dict() with memory.retrieval and memory.semantic.kb_weights""" def test_memory_retrieval_section(self): """ServerConfig.from_dict() preserves memory.retrieval section""" from agentkit.server.config import ServerConfig data = { "memory": { "retrieval": { "top_k": 10, "token_budget": 5000, }, }, } config = ServerConfig.from_dict(data) assert config.memory["retrieval"]["top_k"] == 10 assert config.memory["retrieval"]["token_budget"] == 5000 def test_memory_semantic_kb_weights_section(self): """ServerConfig.from_dict() preserves memory.semantic.kb_weights section""" from agentkit.server.config import ServerConfig data = { "memory": { "semantic": { "enabled": True, "base_url": "http://localhost:8000", "kb_weights": { "industry-kb": 1.2, "enterprise-kb": 0.8, }, }, }, } config = ServerConfig.from_dict(data) assert config.memory["semantic"]["kb_weights"]["industry-kb"] == 1.2 assert config.memory["semantic"]["kb_weights"]["enterprise-kb"] == 0.8 def test_memory_config_without_retrieval(self): """ServerConfig.from_dict() works without memory.retrieval section""" from agentkit.server.config import ServerConfig data = { "memory": { "semantic": {"enabled": False}, }, } config = ServerConfig.from_dict(data) assert config.memory.get("retrieval") is None def test_memory_config_without_kb_weights(self): """ServerConfig.from_dict() works without kb_weights section""" from agentkit.server.config import ServerConfig data = { "memory": { "semantic": { "enabled": True, "base_url": "http://localhost:8000", }, }, } config = ServerConfig.from_dict(data) assert config.memory["semantic"].get("kb_weights") is None