fischer-agentkit/tests/unit/test_http_rag_service.py

885 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.retriever import MemoryRetriever
# ---------------------------------------------------------------------------
# HttpRAGService unit tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceInit:
"""HttpRAGService 初始化"""
def test_basic_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
assert svc._base_url == "http://localhost:8000/api/knowledge"
assert svc._api_key is None
assert svc._knowledge_base_ids == []
assert svc._timeout == 30
def test_init_with_all_params(self):
svc = HttpRAGService(
base_url="http://geo:8000/api/knowledge/",
api_key="sk-test",
knowledge_base_ids=["kb-1", "kb-2"],
timeout=60,
)
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
assert svc._api_key == "sk-test"
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
assert svc._timeout == 60
def test_trailing_slash_stripped(self):
svc = HttpRAGService(base_url="http://host/api/")
assert svc._base_url == "http://host/api"
class TestHttpRAGServiceSearch:
"""HttpRAGService.search — 语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-industry", "kb-enterprise"],
)
@pytest.mark.asyncio
async def test_search_standard_response(self, svc):
"""标准 SearchResponse 格式: {"results": [...]}"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{
"chunk_id": "c1",
"content": "AI 行业趋势分析",
"score": 0.92,
"document_id": "d1",
"document_title": "行业报告",
"metadata": {"page": 1},
},
{
"chunk_id": "c2",
"content": "企业数字化转型",
"score": 0.85,
"document_id": "d2",
"document_title": "企业案例",
"metadata": {},
},
],
"total": 2,
"latency_ms": 50,
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("AI 趋势", top_k=5)
assert len(results) == 2
assert results[0]["id"] == "c1"
assert results[0]["content"] == "AI 行业趋势分析"
assert results[0]["score"] == 0.92
assert results[0]["document_id"] == "d1"
assert results[1]["content"] == "企业数字化转型"
# Verify payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
assert payload["top_k"] == 5
@pytest.mark.asyncio
async def test_search_list_response(self, svc):
"""直接返回列表格式: [...]"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = [
{"chunk_id": "c1", "content": "test", "score": 0.8},
]
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert len(results) == 1
assert results[0]["content"] == "test"
@pytest.mark.asyncio
async def test_search_custom_kb_ids(self, svc):
"""传入自定义 knowledge_base_ids 覆盖默认值"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.search("test", knowledge_base_ids=["custom-kb"])
payload = mock_client.post.call_args[1]["json"]
assert payload["knowledge_base_ids"] == ["custom-kb"]
@pytest.mark.asyncio
async def test_search_http_error_returns_empty(self, svc):
"""HTTP 错误返回空列表"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_connection_error_returns_empty(self, svc):
"""连接错误返回空列表"""
import httpx
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_unexpected_format_returns_empty(self, svc):
"""非预期响应格式返回空列表"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"error": "something"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
class TestHttpRAGServiceIngest:
"""HttpRAGService.ingest — 文档写入"""
@pytest.mark.asyncio
async def test_ingest_success(self):
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("测试文档", "文档内容")
assert result["id"] == "doc-1"
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/documents"
payload = call_args[1]["json"]
assert payload["title"] == "测试文档"
assert payload["content"] == "文档内容"
@pytest.mark.asyncio
async def test_ingest_no_kb_ids_returns_none(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
result = await svc.ingest("test", "content")
assert result is None
@pytest.mark.asyncio
async def test_ingest_http_error_returns_none(self):
import httpx
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("test", "content")
assert result is None
class TestHttpRAGServiceHealthCheck:
"""HttpRAGService.health_check"""
@pytest.mark.asyncio
async def test_health_check_ok(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_401_still_healthy(self):
"""401 表示服务在运行,只是需要认证"""
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_connection_error(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is False
class TestHttpRAGServiceClient:
"""HttpRAGService HTTP 客户端管理"""
def test_client_lazy_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
assert svc._client is None
client = svc._get_client()
assert client is not None
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
def test_client_reuse(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
c1 = svc._get_client()
c2 = svc._get_client()
assert c1 is c2
@pytest.mark.asyncio
async def test_close(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
svc._get_client() # init client
await svc.close()
assert svc._client is None
@pytest.mark.asyncio
async def test_context_manager(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
async with svc as s:
s._get_client()
assert s._client is not None
assert svc._client is None
# ---------------------------------------------------------------------------
# SemanticMemory + HttpRAGService integration
# ---------------------------------------------------------------------------
class TestSemanticMemoryWithHttpRAG:
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
@pytest.mark.asyncio
async def test_search_delegates_to_rag_service(self):
"""SemanticMemory.search 委托给 HttpRAGService.search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock the search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
items = await semantic.search("行业趋势", top_k=3)
assert len(items) == 1
assert items[0].key == "c1"
assert items[0].value == "行业知识"
assert items[0].score == 0.9
assert items[0].metadata["source"] == "rag"
@pytest.mark.asyncio
async def test_search_no_rag_service_returns_empty(self):
"""无 RAG 服务时返回空列表"""
semantic = SemanticMemory()
items = await semantic.search("test")
assert items == []
class TestMemoryRetrieverWithSemantic:
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
@pytest.mark.asyncio
async def test_retriever_queries_semantic_layer(self):
"""MemoryRetriever 查询 Semantic 层并融合结果"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
retriever = MemoryRetriever(semantic_memory=semantic)
items = await retriever.retrieve("知识查询", top_k=3)
assert len(items) >= 1
# Semantic weight is 0.4 by default
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
# ---------------------------------------------------------------------------
# Config-driven integration tests
# ---------------------------------------------------------------------------
class TestServerConfigMemorySemantic:
"""ServerConfig 解析 memory.semantic 配置"""
def test_from_dict_with_semantic(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1", "kb-2"],
"timeout": 60,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["enabled"] is True
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
assert config.memory["semantic"]["api_key"] == "sk-test"
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
assert config.memory["semantic"]["timeout"] == 60
def test_from_dict_without_memory(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.memory == {}
def test_from_yaml_with_env_var_resolution(self):
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
from agentkit.server.config import ServerConfig
import os
import tempfile
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${TEST_GEO_API_KEY}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-from-env"
del os.environ["TEST_GEO_API_KEY"]
def test_from_yaml_with_default_env_var(self):
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
from agentkit.server.config import ServerConfig
import tempfile
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${NONEXISTENT_KEY:-sk-default}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-default"
# ---------------------------------------------------------------------------
# HttpRAGService enhanced_search tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceEnhancedSearch:
"""HttpRAGService.enhanced_search — 增强语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-1", "kb-2"],
)
@pytest.mark.asyncio
async def test_enhanced_search_single_kb(self, svc):
"""单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.enhanced_search("AI 趋势", top_k=5)
assert len(results) == 1
assert results[0]["content"] == "AI 趋势"
assert results[0]["score"] == 0.95
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/retrieve"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["top_k"] == 5
assert payload["use_rerank"] is True
assert payload["use_compression"] is False
@pytest.mark.asyncio
async def test_enhanced_search_multiple_kbs(self, svc):
"""多知识库增强检索,结果合并并按 score 降序排序"""
# First KB returns one result
resp1 = MagicMock()
resp1.status_code = 200
resp1.raise_for_status = MagicMock()
resp1.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"},
]
}
# Second KB returns one result with higher score
resp2 = MagicMock()
resp2.status_code = 200
resp2.raise_for_status = MagicMock()
resp2.json.return_value = {
"results": [
{"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.enhanced_search("test query", top_k=5)
assert len(results) == 2
# Merged results sorted by score descending
assert results[0]["content"] == "KB2 结果"
assert results[0]["score"] == 0.95
assert results[1]["content"] == "KB1 结果"
assert results[1]["score"] == 0.8
# Verify both KB endpoints were called
calls = mock_client.post.call_args_list
assert calls[0][0][0] == "/bases/kb-1/retrieve"
assert calls[1][0][0] == "/bases/kb-2/retrieve"
@pytest.mark.asyncio
async def test_enhanced_search_404_fallback_single_kb(self, svc):
"""404 响应回退到标准 search 方法(单 KB 场景)"""
import httpx
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"404", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
# Mock the standard search method
svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}])
results = await svc.enhanced_search("test query")
# Should have fallen back to search() for this KB only
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1"], top_k=5)
assert len(results) == 1
assert results[0]["id"] == "fallback"
@pytest.mark.asyncio
async def test_enhanced_search_partial_fallback_one_kb_404(self, svc):
"""KB1 有增强检索KB2 返回 404 → KB1 用增强检索KB2 回退到标准 search"""
import httpx
# KB1 returns enhanced results successfully
resp1 = MagicMock()
resp1.status_code = 200
resp1.raise_for_status = MagicMock()
resp1.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "KB1 enhanced", "score": 0.9, "document_id": "d1"},
]
}
# KB2 returns 404
resp2 = MagicMock()
resp2.status_code = 404
resp2.text = "Not Found"
resp2.raise_for_status.side_effect = httpx.HTTPStatusError(
"404", request=MagicMock(), response=resp2
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
svc._get_client = MagicMock(return_value=mock_client)
# Mock standard search for KB2 fallback only
svc.search = AsyncMock(return_value=[
{"id": "c2", "content": "KB2 standard fallback", "score": 0.7, "source": "rag", "document_id": "d2"},
])
results = await svc.enhanced_search("test query", top_k=5)
# KB1 used enhanced, KB2 fell back to standard search
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-2"], top_k=5)
assert len(results) == 2
# Sorted by score descending
assert results[0]["content"] == "KB1 enhanced"
assert results[0]["score"] == 0.9
assert results[1]["content"] == "KB2 standard fallback"
assert results[1]["score"] == 0.7
@pytest.mark.asyncio
async def test_enhanced_search_all_kbs_404_fallback(self, svc):
"""所有 KB 都返回 404 → 全部回退到标准 search"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"404", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
# Mock standard search — called once per KB
svc.search = AsyncMock(return_value=[
{"id": "c1", "content": "standard result", "score": 0.6, "source": "rag", "document_id": "d1"},
])
results = await svc.enhanced_search("test query", top_k=5)
# search() should be called once per KB (kb-1 and kb-2)
assert svc.search.call_count == 2
svc.search.assert_any_call("test query", knowledge_base_ids=["kb-1"], top_k=5)
svc.search.assert_any_call("test query", knowledge_base_ids=["kb-2"], top_k=5)
assert len(results) == 2
@pytest.mark.asyncio
async def test_enhanced_search_500_raises_exception(self, svc):
"""KB 返回 500 → 抛出异常,不回退到标准 search"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
# 500 should raise, not fallback
with pytest.raises(httpx.HTTPStatusError):
await svc.enhanced_search("test query")
@pytest.mark.asyncio
async def test_enhanced_search_http_error_raises(self, svc):
"""非 404 HTTP 错误抛出异常"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
with pytest.raises(httpx.HTTPStatusError):
await svc.enhanced_search("test query")
@pytest.mark.asyncio
async def test_enhanced_search_with_compression(self, svc):
"""验证 use_compression: true 在 payload 中"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.enhanced_search("test", use_compression=True)
payload = mock_client.post.call_args[1]["json"]
assert payload["use_compression"] is True
@pytest.mark.asyncio
async def test_enhanced_search_without_rerank(self, svc):
"""验证 use_rerank: false 在 payload 中"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.enhanced_search("test", use_rerank=False)
payload = mock_client.post.call_args[1]["json"]
assert payload["use_rerank"] is False
# ---------------------------------------------------------------------------
# SemanticMemory enhanced search mode tests
# ---------------------------------------------------------------------------
class TestSemanticMemoryEnhancedSearch:
"""SemanticMemory search_mode — 增强检索模式"""
@pytest.mark.asyncio
async def test_search_mode_enhanced(self):
"""search_mode="enhanced" 时调用 enhanced_search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock enhanced_search
rag.enhanced_search = AsyncMock(return_value=[
{"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"},
])
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="enhanced",
use_rerank=True,
use_compression=False,
)
items = await semantic.search("test query", top_k=3)
rag.enhanced_search.assert_called_once_with(
"test query",
knowledge_base_ids=["kb-1"],
top_k=3,
use_rerank=True,
use_compression=False,
)
assert len(items) == 1
assert items[0].value == "enhanced result"
@pytest.mark.asyncio
async def test_search_mode_standard(self):
"""search_mode="standard" 时调用标准 search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock standard search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="standard",
)
items = await semantic.search("test query", top_k=3)
assert len(items) == 1
assert items[0].value == "standard result"
# Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
@pytest.mark.asyncio
async def test_search_mode_enhanced_fallback(self):
"""search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search"""
class SimpleRAGService:
"""A RAG service without enhanced_search"""
async def search(self, query, knowledge_base_ids=None, top_k=5):
return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}]
rag = SimpleRAGService()
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="enhanced",
)
items = await semantic.search("test query", top_k=3)
assert len(items) == 1
assert items[0].value == "simple result"
# ---------------------------------------------------------------------------
# Config enhanced search tests
# ---------------------------------------------------------------------------
class TestConfigEnhancedSearch:
"""ServerConfig 解析 enhanced search 相关配置"""
def test_config_search_mode(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1"],
"search_mode": "enhanced",
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["search_mode"] == "enhanced"
def test_config_use_rerank(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1"],
"use_rerank": False,
"use_compression": True,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["use_rerank"] is False
assert config.memory["semantic"]["use_compression"] is True