439 lines
16 KiB
Python
439 lines
16 KiB
Python
"""U5: Configurable Retrieval Parameters + Per-KB Weights
|
|
|
|
Tests for:
|
|
1. ReActEngine uses configurable top_k/token_budget from retrieval_config
|
|
2. ConfigDrivenAgent passes retrieval_config from memory config
|
|
3. SemanticMemory applies per-KB weight multipliers to scores
|
|
4. Improved token estimation for mixed Chinese/English text
|
|
5. ServerConfig parsing with memory.retrieval and memory.semantic.kb_weights
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.react import ReActEngine, ReActResult
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
from agentkit.memory.base import MemoryItem
|
|
from agentkit.memory.retriever import MemoryRetriever, _estimate_tokens
|
|
from agentkit.memory.semantic import SemanticMemory
|
|
|
|
|
|
# ── Test Helpers ──────────────────────────────────────────
|
|
|
|
|
|
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
gateway.chat = AsyncMock(side_effect=responses)
|
|
return gateway
|
|
|
|
|
|
def make_response(
|
|
content: str = "",
|
|
prompt_tokens: int = 10,
|
|
completion_tokens: int = 20,
|
|
) -> LLMResponse:
|
|
return LLMResponse(
|
|
content=content,
|
|
model="test-model",
|
|
usage=TokenUsage(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
),
|
|
tool_calls=[],
|
|
)
|
|
|
|
|
|
def make_mock_memory_retriever(context_string: str = "past experience data"):
|
|
retriever = MagicMock()
|
|
retriever.get_context_string = AsyncMock(return_value=context_string)
|
|
retriever._episodic = None
|
|
retriever.store_episode = AsyncMock()
|
|
return retriever
|
|
|
|
|
|
# ── Test: Configurable Retrieval Parameters ──────────
|
|
|
|
|
|
class TestConfigurableRetrievalParameters:
|
|
"""ReActEngine uses configurable top_k/token_budget from retrieval_config"""
|
|
|
|
async def test_default_top_k_when_no_config(self):
|
|
"""ReActEngine uses default top_k=5 when no config provided"""
|
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
)
|
|
|
|
retriever.get_context_string.assert_awaited_once_with(
|
|
query="Hello",
|
|
top_k=5,
|
|
token_budget=2000,
|
|
)
|
|
|
|
async def test_configured_top_k(self):
|
|
"""ReActEngine uses configured top_k from retrieval_config"""
|
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
retrieval_config={"top_k": 10, "token_budget": 4000},
|
|
)
|
|
|
|
retriever.get_context_string.assert_awaited_once_with(
|
|
query="Hello",
|
|
top_k=10,
|
|
token_budget=4000,
|
|
)
|
|
|
|
async def test_configured_token_budget(self):
|
|
"""ReActEngine uses configured token_budget from retrieval_config"""
|
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
retrieval_config={"token_budget": 5000},
|
|
)
|
|
|
|
call_kwargs = retriever.get_context_string.call_args
|
|
assert call_kwargs.kwargs.get("token_budget") == 5000
|
|
|
|
async def test_backward_compatibility_no_config(self):
|
|
"""No config = same behavior as before (top_k=5, token_budget=2000)"""
|
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
)
|
|
|
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
|
assert call_kwargs["top_k"] == 5
|
|
assert call_kwargs["token_budget"] == 2000
|
|
|
|
async def test_stream_uses_retrieval_config(self):
|
|
"""execute_stream also uses retrieval_config"""
|
|
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
retrieval_config={"top_k": 8, "token_budget": 3000},
|
|
):
|
|
events.append(event)
|
|
|
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
|
assert call_kwargs["top_k"] == 8
|
|
assert call_kwargs["token_budget"] == 3000
|
|
|
|
async def test_partial_config_uses_defaults(self):
|
|
"""Partial config: only top_k specified, token_budget falls back to default"""
|
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
retriever = make_mock_memory_retriever("context")
|
|
|
|
await engine.execute(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
memory_retriever=retriever,
|
|
retrieval_config={"top_k": 3},
|
|
)
|
|
|
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
|
assert call_kwargs["top_k"] == 3
|
|
assert call_kwargs["token_budget"] == 2000 # default
|
|
|
|
|
|
class TestConfigDrivenAgentRetrievalConfig:
|
|
"""ConfigDrivenAgent passes retrieval_config from memory config"""
|
|
|
|
async def test_retrieval_config_passed_to_react_engine(self):
|
|
"""ConfigDrivenAgent extracts retrieval config and passes to ReActEngine"""
|
|
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
|
from agentkit.skills.base import SkillConfig
|
|
|
|
config = SkillConfig(
|
|
name="test-agent",
|
|
agent_type="test",
|
|
task_mode="llm_generate",
|
|
execution_mode="react",
|
|
prompt={"identity": "Test agent"},
|
|
memory={
|
|
"retrieval": {"top_k": 10, "token_budget": 5000},
|
|
"working": {"enabled": False},
|
|
"episodic": {"enabled": False},
|
|
},
|
|
)
|
|
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
gateway.chat = AsyncMock(return_value=make_response(content="done"))
|
|
|
|
agent = ConfigDrivenAgent(config=config, llm_gateway=gateway)
|
|
|
|
# Verify the agent has memory config
|
|
assert agent._config.memory.get("retrieval") == {"top_k": 10, "token_budget": 5000}
|
|
|
|
|
|
# ── Test: Per-KB Weights ──────────────────────────────────
|
|
|
|
|
|
class TestPerKBWeights:
|
|
"""SemanticMemory with kb_weights applies multipliers to scores"""
|
|
|
|
async def test_kb_weights_applied_to_scores(self):
|
|
"""kb_weights multiplies scores for matching KB IDs"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Industry data", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "industry-kb"},
|
|
{"id": "2", "content": "Enterprise data", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "enterprise-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["industry-kb", "enterprise-kb"],
|
|
kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8},
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
|
|
# Industry KB result should have higher score
|
|
industry_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "industry-kb")
|
|
enterprise_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "enterprise-kb")
|
|
|
|
assert industry_item.score == pytest.approx(0.9 * 1.2)
|
|
assert enterprise_item.score == pytest.approx(0.9 * 0.8)
|
|
|
|
async def test_industry_kb_scores_higher_than_enterprise(self):
|
|
"""Industry KB (weight 1.2) results score higher than enterprise KB (weight 0.8)"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Enterprise result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "enterprise-kb"},
|
|
{"id": "2", "content": "Industry result", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "industry-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["industry-kb", "enterprise-kb"],
|
|
kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8},
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
|
|
# After sorting by score, industry should be first
|
|
assert results[0].metadata.get("knowledge_base_id") == "industry-kb"
|
|
assert results[0].score > results[1].score
|
|
|
|
async def test_unweighted_kb_gets_default_score(self):
|
|
"""Unweighted KBs get default score (1.0 multiplier)"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Unweighted result", "score": 0.8, "source": "rag", "document_id": "d1", "knowledge_base_id": "unweighted-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["unweighted-kb"],
|
|
kb_weights={"industry-kb": 1.5}, # no weight for unweighted-kb
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
assert len(results) == 1
|
|
assert results[0].score == pytest.approx(0.8) # unchanged
|
|
|
|
async def test_kb_weights_none_no_modification(self):
|
|
"""kb_weights=None: no score modification"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["some-kb"],
|
|
kb_weights=None,
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
assert results[0].score == pytest.approx(0.75)
|
|
|
|
async def test_empty_kb_weights_no_modification(self):
|
|
"""Empty kb_weights dict: no score modification"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["some-kb"],
|
|
kb_weights={},
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
assert results[0].score == pytest.approx(0.75)
|
|
|
|
async def test_kb_id_propagated_to_metadata(self):
|
|
"""knowledge_base_id is propagated to MemoryItem metadata"""
|
|
rag_service = MagicMock()
|
|
rag_service.search = AsyncMock(return_value=[
|
|
{"id": "1", "content": "Result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "my-kb"},
|
|
])
|
|
|
|
memory = SemanticMemory(
|
|
rag_service=rag_service,
|
|
knowledge_base_ids=["my-kb"],
|
|
)
|
|
|
|
results = await memory.search("test query")
|
|
assert results[0].metadata["knowledge_base_id"] == "my-kb"
|
|
|
|
|
|
# ── Test: Token Estimation ────────────────────────────────
|
|
|
|
|
|
class TestTokenEstimation:
|
|
"""Improved token estimation for mixed Chinese/English text"""
|
|
|
|
def test_pure_english_text(self):
|
|
"""Pure English text: ~1 token per word"""
|
|
text = "Hello world this is a test"
|
|
result = _estimate_tokens(text)
|
|
# 6 words * 1 = 6 tokens
|
|
assert result == 6
|
|
|
|
def test_pure_chinese_text(self):
|
|
"""Pure Chinese text: ~2 tokens per character"""
|
|
text = "你好世界测试"
|
|
result = _estimate_tokens(text)
|
|
# 6 CJK chars * 2 = 12 tokens
|
|
assert result == 12
|
|
|
|
def test_mixed_chinese_english_text(self):
|
|
"""Mixed Chinese/English text"""
|
|
text = "你好world测试test"
|
|
result = _estimate_tokens(text)
|
|
# 4 CJK chars * 2 = 8, plus 2 English words = 2, total = 10
|
|
assert result == 10
|
|
|
|
def test_more_accurate_than_old_for_chinese(self):
|
|
"""New estimation is more accurate than len(text)//4 for Chinese text"""
|
|
text = "人工智能技术在近年来取得了巨大突破"
|
|
new_estimate = _estimate_tokens(text)
|
|
old_estimate = len(text) // 4
|
|
|
|
# For Chinese text, the old method underestimates
|
|
# 17 CJK chars * 2 = 34 tokens (new)
|
|
# 17 chars // 4 = 4 tokens (old) — way too low
|
|
assert new_estimate > old_estimate
|
|
assert new_estimate == 34
|
|
|
|
def test_empty_string(self):
|
|
"""Empty string: 0 tokens"""
|
|
assert _estimate_tokens("") == 0
|
|
|
|
def test_whitespace_only(self):
|
|
"""Whitespace only: 0 tokens"""
|
|
assert _estimate_tokens(" ") == 0
|
|
|
|
def test_english_with_punctuation(self):
|
|
"""English with punctuation"""
|
|
text = "Hello, world! How are you?"
|
|
result = _estimate_tokens(text)
|
|
# "Hello," "world!" "How" "are" "you?" = 5 words
|
|
assert result == 5
|
|
|
|
|
|
# ── Test: Config Parsing ──────────────────────────────────
|
|
|
|
|
|
class TestConfigParsing:
|
|
"""ServerConfig.from_dict() with memory.retrieval and memory.semantic.kb_weights"""
|
|
|
|
def test_memory_retrieval_section(self):
|
|
"""ServerConfig.from_dict() preserves memory.retrieval section"""
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"retrieval": {
|
|
"top_k": 10,
|
|
"token_budget": 5000,
|
|
},
|
|
},
|
|
}
|
|
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["retrieval"]["top_k"] == 10
|
|
assert config.memory["retrieval"]["token_budget"] == 5000
|
|
|
|
def test_memory_semantic_kb_weights_section(self):
|
|
"""ServerConfig.from_dict() preserves memory.semantic.kb_weights section"""
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:8000",
|
|
"kb_weights": {
|
|
"industry-kb": 1.2,
|
|
"enterprise-kb": 0.8,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["semantic"]["kb_weights"]["industry-kb"] == 1.2
|
|
assert config.memory["semantic"]["kb_weights"]["enterprise-kb"] == 0.8
|
|
|
|
def test_memory_config_without_retrieval(self):
|
|
"""ServerConfig.from_dict() works without memory.retrieval section"""
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {"enabled": False},
|
|
},
|
|
}
|
|
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory.get("retrieval") is None
|
|
|
|
def test_memory_config_without_kb_weights(self):
|
|
"""ServerConfig.from_dict() works without kb_weights section"""
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:8000",
|
|
},
|
|
},
|
|
}
|
|
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["semantic"].get("kb_weights") is None
|