885 lines
31 KiB
Python
885 lines
31 KiB
Python
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
|
||
|
||
import json
|
||
import pytest
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from agentkit.memory.http_rag import HttpRAGService
|
||
from agentkit.memory.semantic import SemanticMemory
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# HttpRAGService unit tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestHttpRAGServiceInit:
|
||
"""HttpRAGService 初始化"""
|
||
|
||
def test_basic_init(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
assert svc._base_url == "http://localhost:8000/api/knowledge"
|
||
assert svc._api_key is None
|
||
assert svc._knowledge_base_ids == []
|
||
assert svc._timeout == 30
|
||
|
||
def test_init_with_all_params(self):
|
||
svc = HttpRAGService(
|
||
base_url="http://geo:8000/api/knowledge/",
|
||
api_key="sk-test",
|
||
knowledge_base_ids=["kb-1", "kb-2"],
|
||
timeout=60,
|
||
)
|
||
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
|
||
assert svc._api_key == "sk-test"
|
||
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
|
||
assert svc._timeout == 60
|
||
|
||
def test_trailing_slash_stripped(self):
|
||
svc = HttpRAGService(base_url="http://host/api/")
|
||
assert svc._base_url == "http://host/api"
|
||
|
||
|
||
class TestHttpRAGServiceSearch:
|
||
"""HttpRAGService.search — 语义检索"""
|
||
|
||
@pytest.fixture
|
||
def svc(self):
|
||
return HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
api_key="test-key",
|
||
knowledge_base_ids=["kb-industry", "kb-enterprise"],
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_standard_response(self, svc):
|
||
"""标准 SearchResponse 格式: {"results": [...]}"""
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{
|
||
"chunk_id": "c1",
|
||
"content": "AI 行业趋势分析",
|
||
"score": 0.92,
|
||
"document_id": "d1",
|
||
"document_title": "行业报告",
|
||
"metadata": {"page": 1},
|
||
},
|
||
{
|
||
"chunk_id": "c2",
|
||
"content": "企业数字化转型",
|
||
"score": 0.85,
|
||
"document_id": "d2",
|
||
"document_title": "企业案例",
|
||
"metadata": {},
|
||
},
|
||
],
|
||
"total": 2,
|
||
"latency_ms": 50,
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.search("AI 趋势", top_k=5)
|
||
|
||
assert len(results) == 2
|
||
assert results[0]["id"] == "c1"
|
||
assert results[0]["content"] == "AI 行业趋势分析"
|
||
assert results[0]["score"] == 0.92
|
||
assert results[0]["document_id"] == "d1"
|
||
assert results[1]["content"] == "企业数字化转型"
|
||
|
||
# Verify payload
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/search"
|
||
payload = call_args[1]["json"]
|
||
assert payload["query"] == "AI 趋势"
|
||
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
|
||
assert payload["top_k"] == 5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_list_response(self, svc):
|
||
"""直接返回列表格式: [...]"""
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = [
|
||
{"chunk_id": "c1", "content": "test", "score": 0.8},
|
||
]
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.search("test")
|
||
assert len(results) == 1
|
||
assert results[0]["content"] == "test"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_custom_kb_ids(self, svc):
|
||
"""传入自定义 knowledge_base_ids 覆盖默认值"""
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"results": []}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
await svc.search("test", knowledge_base_ids=["custom-kb"])
|
||
|
||
payload = mock_client.post.call_args[1]["json"]
|
||
assert payload["knowledge_base_ids"] == ["custom-kb"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_http_error_returns_empty(self, svc):
|
||
"""HTTP 错误返回空列表"""
|
||
import httpx
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.search("test")
|
||
assert results == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_connection_error_returns_empty(self, svc):
|
||
"""连接错误返回空列表"""
|
||
import httpx
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.search("test")
|
||
assert results == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_unexpected_format_returns_empty(self, svc):
|
||
"""非预期响应格式返回空列表"""
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"error": "something"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.search("test")
|
||
assert results == []
|
||
|
||
|
||
class TestHttpRAGServiceIngest:
|
||
"""HttpRAGService.ingest — 文档写入"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_success(self):
|
||
svc = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 201
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await svc.ingest("测试文档", "文档内容")
|
||
assert result["id"] == "doc-1"
|
||
|
||
# Verify endpoint and payload
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/bases/kb-1/documents"
|
||
payload = call_args[1]["json"]
|
||
assert payload["title"] == "测试文档"
|
||
assert payload["content"] == "文档内容"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_no_kb_ids_returns_none(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
result = await svc.ingest("test", "content")
|
||
assert result is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_http_error_returns_none(self):
|
||
import httpx
|
||
svc = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
result = await svc.ingest("test", "content")
|
||
assert result is None
|
||
|
||
|
||
class TestHttpRAGServiceHealthCheck:
|
||
"""HttpRAGService.health_check"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_ok(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await svc.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_401_still_healthy(self):
|
||
"""401 表示服务在运行,只是需要认证"""
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 401
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await svc.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_connection_error(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
assert await svc.health_check() is False
|
||
|
||
|
||
class TestHttpRAGServiceClient:
|
||
"""HttpRAGService HTTP 客户端管理"""
|
||
|
||
def test_client_lazy_init(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
|
||
assert svc._client is None
|
||
|
||
client = svc._get_client()
|
||
assert client is not None
|
||
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
|
||
|
||
def test_client_reuse(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
c1 = svc._get_client()
|
||
c2 = svc._get_client()
|
||
assert c1 is c2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_close(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
svc._get_client() # init client
|
||
await svc.close()
|
||
assert svc._client is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_context_manager(self):
|
||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||
async with svc as s:
|
||
s._get_client()
|
||
assert s._client is not None
|
||
assert svc._client is None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SemanticMemory + HttpRAGService integration
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSemanticMemoryWithHttpRAG:
|
||
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_delegates_to_rag_service(self):
|
||
"""SemanticMemory.search 委托给 HttpRAGService.search"""
|
||
rag = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
# Mock the search
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
rag._get_client = MagicMock(return_value=mock_client)
|
||
|
||
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||
items = await semantic.search("行业趋势", top_k=3)
|
||
|
||
assert len(items) == 1
|
||
assert items[0].key == "c1"
|
||
assert items[0].value == "行业知识"
|
||
assert items[0].score == 0.9
|
||
assert items[0].metadata["source"] == "rag"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_no_rag_service_returns_empty(self):
|
||
"""无 RAG 服务时返回空列表"""
|
||
semantic = SemanticMemory()
|
||
items = await semantic.search("test")
|
||
assert items == []
|
||
|
||
|
||
class TestMemoryRetrieverWithSemantic:
|
||
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retriever_queries_semantic_layer(self):
|
||
"""MemoryRetriever 查询 Semantic 层并融合结果"""
|
||
rag = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
rag._get_client = MagicMock(return_value=mock_client)
|
||
|
||
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||
retriever = MemoryRetriever(semantic_memory=semantic)
|
||
|
||
items = await retriever.retrieve("知识查询", top_k=3)
|
||
|
||
assert len(items) >= 1
|
||
# Semantic weight is 0.4 by default
|
||
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config-driven integration tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestServerConfigMemorySemantic:
|
||
"""ServerConfig 解析 memory.semantic 配置"""
|
||
|
||
def test_from_dict_with_semantic(self):
|
||
from agentkit.server.config import ServerConfig
|
||
|
||
data = {
|
||
"memory": {
|
||
"semantic": {
|
||
"enabled": True,
|
||
"base_url": "http://geo:8000/api/knowledge",
|
||
"api_key": "sk-test",
|
||
"knowledge_base_ids": ["kb-1", "kb-2"],
|
||
"timeout": 60,
|
||
},
|
||
},
|
||
}
|
||
config = ServerConfig.from_dict(data)
|
||
assert config.memory["semantic"]["enabled"] is True
|
||
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
|
||
assert config.memory["semantic"]["api_key"] == "sk-test"
|
||
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
|
||
assert config.memory["semantic"]["timeout"] == 60
|
||
|
||
def test_from_dict_without_memory(self):
|
||
from agentkit.server.config import ServerConfig
|
||
|
||
config = ServerConfig.from_dict({})
|
||
assert config.memory == {}
|
||
|
||
def test_from_yaml_with_env_var_resolution(self):
|
||
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
|
||
from agentkit.server.config import ServerConfig
|
||
import os
|
||
import tempfile
|
||
|
||
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
|
||
|
||
yaml_content = """
|
||
memory:
|
||
semantic:
|
||
enabled: true
|
||
base_url: http://geo:8000/api/knowledge
|
||
api_key: "${TEST_GEO_API_KEY}"
|
||
"""
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
f.write(yaml_content)
|
||
f.flush()
|
||
config = ServerConfig.from_yaml(f.name)
|
||
|
||
assert config.memory["semantic"]["api_key"] == "sk-from-env"
|
||
del os.environ["TEST_GEO_API_KEY"]
|
||
|
||
def test_from_yaml_with_default_env_var(self):
|
||
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
|
||
from agentkit.server.config import ServerConfig
|
||
import tempfile
|
||
|
||
yaml_content = """
|
||
memory:
|
||
semantic:
|
||
enabled: true
|
||
base_url: http://geo:8000/api/knowledge
|
||
api_key: "${NONEXISTENT_KEY:-sk-default}"
|
||
"""
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
f.write(yaml_content)
|
||
f.flush()
|
||
config = ServerConfig.from_yaml(f.name)
|
||
|
||
assert config.memory["semantic"]["api_key"] == "sk-default"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# HttpRAGService enhanced_search tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestHttpRAGServiceEnhancedSearch:
|
||
"""HttpRAGService.enhanced_search — 增强语义检索"""
|
||
|
||
@pytest.fixture
|
||
def svc(self):
|
||
return HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
api_key="test-key",
|
||
knowledge_base_ids=["kb-1", "kb-2"],
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_single_kb(self, svc):
|
||
"""单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression"""
|
||
svc._knowledge_base_ids = ["kb-1"]
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.enhanced_search("AI 趋势", top_k=5)
|
||
|
||
assert len(results) == 1
|
||
assert results[0]["content"] == "AI 趋势"
|
||
assert results[0]["score"] == 0.95
|
||
|
||
# Verify endpoint and payload
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/bases/kb-1/retrieve"
|
||
payload = call_args[1]["json"]
|
||
assert payload["query"] == "AI 趋势"
|
||
assert payload["top_k"] == 5
|
||
assert payload["use_rerank"] is True
|
||
assert payload["use_compression"] is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_multiple_kbs(self, svc):
|
||
"""多知识库增强检索,结果合并并按 score 降序排序"""
|
||
# First KB returns one result
|
||
resp1 = MagicMock()
|
||
resp1.status_code = 200
|
||
resp1.raise_for_status = MagicMock()
|
||
resp1.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
# Second KB returns one result with higher score
|
||
resp2 = MagicMock()
|
||
resp2.status_code = 200
|
||
resp2.raise_for_status = MagicMock()
|
||
resp2.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
results = await svc.enhanced_search("test query", top_k=5)
|
||
|
||
assert len(results) == 2
|
||
# Merged results sorted by score descending
|
||
assert results[0]["content"] == "KB2 结果"
|
||
assert results[0]["score"] == 0.95
|
||
assert results[1]["content"] == "KB1 结果"
|
||
assert results[1]["score"] == 0.8
|
||
|
||
# Verify both KB endpoints were called
|
||
calls = mock_client.post.call_args_list
|
||
assert calls[0][0][0] == "/bases/kb-1/retrieve"
|
||
assert calls[1][0][0] == "/bases/kb-2/retrieve"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_404_fallback_single_kb(self, svc):
|
||
"""404 响应回退到标准 search 方法(单 KB 场景)"""
|
||
import httpx
|
||
|
||
svc._knowledge_base_ids = ["kb-1"]
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 404
|
||
mock_resp.text = "Not Found"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"404", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
# Mock the standard search method
|
||
svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}])
|
||
|
||
results = await svc.enhanced_search("test query")
|
||
|
||
# Should have fallen back to search() for this KB only
|
||
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1"], top_k=5)
|
||
assert len(results) == 1
|
||
assert results[0]["id"] == "fallback"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_partial_fallback_one_kb_404(self, svc):
|
||
"""KB1 有增强检索,KB2 返回 404 → KB1 用增强检索,KB2 回退到标准 search"""
|
||
import httpx
|
||
|
||
# KB1 returns enhanced results successfully
|
||
resp1 = MagicMock()
|
||
resp1.status_code = 200
|
||
resp1.raise_for_status = MagicMock()
|
||
resp1.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "KB1 enhanced", "score": 0.9, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
# KB2 returns 404
|
||
resp2 = MagicMock()
|
||
resp2.status_code = 404
|
||
resp2.text = "Not Found"
|
||
resp2.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"404", request=MagicMock(), response=resp2
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
# Mock standard search for KB2 fallback only
|
||
svc.search = AsyncMock(return_value=[
|
||
{"id": "c2", "content": "KB2 standard fallback", "score": 0.7, "source": "rag", "document_id": "d2"},
|
||
])
|
||
|
||
results = await svc.enhanced_search("test query", top_k=5)
|
||
|
||
# KB1 used enhanced, KB2 fell back to standard search
|
||
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-2"], top_k=5)
|
||
assert len(results) == 2
|
||
# Sorted by score descending
|
||
assert results[0]["content"] == "KB1 enhanced"
|
||
assert results[0]["score"] == 0.9
|
||
assert results[1]["content"] == "KB2 standard fallback"
|
||
assert results[1]["score"] == 0.7
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_all_kbs_404_fallback(self, svc):
|
||
"""所有 KB 都返回 404 → 全部回退到标准 search"""
|
||
import httpx
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 404
|
||
mock_resp.text = "Not Found"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"404", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
# Mock standard search — called once per KB
|
||
svc.search = AsyncMock(return_value=[
|
||
{"id": "c1", "content": "standard result", "score": 0.6, "source": "rag", "document_id": "d1"},
|
||
])
|
||
|
||
results = await svc.enhanced_search("test query", top_k=5)
|
||
|
||
# search() should be called once per KB (kb-1 and kb-2)
|
||
assert svc.search.call_count == 2
|
||
svc.search.assert_any_call("test query", knowledge_base_ids=["kb-1"], top_k=5)
|
||
svc.search.assert_any_call("test query", knowledge_base_ids=["kb-2"], top_k=5)
|
||
assert len(results) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_500_raises_exception(self, svc):
|
||
"""KB 返回 500 → 抛出异常,不回退到标准 search"""
|
||
import httpx
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
# 500 should raise, not fallback
|
||
with pytest.raises(httpx.HTTPStatusError):
|
||
await svc.enhanced_search("test query")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_http_error_raises(self, svc):
|
||
"""非 404 HTTP 错误抛出异常"""
|
||
import httpx
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 500
|
||
mock_resp.text = "Internal Server Error"
|
||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||
"500", request=MagicMock(), response=mock_resp
|
||
)
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
with pytest.raises(httpx.HTTPStatusError):
|
||
await svc.enhanced_search("test query")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_with_compression(self, svc):
|
||
"""验证 use_compression: true 在 payload 中"""
|
||
svc._knowledge_base_ids = ["kb-1"]
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"results": []}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
await svc.enhanced_search("test", use_compression=True)
|
||
|
||
payload = mock_client.post.call_args[1]["json"]
|
||
assert payload["use_compression"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_enhanced_search_without_rerank(self, svc):
|
||
"""验证 use_rerank: false 在 payload 中"""
|
||
svc._knowledge_base_ids = ["kb-1"]
|
||
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {"results": []}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
svc._get_client = MagicMock(return_value=mock_client)
|
||
|
||
await svc.enhanced_search("test", use_rerank=False)
|
||
|
||
payload = mock_client.post.call_args[1]["json"]
|
||
assert payload["use_rerank"] is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SemanticMemory enhanced search mode tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSemanticMemoryEnhancedSearch:
|
||
"""SemanticMemory search_mode — 增强检索模式"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_mode_enhanced(self):
|
||
"""search_mode="enhanced" 时调用 enhanced_search"""
|
||
rag = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
# Mock enhanced_search
|
||
rag.enhanced_search = AsyncMock(return_value=[
|
||
{"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"},
|
||
])
|
||
|
||
semantic = SemanticMemory(
|
||
rag_service=rag,
|
||
knowledge_base_ids=["kb-1"],
|
||
search_mode="enhanced",
|
||
use_rerank=True,
|
||
use_compression=False,
|
||
)
|
||
|
||
items = await semantic.search("test query", top_k=3)
|
||
|
||
rag.enhanced_search.assert_called_once_with(
|
||
"test query",
|
||
knowledge_base_ids=["kb-1"],
|
||
top_k=3,
|
||
use_rerank=True,
|
||
use_compression=False,
|
||
)
|
||
assert len(items) == 1
|
||
assert items[0].value == "enhanced result"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_mode_standard(self):
|
||
"""search_mode="standard" 时调用标准 search"""
|
||
rag = HttpRAGService(
|
||
base_url="http://localhost:8000/api/knowledge",
|
||
knowledge_base_ids=["kb-1"],
|
||
)
|
||
|
||
# Mock standard search
|
||
mock_resp = MagicMock()
|
||
mock_resp.status_code = 200
|
||
mock_resp.raise_for_status = MagicMock()
|
||
mock_resp.json.return_value = {
|
||
"results": [
|
||
{"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"},
|
||
]
|
||
}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||
rag._get_client = MagicMock(return_value=mock_client)
|
||
|
||
semantic = SemanticMemory(
|
||
rag_service=rag,
|
||
knowledge_base_ids=["kb-1"],
|
||
search_mode="standard",
|
||
)
|
||
|
||
items = await semantic.search("test query", top_k=3)
|
||
|
||
assert len(items) == 1
|
||
assert items[0].value == "standard result"
|
||
# Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve
|
||
call_args = mock_client.post.call_args
|
||
assert call_args[0][0] == "/search"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_mode_enhanced_fallback(self):
|
||
"""search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search"""
|
||
|
||
class SimpleRAGService:
|
||
"""A RAG service without enhanced_search"""
|
||
async def search(self, query, knowledge_base_ids=None, top_k=5):
|
||
return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}]
|
||
|
||
rag = SimpleRAGService()
|
||
semantic = SemanticMemory(
|
||
rag_service=rag,
|
||
knowledge_base_ids=["kb-1"],
|
||
search_mode="enhanced",
|
||
)
|
||
|
||
items = await semantic.search("test query", top_k=3)
|
||
|
||
assert len(items) == 1
|
||
assert items[0].value == "simple result"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config enhanced search tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestConfigEnhancedSearch:
|
||
"""ServerConfig 解析 enhanced search 相关配置"""
|
||
|
||
def test_config_search_mode(self):
|
||
from agentkit.server.config import ServerConfig
|
||
|
||
data = {
|
||
"memory": {
|
||
"semantic": {
|
||
"enabled": True,
|
||
"base_url": "http://geo:8000/api/knowledge",
|
||
"api_key": "sk-test",
|
||
"knowledge_base_ids": ["kb-1"],
|
||
"search_mode": "enhanced",
|
||
},
|
||
},
|
||
}
|
||
config = ServerConfig.from_dict(data)
|
||
assert config.memory["semantic"]["search_mode"] == "enhanced"
|
||
|
||
def test_config_use_rerank(self):
|
||
from agentkit.server.config import ServerConfig
|
||
|
||
data = {
|
||
"memory": {
|
||
"semantic": {
|
||
"enabled": True,
|
||
"base_url": "http://geo:8000/api/knowledge",
|
||
"api_key": "sk-test",
|
||
"knowledge_base_ids": ["kb-1"],
|
||
"use_rerank": False,
|
||
"use_compression": True,
|
||
},
|
||
},
|
||
}
|
||
config = ServerConfig.from_dict(data)
|
||
assert config.memory["semantic"]["use_rerank"] is False
|
||
assert config.memory["semantic"]["use_compression"] is True
|