feat(memory): add HttpRAGService for config-driven knowledge base integration
This commit is contained in:
parent
0456429beb
commit
cd5b39087e
|
|
@ -313,9 +313,12 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
try:
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
from agentkit.memory.working import WorkingMemory
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.http_rag import HttpRAGService
|
||||
|
||||
working = None
|
||||
episodic = None
|
||||
semantic = None
|
||||
|
||||
if config.memory.get("working", {}).get("enabled"):
|
||||
import redis.asyncio as aioredis
|
||||
|
|
@ -328,9 +331,23 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
# Will be initialized externally when DB is available
|
||||
pass
|
||||
|
||||
if config.memory.get("semantic", {}).get("enabled"):
|
||||
sem_conf = config.memory["semantic"]
|
||||
rag_service = HttpRAGService(
|
||||
base_url=sem_conf["base_url"],
|
||||
api_key=sem_conf.get("api_key"),
|
||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||
timeout=sem_conf.get("timeout", 30),
|
||||
)
|
||||
semantic = SemanticMemory(
|
||||
rag_service=rag_service,
|
||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||
)
|
||||
|
||||
self._memory_retriever = MemoryRetriever(
|
||||
working_memory=working,
|
||||
episodic_memory=episodic,
|
||||
semantic_memory=semantic,
|
||||
)
|
||||
|
||||
# Inject into BaseAgent
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from agentkit.memory.base import Memory, MemoryItem
|
|||
from agentkit.memory.working import WorkingMemory
|
||||
from agentkit.memory.episodic import EpisodicMemory
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.http_rag import HttpRAGService
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -12,5 +13,6 @@ __all__ = [
|
|||
"WorkingMemory",
|
||||
"EpisodicMemory",
|
||||
"SemanticMemory",
|
||||
"HttpRAGService",
|
||||
"MemoryRetriever",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
|
||||
|
||||
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRAGService:
|
||||
"""HTTP 客户端,调用业务系统的知识库检索 API
|
||||
|
||||
适配任意提供以下接口的知识库服务:
|
||||
- POST {base_url}/search → 语义检索
|
||||
- POST {base_url}/ingest → 文档写入(可选)
|
||||
|
||||
典型配置(agentkit.yaml)::
|
||||
|
||||
memory:
|
||||
semantic:
|
||||
enabled: true
|
||||
base_url: "http://localhost:8000/api/knowledge"
|
||||
api_key: "${GEO_API_KEY}"
|
||||
knowledge_base_ids:
|
||||
- "industry-kb-id"
|
||||
- "enterprise-kb-id"
|
||||
timeout: 30
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
knowledge_base_ids: list[str] | None = None,
|
||||
timeout: int = 30,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
base_url: 知识库 API 基础地址,如 http://localhost:8000/api/knowledge
|
||||
api_key: 认证 API Key(放在 Authorization: Bearer 头)
|
||||
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||||
timeout: HTTP 请求超时秒数
|
||||
"""
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._knowledge_base_ids = knowledge_base_ids or []
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""懒初始化 httpx 客户端"""
|
||||
if self._client is None or self._client.is_closed:
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
headers=headers,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_ids: list[str] | None = None,
|
||||
top_k: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""语义检索知识库
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
检索结果列表,每项包含 content/score/document_id 等字段
|
||||
"""
|
||||
kb_ids = knowledge_base_ids or self._knowledge_base_ids
|
||||
payload = {
|
||||
"query": query,
|
||||
"knowledge_base_ids": kb_ids,
|
||||
"top_k": top_k,
|
||||
}
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
resp = await client.post("/search", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# 兼容两种响应格式:
|
||||
# 1. {"results": [...]} — GEO 标准 SearchResponse
|
||||
# 2. [...] — 直接返回列表
|
||||
if isinstance(data, dict) and "results" in data:
|
||||
results = data["results"]
|
||||
elif isinstance(data, list):
|
||||
results = data
|
||||
else:
|
||||
logger.warning(f"Unexpected search response format: {type(data)}")
|
||||
return []
|
||||
|
||||
# 标准化为 SemanticMemory 期望的格式
|
||||
normalized = []
|
||||
for r in results:
|
||||
if isinstance(r, dict):
|
||||
normalized.append({
|
||||
"id": r.get("chunk_id", r.get("id", "")),
|
||||
"content": r.get("content", ""),
|
||||
"score": float(r.get("score", 0.0)),
|
||||
"source": r.get("source", "rag"),
|
||||
"document_id": r.get("document_id", ""),
|
||||
"document_title": r.get("document_title", ""),
|
||||
"metadata": r.get("metadata", {}),
|
||||
})
|
||||
return normalized
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}")
|
||||
return []
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"RAG search request error: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"RAG search unexpected error: {e}")
|
||||
return []
|
||||
|
||||
async def ingest(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""写入文档到知识库(可选操作)
|
||||
|
||||
Args:
|
||||
key: 文档标题或标识
|
||||
value: 文档内容
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
写入结果,或 None 表示写入不可用
|
||||
"""
|
||||
kb_ids = self._knowledge_base_ids
|
||||
if not kb_ids:
|
||||
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"title": key,
|
||||
"content": str(value),
|
||||
"source_type": "text",
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
# 写入到第一个配置的知识库
|
||||
kb_id = kb_ids[0]
|
||||
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"RAG ingest error: {e}")
|
||||
return None
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""检查知识库服务是否可用"""
|
||||
client = self._get_client()
|
||||
try:
|
||||
resp = await client.get("/bases")
|
||||
return resp.status_code in (200, 401) # 401 = 服务在但需认证
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 HTTP 客户端"""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def __aenter__(self) -> "HttpRAGService":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
|
@ -165,15 +165,37 @@ def create_app(
|
|||
try:
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
from agentkit.memory.working import WorkingMemory
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.http_rag import HttpRAGService
|
||||
|
||||
working = None
|
||||
episodic = None
|
||||
semantic = None
|
||||
|
||||
if server_config.memory.get("working", {}).get("enabled"):
|
||||
import redis.asyncio as aioredis
|
||||
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||
working = WorkingMemory(redis=redis_client)
|
||||
|
||||
memory_retriever = MemoryRetriever(working_memory=working)
|
||||
if server_config.memory.get("semantic", {}).get("enabled"):
|
||||
sem_conf = server_config.memory["semantic"]
|
||||
rag_service = HttpRAGService(
|
||||
base_url=sem_conf["base_url"],
|
||||
api_key=sem_conf.get("api_key"),
|
||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||
timeout=sem_conf.get("timeout", 30),
|
||||
)
|
||||
semantic = SemanticMemory(
|
||||
rag_service=rag_service,
|
||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||
)
|
||||
|
||||
memory_retriever = MemoryRetriever(
|
||||
working_memory=working,
|
||||
episodic_memory=episodic,
|
||||
semantic_memory=semantic,
|
||||
)
|
||||
app.state.memory_retriever = memory_retriever
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class ServerConfig:
|
|||
log_format: str = "text",
|
||||
task_store: dict[str, Any] | None = None,
|
||||
cors_origins: list[str] | None = None,
|
||||
memory: dict[str, Any] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
|
@ -75,6 +76,7 @@ class ServerConfig:
|
|||
self.log_format = log_format
|
||||
self.task_store = task_store or {}
|
||||
self.cors_origins = cors_origins or ["*"]
|
||||
self.memory = memory or {}
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str) -> "ServerConfig":
|
||||
|
|
@ -95,6 +97,7 @@ class ServerConfig:
|
|||
skills_data = data.get("skills", {})
|
||||
logging_data = data.get("logging", {})
|
||||
task_store_data = data.get("task_store", {})
|
||||
memory_data = data.get("memory", {})
|
||||
|
||||
# Build LLMConfig
|
||||
llm_config = cls._build_llm_config(llm_data)
|
||||
|
|
@ -116,6 +119,7 @@ class ServerConfig:
|
|||
log_format=logging_data.get("format", "text"),
|
||||
task_store=task_store_data,
|
||||
cors_origins=server.get("cors_origins"),
|
||||
memory=memory_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,472 @@
|
|||
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from agentkit.memory.http_rag import HttpRAGService
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HttpRAGService unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHttpRAGServiceInit:
|
||||
"""HttpRAGService 初始化"""
|
||||
|
||||
def test_basic_init(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
assert svc._base_url == "http://localhost:8000/api/knowledge"
|
||||
assert svc._api_key is None
|
||||
assert svc._knowledge_base_ids == []
|
||||
assert svc._timeout == 30
|
||||
|
||||
def test_init_with_all_params(self):
|
||||
svc = HttpRAGService(
|
||||
base_url="http://geo:8000/api/knowledge/",
|
||||
api_key="sk-test",
|
||||
knowledge_base_ids=["kb-1", "kb-2"],
|
||||
timeout=60,
|
||||
)
|
||||
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
|
||||
assert svc._api_key == "sk-test"
|
||||
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
|
||||
assert svc._timeout == 60
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
svc = HttpRAGService(base_url="http://host/api/")
|
||||
assert svc._base_url == "http://host/api"
|
||||
|
||||
|
||||
class TestHttpRAGServiceSearch:
|
||||
"""HttpRAGService.search — 语义检索"""
|
||||
|
||||
@pytest.fixture
|
||||
def svc(self):
|
||||
return HttpRAGService(
|
||||
base_url="http://localhost:8000/api/knowledge",
|
||||
api_key="test-key",
|
||||
knowledge_base_ids=["kb-industry", "kb-enterprise"],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_standard_response(self, svc):
|
||||
"""标准 SearchResponse 格式: {"results": [...]}"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{
|
||||
"chunk_id": "c1",
|
||||
"content": "AI 行业趋势分析",
|
||||
"score": 0.92,
|
||||
"document_id": "d1",
|
||||
"document_title": "行业报告",
|
||||
"metadata": {"page": 1},
|
||||
},
|
||||
{
|
||||
"chunk_id": "c2",
|
||||
"content": "企业数字化转型",
|
||||
"score": 0.85,
|
||||
"document_id": "d2",
|
||||
"document_title": "企业案例",
|
||||
"metadata": {},
|
||||
},
|
||||
],
|
||||
"total": 2,
|
||||
"latency_ms": 50,
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
results = await svc.search("AI 趋势", top_k=5)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "c1"
|
||||
assert results[0]["content"] == "AI 行业趋势分析"
|
||||
assert results[0]["score"] == 0.92
|
||||
assert results[0]["document_id"] == "d1"
|
||||
assert results[1]["content"] == "企业数字化转型"
|
||||
|
||||
# Verify payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "/search"
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["query"] == "AI 趋势"
|
||||
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
|
||||
assert payload["top_k"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_list_response(self, svc):
|
||||
"""直接返回列表格式: [...]"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = [
|
||||
{"chunk_id": "c1", "content": "test", "score": 0.8},
|
||||
]
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
results = await svc.search("test")
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_custom_kb_ids(self, svc):
|
||||
"""传入自定义 knowledge_base_ids 覆盖默认值"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"results": []}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
await svc.search("test", knowledge_base_ids=["custom-kb"])
|
||||
|
||||
payload = mock_client.post.call_args[1]["json"]
|
||||
assert payload["knowledge_base_ids"] == ["custom-kb"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_http_error_returns_empty(self, svc):
|
||||
"""HTTP 错误返回空列表"""
|
||||
import httpx
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_resp.text = "Internal Server Error"
|
||||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"500", request=MagicMock(), response=mock_resp
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
results = await svc.search("test")
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_connection_error_returns_empty(self, svc):
|
||||
"""连接错误返回空列表"""
|
||||
import httpx
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
results = await svc.search("test")
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_unexpected_format_returns_empty(self, svc):
|
||||
"""非预期响应格式返回空列表"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"error": "something"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
results = await svc.search("test")
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestHttpRAGServiceIngest:
|
||||
"""HttpRAGService.ingest — 文档写入"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_success(self):
|
||||
svc = HttpRAGService(
|
||||
base_url="http://localhost:8000/api/knowledge",
|
||||
knowledge_base_ids=["kb-1"],
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 201
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
result = await svc.ingest("测试文档", "文档内容")
|
||||
assert result["id"] == "doc-1"
|
||||
|
||||
# Verify endpoint and payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "/bases/kb-1/documents"
|
||||
payload = call_args[1]["json"]
|
||||
assert payload["title"] == "测试文档"
|
||||
assert payload["content"] == "文档内容"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_no_kb_ids_returns_none(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
result = await svc.ingest("test", "content")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_http_error_returns_none(self):
|
||||
import httpx
|
||||
svc = HttpRAGService(
|
||||
base_url="http://localhost:8000/api/knowledge",
|
||||
knowledge_base_ids=["kb-1"],
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"500", request=MagicMock(), response=mock_resp
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
result = await svc.ingest("test", "content")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestHttpRAGServiceHealthCheck:
|
||||
"""HttpRAGService.health_check"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_ok(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
assert await svc.health_check() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_401_still_healthy(self):
|
||||
"""401 表示服务在运行,只是需要认证"""
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 401
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
assert await svc.health_check() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_connection_error(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||||
svc._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
assert await svc.health_check() is False
|
||||
|
||||
|
||||
class TestHttpRAGServiceClient:
|
||||
"""HttpRAGService HTTP 客户端管理"""
|
||||
|
||||
def test_client_lazy_init(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
|
||||
assert svc._client is None
|
||||
|
||||
client = svc._get_client()
|
||||
assert client is not None
|
||||
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
|
||||
|
||||
def test_client_reuse(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
c1 = svc._get_client()
|
||||
c2 = svc._get_client()
|
||||
assert c1 is c2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
svc._get_client() # init client
|
||||
await svc.close()
|
||||
assert svc._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager(self):
|
||||
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||
async with svc as s:
|
||||
s._get_client()
|
||||
assert s._client is not None
|
||||
assert svc._client is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SemanticMemory + HttpRAGService integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSemanticMemoryWithHttpRAG:
|
||||
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_delegates_to_rag_service(self):
|
||||
"""SemanticMemory.search 委托给 HttpRAGService.search"""
|
||||
rag = HttpRAGService(
|
||||
base_url="http://localhost:8000/api/knowledge",
|
||||
knowledge_base_ids=["kb-1"],
|
||||
)
|
||||
|
||||
# Mock the search
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
rag._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||||
items = await semantic.search("行业趋势", top_k=3)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].key == "c1"
|
||||
assert items[0].value == "行业知识"
|
||||
assert items[0].score == 0.9
|
||||
assert items[0].metadata["source"] == "rag"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_rag_service_returns_empty(self):
|
||||
"""无 RAG 服务时返回空列表"""
|
||||
semantic = SemanticMemory()
|
||||
items = await semantic.search("test")
|
||||
assert items == []
|
||||
|
||||
|
||||
class TestMemoryRetrieverWithSemantic:
|
||||
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_queries_semantic_layer(self):
|
||||
"""MemoryRetriever 查询 Semantic 层并融合结果"""
|
||||
rag = HttpRAGService(
|
||||
base_url="http://localhost:8000/api/knowledge",
|
||||
knowledge_base_ids=["kb-1"],
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
rag._get_client = MagicMock(return_value=mock_client)
|
||||
|
||||
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||||
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||
|
||||
items = await retriever.retrieve("知识查询", top_k=3)
|
||||
|
||||
assert len(items) >= 1
|
||||
# Semantic weight is 0.4 by default
|
||||
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config-driven integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestServerConfigMemorySemantic:
|
||||
"""ServerConfig 解析 memory.semantic 配置"""
|
||||
|
||||
def test_from_dict_with_semantic(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
data = {
|
||||
"memory": {
|
||||
"semantic": {
|
||||
"enabled": True,
|
||||
"base_url": "http://geo:8000/api/knowledge",
|
||||
"api_key": "sk-test",
|
||||
"knowledge_base_ids": ["kb-1", "kb-2"],
|
||||
"timeout": 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
config = ServerConfig.from_dict(data)
|
||||
assert config.memory["semantic"]["enabled"] is True
|
||||
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
|
||||
assert config.memory["semantic"]["api_key"] == "sk-test"
|
||||
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
|
||||
assert config.memory["semantic"]["timeout"] == 60
|
||||
|
||||
def test_from_dict_without_memory(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.memory == {}
|
||||
|
||||
def test_from_yaml_with_env_var_resolution(self):
|
||||
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
|
||||
from agentkit.server.config import ServerConfig
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
|
||||
|
||||
yaml_content = """
|
||||
memory:
|
||||
semantic:
|
||||
enabled: true
|
||||
base_url: http://geo:8000/api/knowledge
|
||||
api_key: "${TEST_GEO_API_KEY}"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
assert config.memory["semantic"]["api_key"] == "sk-from-env"
|
||||
del os.environ["TEST_GEO_API_KEY"]
|
||||
|
||||
def test_from_yaml_with_default_env_var(self):
|
||||
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
|
||||
from agentkit.server.config import ServerConfig
|
||||
import tempfile
|
||||
|
||||
yaml_content = """
|
||||
memory:
|
||||
semantic:
|
||||
enabled: true
|
||||
base_url: http://geo:8000/api/knowledge
|
||||
api_key: "${NONEXISTENT_KEY:-sk-default}"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
assert config.memory["semantic"]["api_key"] == "sk-default"
|
||||
Loading…
Reference in New Issue