feat(memory): add HttpRAGService for config-driven knowledge base integration
This commit is contained in:
parent
0456429beb
commit
cd5b39087e
|
|
@ -313,9 +313,12 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
try:
|
try:
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
from agentkit.memory.working import WorkingMemory
|
from agentkit.memory.working import WorkingMemory
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.http_rag import HttpRAGService
|
||||||
|
|
||||||
working = None
|
working = None
|
||||||
episodic = None
|
episodic = None
|
||||||
|
semantic = None
|
||||||
|
|
||||||
if config.memory.get("working", {}).get("enabled"):
|
if config.memory.get("working", {}).get("enabled"):
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
@ -328,9 +331,23 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
# Will be initialized externally when DB is available
|
# Will be initialized externally when DB is available
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if config.memory.get("semantic", {}).get("enabled"):
|
||||||
|
sem_conf = config.memory["semantic"]
|
||||||
|
rag_service = HttpRAGService(
|
||||||
|
base_url=sem_conf["base_url"],
|
||||||
|
api_key=sem_conf.get("api_key"),
|
||||||
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
timeout=sem_conf.get("timeout", 30),
|
||||||
|
)
|
||||||
|
semantic = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
)
|
||||||
|
|
||||||
self._memory_retriever = MemoryRetriever(
|
self._memory_retriever = MemoryRetriever(
|
||||||
working_memory=working,
|
working_memory=working,
|
||||||
episodic_memory=episodic,
|
episodic_memory=episodic,
|
||||||
|
semantic_memory=semantic,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inject into BaseAgent
|
# Inject into BaseAgent
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from agentkit.memory.base import Memory, MemoryItem
|
||||||
from agentkit.memory.working import WorkingMemory
|
from agentkit.memory.working import WorkingMemory
|
||||||
from agentkit.memory.episodic import EpisodicMemory
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.semantic import SemanticMemory
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.http_rag import HttpRAGService
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -12,5 +13,6 @@ __all__ = [
|
||||||
"WorkingMemory",
|
"WorkingMemory",
|
||||||
"EpisodicMemory",
|
"EpisodicMemory",
|
||||||
"SemanticMemory",
|
"SemanticMemory",
|
||||||
|
"HttpRAGService",
|
||||||
"MemoryRetriever",
|
"MemoryRetriever",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
|
||||||
|
|
||||||
|
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpRAGService:
|
||||||
|
"""HTTP 客户端,调用业务系统的知识库检索 API
|
||||||
|
|
||||||
|
适配任意提供以下接口的知识库服务:
|
||||||
|
- POST {base_url}/search → 语义检索
|
||||||
|
- POST {base_url}/ingest → 文档写入(可选)
|
||||||
|
|
||||||
|
典型配置(agentkit.yaml)::
|
||||||
|
|
||||||
|
memory:
|
||||||
|
semantic:
|
||||||
|
enabled: true
|
||||||
|
base_url: "http://localhost:8000/api/knowledge"
|
||||||
|
api_key: "${GEO_API_KEY}"
|
||||||
|
knowledge_base_ids:
|
||||||
|
- "industry-kb-id"
|
||||||
|
- "enterprise-kb-id"
|
||||||
|
timeout: 30
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str | None = None,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
base_url: 知识库 API 基础地址,如 http://localhost:8000/api/knowledge
|
||||||
|
api_key: 认证 API Key(放在 Authorization: Bearer 头)
|
||||||
|
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||||||
|
timeout: HTTP 请求超时秒数
|
||||||
|
"""
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._api_key = api_key
|
||||||
|
self._knowledge_base_ids = knowledge_base_ids or []
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""懒初始化 httpx 客户端"""
|
||||||
|
if self._client is None or self._client.is_closed:
|
||||||
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
if self._api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self._base_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""语义检索知识库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 检索查询
|
||||||
|
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
|
||||||
|
top_k: 返回结果数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表,每项包含 content/score/document_id 等字段
|
||||||
|
"""
|
||||||
|
kb_ids = knowledge_base_ids or self._knowledge_base_ids
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"knowledge_base_ids": kb_ids,
|
||||||
|
"top_k": top_k,
|
||||||
|
}
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
try:
|
||||||
|
resp = await client.post("/search", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
# 兼容两种响应格式:
|
||||||
|
# 1. {"results": [...]} — GEO 标准 SearchResponse
|
||||||
|
# 2. [...] — 直接返回列表
|
||||||
|
if isinstance(data, dict) and "results" in data:
|
||||||
|
results = data["results"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
results = data
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected search response format: {type(data)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 标准化为 SemanticMemory 期望的格式
|
||||||
|
normalized = []
|
||||||
|
for r in results:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
normalized.append({
|
||||||
|
"id": r.get("chunk_id", r.get("id", "")),
|
||||||
|
"content": r.get("content", ""),
|
||||||
|
"score": float(r.get("score", 0.0)),
|
||||||
|
"source": r.get("source", "rag"),
|
||||||
|
"document_id": r.get("document_id", ""),
|
||||||
|
"document_title": r.get("document_title", ""),
|
||||||
|
"metadata": r.get("metadata", {}),
|
||||||
|
})
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}")
|
||||||
|
return []
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"RAG search request error: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAG search unexpected error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def ingest(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""写入文档到知识库(可选操作)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 文档标题或标识
|
||||||
|
value: 文档内容
|
||||||
|
metadata: 额外元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
写入结果,或 None 表示写入不可用
|
||||||
|
"""
|
||||||
|
kb_ids = self._knowledge_base_ids
|
||||||
|
if not kb_ids:
|
||||||
|
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"title": key,
|
||||||
|
"content": str(value),
|
||||||
|
"source_type": "text",
|
||||||
|
"metadata": metadata or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
try:
|
||||||
|
# 写入到第一个配置的知识库
|
||||||
|
kb_id = kb_ids[0]
|
||||||
|
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAG ingest error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""检查知识库服务是否可用"""
|
||||||
|
client = self._get_client()
|
||||||
|
try:
|
||||||
|
resp = await client.get("/bases")
|
||||||
|
return resp.status_code in (200, 401) # 401 = 服务在但需认证
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭 HTTP 客户端"""
|
||||||
|
if self._client and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "HttpRAGService":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args: Any) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
@ -165,15 +165,37 @@ def create_app(
|
||||||
try:
|
try:
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
from agentkit.memory.working import WorkingMemory
|
from agentkit.memory.working import WorkingMemory
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.http_rag import HttpRAGService
|
||||||
|
|
||||||
working = None
|
working = None
|
||||||
|
episodic = None
|
||||||
|
semantic = None
|
||||||
|
|
||||||
if server_config.memory.get("working", {}).get("enabled"):
|
if server_config.memory.get("working", {}).get("enabled"):
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||||
working = WorkingMemory(redis=redis_client)
|
working = WorkingMemory(redis=redis_client)
|
||||||
|
|
||||||
memory_retriever = MemoryRetriever(working_memory=working)
|
if server_config.memory.get("semantic", {}).get("enabled"):
|
||||||
|
sem_conf = server_config.memory["semantic"]
|
||||||
|
rag_service = HttpRAGService(
|
||||||
|
base_url=sem_conf["base_url"],
|
||||||
|
api_key=sem_conf.get("api_key"),
|
||||||
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
timeout=sem_conf.get("timeout", 30),
|
||||||
|
)
|
||||||
|
semantic = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_retriever = MemoryRetriever(
|
||||||
|
working_memory=working,
|
||||||
|
episodic_memory=episodic,
|
||||||
|
semantic_memory=semantic,
|
||||||
|
)
|
||||||
app.state.memory_retriever = memory_retriever
|
app.state.memory_retriever = memory_retriever
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ class ServerConfig:
|
||||||
log_format: str = "text",
|
log_format: str = "text",
|
||||||
task_store: dict[str, Any] | None = None,
|
task_store: dict[str, Any] | None = None,
|
||||||
cors_origins: list[str] | None = None,
|
cors_origins: list[str] | None = None,
|
||||||
|
memory: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
@ -75,6 +76,7 @@ class ServerConfig:
|
||||||
self.log_format = log_format
|
self.log_format = log_format
|
||||||
self.task_store = task_store or {}
|
self.task_store = task_store or {}
|
||||||
self.cors_origins = cors_origins or ["*"]
|
self.cors_origins = cors_origins or ["*"]
|
||||||
|
self.memory = memory or {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml(cls, path: str) -> "ServerConfig":
|
def from_yaml(cls, path: str) -> "ServerConfig":
|
||||||
|
|
@ -95,6 +97,7 @@ class ServerConfig:
|
||||||
skills_data = data.get("skills", {})
|
skills_data = data.get("skills", {})
|
||||||
logging_data = data.get("logging", {})
|
logging_data = data.get("logging", {})
|
||||||
task_store_data = data.get("task_store", {})
|
task_store_data = data.get("task_store", {})
|
||||||
|
memory_data = data.get("memory", {})
|
||||||
|
|
||||||
# Build LLMConfig
|
# Build LLMConfig
|
||||||
llm_config = cls._build_llm_config(llm_data)
|
llm_config = cls._build_llm_config(llm_data)
|
||||||
|
|
@ -116,6 +119,7 @@ class ServerConfig:
|
||||||
log_format=logging_data.get("format", "text"),
|
log_format=logging_data.get("format", "text"),
|
||||||
task_store=task_store_data,
|
task_store=task_store_data,
|
||||||
cors_origins=server.get("cors_origins"),
|
cors_origins=server.get("cors_origins"),
|
||||||
|
memory=memory_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,472 @@
|
||||||
|
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from agentkit.memory.http_rag import HttpRAGService
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HttpRAGService unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceInit:
|
||||||
|
"""HttpRAGService 初始化"""
|
||||||
|
|
||||||
|
def test_basic_init(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
assert svc._base_url == "http://localhost:8000/api/knowledge"
|
||||||
|
assert svc._api_key is None
|
||||||
|
assert svc._knowledge_base_ids == []
|
||||||
|
assert svc._timeout == 30
|
||||||
|
|
||||||
|
def test_init_with_all_params(self):
|
||||||
|
svc = HttpRAGService(
|
||||||
|
base_url="http://geo:8000/api/knowledge/",
|
||||||
|
api_key="sk-test",
|
||||||
|
knowledge_base_ids=["kb-1", "kb-2"],
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
|
||||||
|
assert svc._api_key == "sk-test"
|
||||||
|
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
|
||||||
|
assert svc._timeout == 60
|
||||||
|
|
||||||
|
def test_trailing_slash_stripped(self):
|
||||||
|
svc = HttpRAGService(base_url="http://host/api/")
|
||||||
|
assert svc._base_url == "http://host/api"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceSearch:
|
||||||
|
"""HttpRAGService.search — 语义检索"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def svc(self):
|
||||||
|
return HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
api_key="test-key",
|
||||||
|
knowledge_base_ids=["kb-industry", "kb-enterprise"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_standard_response(self, svc):
|
||||||
|
"""标准 SearchResponse 格式: {"results": [...]}"""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"chunk_id": "c1",
|
||||||
|
"content": "AI 行业趋势分析",
|
||||||
|
"score": 0.92,
|
||||||
|
"document_id": "d1",
|
||||||
|
"document_title": "行业报告",
|
||||||
|
"metadata": {"page": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "c2",
|
||||||
|
"content": "企业数字化转型",
|
||||||
|
"score": 0.85,
|
||||||
|
"document_id": "d2",
|
||||||
|
"document_title": "企业案例",
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"total": 2,
|
||||||
|
"latency_ms": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.search("AI 趋势", top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["id"] == "c1"
|
||||||
|
assert results[0]["content"] == "AI 行业趋势分析"
|
||||||
|
assert results[0]["score"] == 0.92
|
||||||
|
assert results[0]["document_id"] == "d1"
|
||||||
|
assert results[1]["content"] == "企业数字化转型"
|
||||||
|
|
||||||
|
# Verify payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[0][0] == "/search"
|
||||||
|
payload = call_args[1]["json"]
|
||||||
|
assert payload["query"] == "AI 趋势"
|
||||||
|
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
|
||||||
|
assert payload["top_k"] == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_list_response(self, svc):
|
||||||
|
"""直接返回列表格式: [...]"""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = [
|
||||||
|
{"chunk_id": "c1", "content": "test", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.search("test")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["content"] == "test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_custom_kb_ids(self, svc):
|
||||||
|
"""传入自定义 knowledge_base_ids 覆盖默认值"""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"results": []}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
await svc.search("test", knowledge_base_ids=["custom-kb"])
|
||||||
|
|
||||||
|
payload = mock_client.post.call_args[1]["json"]
|
||||||
|
assert payload["knowledge_base_ids"] == ["custom-kb"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_http_error_returns_empty(self, svc):
|
||||||
|
"""HTTP 错误返回空列表"""
|
||||||
|
import httpx
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 500
|
||||||
|
mock_resp.text = "Internal Server Error"
|
||||||
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"500", request=MagicMock(), response=mock_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.search("test")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_connection_error_returns_empty(self, svc):
|
||||||
|
"""连接错误返回空列表"""
|
||||||
|
import httpx
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.search("test")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_unexpected_format_returns_empty(self, svc):
|
||||||
|
"""非预期响应格式返回空列表"""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"error": "something"}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.search("test")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceIngest:
|
||||||
|
"""HttpRAGService.ingest — 文档写入"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_success(self):
|
||||||
|
svc = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 201
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
result = await svc.ingest("测试文档", "文档内容")
|
||||||
|
assert result["id"] == "doc-1"
|
||||||
|
|
||||||
|
# Verify endpoint and payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[0][0] == "/bases/kb-1/documents"
|
||||||
|
payload = call_args[1]["json"]
|
||||||
|
assert payload["title"] == "测试文档"
|
||||||
|
assert payload["content"] == "文档内容"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_no_kb_ids_returns_none(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
result = await svc.ingest("test", "content")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_http_error_returns_none(self):
|
||||||
|
import httpx
|
||||||
|
svc = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 500
|
||||||
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"500", request=MagicMock(), response=mock_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
result = await svc.ingest("test", "content")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceHealthCheck:
|
||||||
|
"""HttpRAGService.health_check"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_ok(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
assert await svc.health_check() is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_401_still_healthy(self):
|
||||||
|
"""401 表示服务在运行,只是需要认证"""
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 401
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
assert await svc.health_check() is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_connection_error(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
assert await svc.health_check() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceClient:
|
||||||
|
"""HttpRAGService HTTP 客户端管理"""
|
||||||
|
|
||||||
|
def test_client_lazy_init(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
|
||||||
|
assert svc._client is None
|
||||||
|
|
||||||
|
client = svc._get_client()
|
||||||
|
assert client is not None
|
||||||
|
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
|
||||||
|
|
||||||
|
def test_client_reuse(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
c1 = svc._get_client()
|
||||||
|
c2 = svc._get_client()
|
||||||
|
assert c1 is c2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
svc._get_client() # init client
|
||||||
|
await svc.close()
|
||||||
|
assert svc._client is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manager(self):
|
||||||
|
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
|
||||||
|
async with svc as s:
|
||||||
|
s._get_client()
|
||||||
|
assert s._client is not None
|
||||||
|
assert svc._client is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SemanticMemory + HttpRAGService integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticMemoryWithHttpRAG:
|
||||||
|
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_delegates_to_rag_service(self):
|
||||||
|
"""SemanticMemory.search 委托给 HttpRAGService.search"""
|
||||||
|
rag = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the search
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
rag._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||||||
|
items = await semantic.search("行业趋势", top_k=3)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].key == "c1"
|
||||||
|
assert items[0].value == "行业知识"
|
||||||
|
assert items[0].score == 0.9
|
||||||
|
assert items[0].metadata["source"] == "rag"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_no_rag_service_returns_empty(self):
|
||||||
|
"""无 RAG 服务时返回空列表"""
|
||||||
|
semantic = SemanticMemory()
|
||||||
|
items = await semantic.search("test")
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryRetrieverWithSemantic:
|
||||||
|
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retriever_queries_semantic_layer(self):
|
||||||
|
"""MemoryRetriever 查询 Semantic 层并融合结果"""
|
||||||
|
rag = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
rag._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
items = await retriever.retrieve("知识查询", top_k=3)
|
||||||
|
|
||||||
|
assert len(items) >= 1
|
||||||
|
# Semantic weight is 0.4 by default
|
||||||
|
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config-driven integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerConfigMemorySemantic:
|
||||||
|
"""ServerConfig 解析 memory.semantic 配置"""
|
||||||
|
|
||||||
|
def test_from_dict_with_semantic(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://geo:8000/api/knowledge",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"knowledge_base_ids": ["kb-1", "kb-2"],
|
||||||
|
"timeout": 60,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["semantic"]["enabled"] is True
|
||||||
|
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
|
||||||
|
assert config.memory["semantic"]["api_key"] == "sk-test"
|
||||||
|
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
|
||||||
|
assert config.memory["semantic"]["timeout"] == 60
|
||||||
|
|
||||||
|
def test_from_dict_without_memory(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict({})
|
||||||
|
assert config.memory == {}
|
||||||
|
|
||||||
|
def test_from_yaml_with_env_var_resolution(self):
|
||||||
|
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
|
||||||
|
|
||||||
|
yaml_content = """
|
||||||
|
memory:
|
||||||
|
semantic:
|
||||||
|
enabled: true
|
||||||
|
base_url: http://geo:8000/api/knowledge
|
||||||
|
api_key: "${TEST_GEO_API_KEY}"
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
f.flush()
|
||||||
|
config = ServerConfig.from_yaml(f.name)
|
||||||
|
|
||||||
|
assert config.memory["semantic"]["api_key"] == "sk-from-env"
|
||||||
|
del os.environ["TEST_GEO_API_KEY"]
|
||||||
|
|
||||||
|
def test_from_yaml_with_default_env_var(self):
|
||||||
|
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
yaml_content = """
|
||||||
|
memory:
|
||||||
|
semantic:
|
||||||
|
enabled: true
|
||||||
|
base_url: http://geo:8000/api/knowledge
|
||||||
|
api_key: "${NONEXISTENT_KEY:-sk-default}"
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
f.flush()
|
||||||
|
config = ServerConfig.from_yaml(f.name)
|
||||||
|
|
||||||
|
assert config.memory["semantic"]["api_key"] == "sk-default"
|
||||||
Loading…
Reference in New Issue