791 lines
27 KiB
Python
791 lines
27 KiB
Python
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from agentkit.memory.http_rag import HttpRAGService
|
|
from agentkit.memory.semantic import SemanticMemory
|
|
from agentkit.memory.retriever import MemoryRetriever
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HttpRAGService unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHttpRAGServiceInit:
|
|
"""HttpRAGService 初始化"""
|
|
|
|
def test_basic_init(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
assert svc._base_url == "http://localhost:8000/api/knowledge"
|
|
assert svc._api_key is None
|
|
assert svc._knowledge_base_ids == []
|
|
assert svc._timeout == 30
|
|
|
|
def test_init_with_all_params(self):
|
|
svc = HttpRAGService(
|
|
base_url="http://geo:8000/api/knowledge/",
|
|
api_key="sk-test",
|
|
knowledge_base_ids=["kb-1", "kb-2"],
|
|
timeout=60,
|
|
)
|
|
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
|
|
assert svc._api_key == "sk-test"
|
|
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
|
|
assert svc._timeout == 60
|
|
|
|
def test_trailing_slash_stripped(self):
|
|
svc = HttpRAGService(base_url="http://host/api/")
|
|
assert svc._base_url == "http://host/api"
|
|
|
|
|
|
class TestHttpRAGServiceSearch:
|
|
"""HttpRAGService.search — 语义检索"""
|
|
|
|
@pytest.fixture
|
|
def svc(self):
|
|
return HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
api_key="test-key",
|
|
knowledge_base_ids=["kb-industry", "kb-enterprise"],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_standard_response(self, svc):
|
|
"""标准 SearchResponse 格式: {"results": [...]}"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"results": [
|
|
{
|
|
"chunk_id": "c1",
|
|
"content": "AI 行业趋势分析",
|
|
"score": 0.92,
|
|
"document_id": "d1",
|
|
"document_title": "行业报告",
|
|
"metadata": {"page": 1},
|
|
},
|
|
{
|
|
"chunk_id": "c2",
|
|
"content": "企业数字化转型",
|
|
"score": 0.85,
|
|
"document_id": "d2",
|
|
"document_title": "企业案例",
|
|
"metadata": {},
|
|
},
|
|
],
|
|
"total": 2,
|
|
"latency_ms": 50,
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.search("AI 趋势", top_k=5)
|
|
|
|
assert len(results) == 2
|
|
assert results[0]["id"] == "c1"
|
|
assert results[0]["content"] == "AI 行业趋势分析"
|
|
assert results[0]["score"] == 0.92
|
|
assert results[0]["document_id"] == "d1"
|
|
assert results[1]["content"] == "企业数字化转型"
|
|
|
|
# Verify payload
|
|
call_args = mock_client.post.call_args
|
|
assert call_args[0][0] == "/search"
|
|
payload = call_args[1]["json"]
|
|
assert payload["query"] == "AI 趋势"
|
|
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
|
|
assert payload["top_k"] == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_list_response(self, svc):
|
|
"""直接返回列表格式: [...]"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = [
|
|
{"chunk_id": "c1", "content": "test", "score": 0.8},
|
|
]
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.search("test")
|
|
assert len(results) == 1
|
|
assert results[0]["content"] == "test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_custom_kb_ids(self, svc):
|
|
"""传入自定义 knowledge_base_ids 覆盖默认值"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {"results": []}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
await svc.search("test", knowledge_base_ids=["custom-kb"])
|
|
|
|
payload = mock_client.post.call_args[1]["json"]
|
|
assert payload["knowledge_base_ids"] == ["custom-kb"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_http_error_returns_empty(self, svc):
|
|
"""HTTP 错误返回空列表"""
|
|
import httpx
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "Internal Server Error"
|
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
"500", request=MagicMock(), response=mock_resp
|
|
)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.search("test")
|
|
assert results == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_connection_error_returns_empty(self, svc):
|
|
"""连接错误返回空列表"""
|
|
import httpx
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.search("test")
|
|
assert results == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_unexpected_format_returns_empty(self, svc):
|
|
"""非预期响应格式返回空列表"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {"error": "something"}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.search("test")
|
|
assert results == []
|
|
|
|
|
|
class TestHttpRAGServiceIngest:
|
|
"""HttpRAGService.ingest — 文档写入"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_success(self):
|
|
svc = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 201
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
result = await svc.ingest("测试文档", "文档内容")
|
|
assert result["id"] == "doc-1"
|
|
|
|
# Verify endpoint and payload
|
|
call_args = mock_client.post.call_args
|
|
assert call_args[0][0] == "/bases/kb-1/documents"
|
|
payload = call_args[1]["json"]
|
|
assert payload["title"] == "测试文档"
|
|
assert payload["content"] == "文档内容"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_no_kb_ids_returns_none(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
result = await svc.ingest("test", "content")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_http_error_returns_none(self):
|
|
import httpx
|
|
svc = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
"500", request=MagicMock(), response=mock_resp
|
|
)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
result = await svc.ingest("test", "content")
|
|
assert result is None
|
|
|
|
|
|
class TestHttpRAGServiceHealthCheck:
|
|
"""HttpRAGService.health_check"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_ok(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
assert await svc.health_check() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_401_still_healthy(self):
|
|
"""401 表示服务在运行,只是需要认证"""
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 401
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
assert await svc.health_check() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_connection_error(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
assert await svc.health_check() is False
|
|
|
|
|
|
class TestHttpRAGServiceClient:
|
|
"""HttpRAGService HTTP 客户端管理"""
|
|
|
|
def test_client_lazy_init(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
|
|
assert svc._client is None
|
|
|
|
client = svc._get_client()
|
|
assert client is not None
|
|
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
|
|
|
|
def test_client_reuse(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
c1 = svc._get_client()
|
|
c2 = svc._get_client()
|
|
assert c1 is c2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
svc._get_client() # init client
|
|
await svc.close()
|
|
assert svc._client is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_manager(self):
|
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
|
async with svc as s:
|
|
s._get_client()
|
|
assert s._client is not None
|
|
assert svc._client is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SemanticMemory + HttpRAGService integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSemanticMemoryWithHttpRAG:
|
|
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_delegates_to_rag_service(self):
|
|
"""SemanticMemory.search 委托给 HttpRAGService.search"""
|
|
rag = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
# Mock the search
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
rag._get_client = MagicMock(return_value=mock_client)
|
|
|
|
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
|
items = await semantic.search("行业趋势", top_k=3)
|
|
|
|
assert len(items) == 1
|
|
assert items[0].key == "c1"
|
|
assert items[0].value == "行业知识"
|
|
assert items[0].score == 0.9
|
|
assert items[0].metadata["source"] == "rag"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_no_rag_service_returns_empty(self):
|
|
"""无 RAG 服务时返回空列表"""
|
|
semantic = SemanticMemory()
|
|
items = await semantic.search("test")
|
|
assert items == []
|
|
|
|
|
|
class TestMemoryRetrieverWithSemantic:
|
|
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retriever_queries_semantic_layer(self):
|
|
"""MemoryRetriever 查询 Semantic 层并融合结果"""
|
|
rag = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
rag._get_client = MagicMock(return_value=mock_client)
|
|
|
|
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
items = await retriever.retrieve("知识查询", top_k=3)
|
|
|
|
assert len(items) >= 1
|
|
# Semantic weight is 0.4 by default
|
|
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config-driven integration tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestServerConfigMemorySemantic:
|
|
"""ServerConfig 解析 memory.semantic 配置"""
|
|
|
|
def test_from_dict_with_semantic(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://geo:8000/api/knowledge",
|
|
"api_key": "sk-test",
|
|
"knowledge_base_ids": ["kb-1", "kb-2"],
|
|
"timeout": 60,
|
|
},
|
|
},
|
|
}
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["semantic"]["enabled"] is True
|
|
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
|
|
assert config.memory["semantic"]["api_key"] == "sk-test"
|
|
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
|
|
assert config.memory["semantic"]["timeout"] == 60
|
|
|
|
def test_from_dict_without_memory(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
config = ServerConfig.from_dict({})
|
|
assert config.memory == {}
|
|
|
|
def test_from_yaml_with_env_var_resolution(self):
|
|
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
|
|
from agentkit.server.config import ServerConfig
|
|
import os
|
|
import tempfile
|
|
|
|
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
|
|
|
|
yaml_content = """
|
|
memory:
|
|
semantic:
|
|
enabled: true
|
|
base_url: http://geo:8000/api/knowledge
|
|
api_key: "${TEST_GEO_API_KEY}"
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
f.write(yaml_content)
|
|
f.flush()
|
|
config = ServerConfig.from_yaml(f.name)
|
|
|
|
assert config.memory["semantic"]["api_key"] == "sk-from-env"
|
|
del os.environ["TEST_GEO_API_KEY"]
|
|
|
|
def test_from_yaml_with_default_env_var(self):
|
|
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
|
|
from agentkit.server.config import ServerConfig
|
|
import tempfile
|
|
|
|
yaml_content = """
|
|
memory:
|
|
semantic:
|
|
enabled: true
|
|
base_url: http://geo:8000/api/knowledge
|
|
api_key: "${NONEXISTENT_KEY:-sk-default}"
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
f.write(yaml_content)
|
|
f.flush()
|
|
config = ServerConfig.from_yaml(f.name)
|
|
|
|
assert config.memory["semantic"]["api_key"] == "sk-default"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HttpRAGService enhanced_search tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHttpRAGServiceEnhancedSearch:
|
|
"""HttpRAGService.enhanced_search — 增强语义检索"""
|
|
|
|
@pytest.fixture
|
|
def svc(self):
|
|
return HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
api_key="test-key",
|
|
knowledge_base_ids=["kb-1", "kb-2"],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_single_kb(self, svc):
|
|
"""单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression"""
|
|
svc._knowledge_base_ids = ["kb-1"]
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.enhanced_search("AI 趋势", top_k=5)
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["content"] == "AI 趋势"
|
|
assert results[0]["score"] == 0.95
|
|
|
|
# Verify endpoint and payload
|
|
call_args = mock_client.post.call_args
|
|
assert call_args[0][0] == "/bases/kb-1/retrieve"
|
|
payload = call_args[1]["json"]
|
|
assert payload["query"] == "AI 趋势"
|
|
assert payload["top_k"] == 5
|
|
assert payload["use_rerank"] is True
|
|
assert payload["use_compression"] is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_multiple_kbs(self, svc):
|
|
"""多知识库增强检索,结果合并并按 score 降序排序"""
|
|
# First KB returns one result
|
|
resp1 = MagicMock()
|
|
resp1.status_code = 200
|
|
resp1.raise_for_status = MagicMock()
|
|
resp1.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"},
|
|
]
|
|
}
|
|
|
|
# Second KB returns one result with higher score
|
|
resp2 = MagicMock()
|
|
resp2.status_code = 200
|
|
resp2.raise_for_status = MagicMock()
|
|
resp2.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.enhanced_search("test query", top_k=5)
|
|
|
|
assert len(results) == 2
|
|
# Merged results sorted by score descending
|
|
assert results[0]["content"] == "KB2 结果"
|
|
assert results[0]["score"] == 0.95
|
|
assert results[1]["content"] == "KB1 结果"
|
|
assert results[1]["score"] == 0.8
|
|
|
|
# Verify both KB endpoints were called
|
|
calls = mock_client.post.call_args_list
|
|
assert calls[0][0][0] == "/bases/kb-1/retrieve"
|
|
assert calls[1][0][0] == "/bases/kb-2/retrieve"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_404_fallback(self, svc):
|
|
"""404 响应回退到标准 search 方法"""
|
|
import httpx
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 404
|
|
mock_resp.text = "Not Found"
|
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
"404", request=MagicMock(), response=mock_resp
|
|
)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
# Mock the standard search method
|
|
svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}])
|
|
|
|
results = await svc.enhanced_search("test query")
|
|
|
|
# Should have fallen back to search()
|
|
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1", "kb-2"], top_k=5)
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "fallback"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_http_error(self, svc):
|
|
"""非 404 HTTP 错误返回空列表"""
|
|
import httpx
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "Internal Server Error"
|
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
|
"500", request=MagicMock(), response=mock_resp
|
|
)
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
results = await svc.enhanced_search("test query")
|
|
assert results == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_with_compression(self, svc):
|
|
"""验证 use_compression: true 在 payload 中"""
|
|
svc._knowledge_base_ids = ["kb-1"]
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {"results": []}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
await svc.enhanced_search("test", use_compression=True)
|
|
|
|
payload = mock_client.post.call_args[1]["json"]
|
|
assert payload["use_compression"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhanced_search_without_rerank(self, svc):
|
|
"""验证 use_rerank: false 在 payload 中"""
|
|
svc._knowledge_base_ids = ["kb-1"]
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {"results": []}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
svc._get_client = MagicMock(return_value=mock_client)
|
|
|
|
await svc.enhanced_search("test", use_rerank=False)
|
|
|
|
payload = mock_client.post.call_args[1]["json"]
|
|
assert payload["use_rerank"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SemanticMemory enhanced search mode tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSemanticMemoryEnhancedSearch:
|
|
"""SemanticMemory search_mode — 增强检索模式"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_mode_enhanced(self):
|
|
"""search_mode="enhanced" 时调用 enhanced_search"""
|
|
rag = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
# Mock enhanced_search
|
|
rag.enhanced_search = AsyncMock(return_value=[
|
|
{"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"},
|
|
])
|
|
|
|
semantic = SemanticMemory(
|
|
rag_service=rag,
|
|
knowledge_base_ids=["kb-1"],
|
|
search_mode="enhanced",
|
|
use_rerank=True,
|
|
use_compression=False,
|
|
)
|
|
|
|
items = await semantic.search("test query", top_k=3)
|
|
|
|
rag.enhanced_search.assert_called_once_with(
|
|
"test query",
|
|
knowledge_base_ids=["kb-1"],
|
|
top_k=3,
|
|
use_rerank=True,
|
|
use_compression=False,
|
|
)
|
|
assert len(items) == 1
|
|
assert items[0].value == "enhanced result"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_mode_standard(self):
|
|
"""search_mode="standard" 时调用标准 search"""
|
|
rag = HttpRAGService(
|
|
base_url="http://localhost:8000/api/knowledge",
|
|
knowledge_base_ids=["kb-1"],
|
|
)
|
|
|
|
# Mock standard search
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.raise_for_status = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"results": [
|
|
{"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"},
|
|
]
|
|
}
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
|
rag._get_client = MagicMock(return_value=mock_client)
|
|
|
|
semantic = SemanticMemory(
|
|
rag_service=rag,
|
|
knowledge_base_ids=["kb-1"],
|
|
search_mode="standard",
|
|
)
|
|
|
|
items = await semantic.search("test query", top_k=3)
|
|
|
|
assert len(items) == 1
|
|
assert items[0].value == "standard result"
|
|
# Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve
|
|
call_args = mock_client.post.call_args
|
|
assert call_args[0][0] == "/search"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_mode_enhanced_fallback(self):
|
|
"""search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search"""
|
|
|
|
class SimpleRAGService:
|
|
"""A RAG service without enhanced_search"""
|
|
async def search(self, query, knowledge_base_ids=None, top_k=5):
|
|
return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}]
|
|
|
|
rag = SimpleRAGService()
|
|
semantic = SemanticMemory(
|
|
rag_service=rag,
|
|
knowledge_base_ids=["kb-1"],
|
|
search_mode="enhanced",
|
|
)
|
|
|
|
items = await semantic.search("test query", top_k=3)
|
|
|
|
assert len(items) == 1
|
|
assert items[0].value == "simple result"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config enhanced search tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfigEnhancedSearch:
|
|
"""ServerConfig 解析 enhanced search 相关配置"""
|
|
|
|
def test_config_search_mode(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://geo:8000/api/knowledge",
|
|
"api_key": "sk-test",
|
|
"knowledge_base_ids": ["kb-1"],
|
|
"search_mode": "enhanced",
|
|
},
|
|
},
|
|
}
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["semantic"]["search_mode"] == "enhanced"
|
|
|
|
def test_config_use_rerank(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
data = {
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://geo:8000/api/knowledge",
|
|
"api_key": "sk-test",
|
|
"knowledge_base_ids": ["kb-1"],
|
|
"use_rerank": False,
|
|
"use_compression": True,
|
|
},
|
|
},
|
|
}
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.memory["semantic"]["use_rerank"] is False
|
|
assert config.memory["semantic"]["use_compression"] is True
|