fischer-agentkit/tests/unit/test_http_rag_service.py

791 lines
27 KiB
Python

"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.retriever import MemoryRetriever
# ---------------------------------------------------------------------------
# HttpRAGService unit tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceInit:
"""HttpRAGService 初始化"""
def test_basic_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
assert svc._base_url == "http://localhost:8000/api/knowledge"
assert svc._api_key is None
assert svc._knowledge_base_ids == []
assert svc._timeout == 30
def test_init_with_all_params(self):
svc = HttpRAGService(
base_url="http://geo:8000/api/knowledge/",
api_key="sk-test",
knowledge_base_ids=["kb-1", "kb-2"],
timeout=60,
)
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
assert svc._api_key == "sk-test"
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
assert svc._timeout == 60
def test_trailing_slash_stripped(self):
svc = HttpRAGService(base_url="http://host/api/")
assert svc._base_url == "http://host/api"
class TestHttpRAGServiceSearch:
"""HttpRAGService.search — 语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-industry", "kb-enterprise"],
)
@pytest.mark.asyncio
async def test_search_standard_response(self, svc):
"""标准 SearchResponse 格式: {"results": [...]}"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{
"chunk_id": "c1",
"content": "AI 行业趋势分析",
"score": 0.92,
"document_id": "d1",
"document_title": "行业报告",
"metadata": {"page": 1},
},
{
"chunk_id": "c2",
"content": "企业数字化转型",
"score": 0.85,
"document_id": "d2",
"document_title": "企业案例",
"metadata": {},
},
],
"total": 2,
"latency_ms": 50,
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("AI 趋势", top_k=5)
assert len(results) == 2
assert results[0]["id"] == "c1"
assert results[0]["content"] == "AI 行业趋势分析"
assert results[0]["score"] == 0.92
assert results[0]["document_id"] == "d1"
assert results[1]["content"] == "企业数字化转型"
# Verify payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
assert payload["top_k"] == 5
@pytest.mark.asyncio
async def test_search_list_response(self, svc):
"""直接返回列表格式: [...]"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = [
{"chunk_id": "c1", "content": "test", "score": 0.8},
]
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert len(results) == 1
assert results[0]["content"] == "test"
@pytest.mark.asyncio
async def test_search_custom_kb_ids(self, svc):
"""传入自定义 knowledge_base_ids 覆盖默认值"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.search("test", knowledge_base_ids=["custom-kb"])
payload = mock_client.post.call_args[1]["json"]
assert payload["knowledge_base_ids"] == ["custom-kb"]
@pytest.mark.asyncio
async def test_search_http_error_returns_empty(self, svc):
"""HTTP 错误返回空列表"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_connection_error_returns_empty(self, svc):
"""连接错误返回空列表"""
import httpx
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_unexpected_format_returns_empty(self, svc):
"""非预期响应格式返回空列表"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"error": "something"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
class TestHttpRAGServiceIngest:
"""HttpRAGService.ingest — 文档写入"""
@pytest.mark.asyncio
async def test_ingest_success(self):
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("测试文档", "文档内容")
assert result["id"] == "doc-1"
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/documents"
payload = call_args[1]["json"]
assert payload["title"] == "测试文档"
assert payload["content"] == "文档内容"
@pytest.mark.asyncio
async def test_ingest_no_kb_ids_returns_none(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
result = await svc.ingest("test", "content")
assert result is None
@pytest.mark.asyncio
async def test_ingest_http_error_returns_none(self):
import httpx
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("test", "content")
assert result is None
class TestHttpRAGServiceHealthCheck:
"""HttpRAGService.health_check"""
@pytest.mark.asyncio
async def test_health_check_ok(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_401_still_healthy(self):
"""401 表示服务在运行,只是需要认证"""
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_connection_error(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is False
class TestHttpRAGServiceClient:
"""HttpRAGService HTTP 客户端管理"""
def test_client_lazy_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
assert svc._client is None
client = svc._get_client()
assert client is not None
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
def test_client_reuse(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
c1 = svc._get_client()
c2 = svc._get_client()
assert c1 is c2
@pytest.mark.asyncio
async def test_close(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
svc._get_client() # init client
await svc.close()
assert svc._client is None
@pytest.mark.asyncio
async def test_context_manager(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
async with svc as s:
s._get_client()
assert s._client is not None
assert svc._client is None
# ---------------------------------------------------------------------------
# SemanticMemory + HttpRAGService integration
# ---------------------------------------------------------------------------
class TestSemanticMemoryWithHttpRAG:
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
@pytest.mark.asyncio
async def test_search_delegates_to_rag_service(self):
"""SemanticMemory.search 委托给 HttpRAGService.search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock the search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
items = await semantic.search("行业趋势", top_k=3)
assert len(items) == 1
assert items[0].key == "c1"
assert items[0].value == "行业知识"
assert items[0].score == 0.9
assert items[0].metadata["source"] == "rag"
@pytest.mark.asyncio
async def test_search_no_rag_service_returns_empty(self):
"""无 RAG 服务时返回空列表"""
semantic = SemanticMemory()
items = await semantic.search("test")
assert items == []
class TestMemoryRetrieverWithSemantic:
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
@pytest.mark.asyncio
async def test_retriever_queries_semantic_layer(self):
"""MemoryRetriever 查询 Semantic 层并融合结果"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
retriever = MemoryRetriever(semantic_memory=semantic)
items = await retriever.retrieve("知识查询", top_k=3)
assert len(items) >= 1
# Semantic weight is 0.4 by default
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
# ---------------------------------------------------------------------------
# Config-driven integration tests
# ---------------------------------------------------------------------------
class TestServerConfigMemorySemantic:
"""ServerConfig 解析 memory.semantic 配置"""
def test_from_dict_with_semantic(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1", "kb-2"],
"timeout": 60,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["enabled"] is True
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
assert config.memory["semantic"]["api_key"] == "sk-test"
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
assert config.memory["semantic"]["timeout"] == 60
def test_from_dict_without_memory(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.memory == {}
def test_from_yaml_with_env_var_resolution(self):
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
from agentkit.server.config import ServerConfig
import os
import tempfile
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${TEST_GEO_API_KEY}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-from-env"
del os.environ["TEST_GEO_API_KEY"]
def test_from_yaml_with_default_env_var(self):
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
from agentkit.server.config import ServerConfig
import tempfile
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${NONEXISTENT_KEY:-sk-default}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-default"
# ---------------------------------------------------------------------------
# HttpRAGService enhanced_search tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceEnhancedSearch:
"""HttpRAGService.enhanced_search — 增强语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-1", "kb-2"],
)
@pytest.mark.asyncio
async def test_enhanced_search_single_kb(self, svc):
"""单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.enhanced_search("AI 趋势", top_k=5)
assert len(results) == 1
assert results[0]["content"] == "AI 趋势"
assert results[0]["score"] == 0.95
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/retrieve"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["top_k"] == 5
assert payload["use_rerank"] is True
assert payload["use_compression"] is False
@pytest.mark.asyncio
async def test_enhanced_search_multiple_kbs(self, svc):
"""多知识库增强检索,结果合并并按 score 降序排序"""
# First KB returns one result
resp1 = MagicMock()
resp1.status_code = 200
resp1.raise_for_status = MagicMock()
resp1.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"},
]
}
# Second KB returns one result with higher score
resp2 = MagicMock()
resp2.status_code = 200
resp2.raise_for_status = MagicMock()
resp2.json.return_value = {
"results": [
{"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.enhanced_search("test query", top_k=5)
assert len(results) == 2
# Merged results sorted by score descending
assert results[0]["content"] == "KB2 结果"
assert results[0]["score"] == 0.95
assert results[1]["content"] == "KB1 结果"
assert results[1]["score"] == 0.8
# Verify both KB endpoints were called
calls = mock_client.post.call_args_list
assert calls[0][0][0] == "/bases/kb-1/retrieve"
assert calls[1][0][0] == "/bases/kb-2/retrieve"
@pytest.mark.asyncio
async def test_enhanced_search_404_fallback(self, svc):
"""404 响应回退到标准 search 方法"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"404", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
# Mock the standard search method
svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}])
results = await svc.enhanced_search("test query")
# Should have fallen back to search()
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1", "kb-2"], top_k=5)
assert len(results) == 1
assert results[0]["id"] == "fallback"
@pytest.mark.asyncio
async def test_enhanced_search_http_error(self, svc):
"""非 404 HTTP 错误返回空列表"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.enhanced_search("test query")
assert results == []
@pytest.mark.asyncio
async def test_enhanced_search_with_compression(self, svc):
"""验证 use_compression: true 在 payload 中"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.enhanced_search("test", use_compression=True)
payload = mock_client.post.call_args[1]["json"]
assert payload["use_compression"] is True
@pytest.mark.asyncio
async def test_enhanced_search_without_rerank(self, svc):
"""验证 use_rerank: false 在 payload 中"""
svc._knowledge_base_ids = ["kb-1"]
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.enhanced_search("test", use_rerank=False)
payload = mock_client.post.call_args[1]["json"]
assert payload["use_rerank"] is False
# ---------------------------------------------------------------------------
# SemanticMemory enhanced search mode tests
# ---------------------------------------------------------------------------
class TestSemanticMemoryEnhancedSearch:
"""SemanticMemory search_mode — 增强检索模式"""
@pytest.mark.asyncio
async def test_search_mode_enhanced(self):
"""search_mode="enhanced" 时调用 enhanced_search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock enhanced_search
rag.enhanced_search = AsyncMock(return_value=[
{"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"},
])
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="enhanced",
use_rerank=True,
use_compression=False,
)
items = await semantic.search("test query", top_k=3)
rag.enhanced_search.assert_called_once_with(
"test query",
knowledge_base_ids=["kb-1"],
top_k=3,
use_rerank=True,
use_compression=False,
)
assert len(items) == 1
assert items[0].value == "enhanced result"
@pytest.mark.asyncio
async def test_search_mode_standard(self):
"""search_mode="standard" 时调用标准 search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock standard search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="standard",
)
items = await semantic.search("test query", top_k=3)
assert len(items) == 1
assert items[0].value == "standard result"
# Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
@pytest.mark.asyncio
async def test_search_mode_enhanced_fallback(self):
"""search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search"""
class SimpleRAGService:
"""A RAG service without enhanced_search"""
async def search(self, query, knowledge_base_ids=None, top_k=5):
return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}]
rag = SimpleRAGService()
semantic = SemanticMemory(
rag_service=rag,
knowledge_base_ids=["kb-1"],
search_mode="enhanced",
)
items = await semantic.search("test query", top_k=3)
assert len(items) == 1
assert items[0].value == "simple result"
# ---------------------------------------------------------------------------
# Config enhanced search tests
# ---------------------------------------------------------------------------
class TestConfigEnhancedSearch:
"""ServerConfig 解析 enhanced search 相关配置"""
def test_config_search_mode(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1"],
"search_mode": "enhanced",
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["search_mode"] == "enhanced"
def test_config_use_rerank(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1"],
"use_rerank": False,
"use_compression": True,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["use_rerank"] is False
assert config.memory["semantic"]["use_compression"] is True