feat(memory): add HttpRAGService for config-driven knowledge base integration

This commit is contained in:
chiguyong 2026-06-06 18:36:05 +08:00
parent 0456429beb
commit cd5b39087e
6 changed files with 711 additions and 1 deletions

View File

@ -313,9 +313,12 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
try:
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.working import WorkingMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
working = None
episodic = None
semantic = None
if config.memory.get("working", {}).get("enabled"):
import redis.asyncio as aioredis
@ -328,9 +331,23 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
# Will be initialized externally when DB is available
pass
if config.memory.get("semantic", {}).get("enabled"):
sem_conf = config.memory["semantic"]
rag_service = HttpRAGService(
base_url=sem_conf["base_url"],
api_key=sem_conf.get("api_key"),
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
timeout=sem_conf.get("timeout", 30),
)
semantic = SemanticMemory(
rag_service=rag_service,
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
)
self._memory_retriever = MemoryRetriever(
working_memory=working,
episodic_memory=episodic,
semantic_memory=semantic,
)
# Inject into BaseAgent

View File

@ -4,6 +4,7 @@ from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.retriever import MemoryRetriever
__all__ = [
@ -12,5 +13,6 @@ __all__ = [
"WorkingMemory",
"EpisodicMemory",
"SemanticMemory",
"HttpRAGService",
"MemoryRetriever",
]

View File

@ -0,0 +1,193 @@
"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
"""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API
适配任意提供以下接口的知识库服务
- POST {base_url}/search 语义检索
- POST {base_url}/ingest 文档写入可选
典型配置agentkit.yaml::
memory:
semantic:
enabled: true
base_url: "http://localhost:8000/api/knowledge"
api_key: "${GEO_API_KEY}"
knowledge_base_ids:
- "industry-kb-id"
- "enterprise-kb-id"
timeout: 30
"""
def __init__(
self,
base_url: str,
api_key: str | None = None,
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
):
"""
Args:
base_url: 知识库 API 基础地址 http://localhost:8000/api/knowledge
api_key: 认证 API Key放在 Authorization: Bearer
knowledge_base_ids: 默认检索的知识库 ID 列表
timeout: HTTP 请求超时秒数
"""
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._knowledge_base_ids = knowledge_base_ids or []
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
"""懒初始化 httpx 客户端"""
if self._client is None or self._client.is_closed:
headers: dict[str, str] = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers=headers,
timeout=self._timeout,
)
return self._client
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
) -> list[dict[str, Any]]:
"""语义检索知识库
Args:
query: 检索查询
knowledge_base_ids: 知识库 ID 列表默认使用配置值
top_k: 返回结果数量
Returns:
检索结果列表每项包含 content/score/document_id 等字段
"""
kb_ids = knowledge_base_ids or self._knowledge_base_ids
payload = {
"query": query,
"knowledge_base_ids": kb_ids,
"top_k": top_k,
}
client = self._get_client()
try:
resp = await client.post("/search", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容两种响应格式:
# 1. {"results": [...]} — GEO 标准 SearchResponse
# 2. [...] — 直接返回列表
if isinstance(data, dict) and "results" in data:
results = data["results"]
elif isinstance(data, list):
results = data
else:
logger.warning(f"Unexpected search response format: {type(data)}")
return []
# 标准化为 SemanticMemory 期望的格式
normalized = []
for r in results:
if isinstance(r, dict):
normalized.append({
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
"source": r.get("source", "rag"),
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}),
})
return normalized
except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}")
return []
except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}")
return []
except Exception as e:
logger.error(f"RAG search unexpected error: {e}")
return []
async def ingest(
self,
key: str,
value: Any,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
"""写入文档到知识库(可选操作)
Args:
key: 文档标题或标识
value: 文档内容
metadata: 额外元数据
Returns:
写入结果 None 表示写入不可用
"""
kb_ids = self._knowledge_base_ids
if not kb_ids:
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
return None
payload = {
"title": key,
"content": str(value),
"source_type": "text",
"metadata": metadata or {},
}
client = self._get_client()
try:
# 写入到第一个配置的知识库
kb_id = kb_ids[0]
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
return None
except Exception as e:
logger.error(f"RAG ingest error: {e}")
return None
async def health_check(self) -> bool:
"""检查知识库服务是否可用"""
client = self._get_client()
try:
resp = await client.get("/bases")
return resp.status_code in (200, 401) # 401 = 服务在但需认证
except Exception:
return False
async def close(self) -> None:
"""关闭 HTTP 客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "HttpRAGService":
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()

View File

@ -165,15 +165,37 @@ def create_app(
try:
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.working import WorkingMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.http_rag import HttpRAGService
working = None
episodic = None
semantic = None
if server_config.memory.get("working", {}).get("enabled"):
import redis.asyncio as aioredis
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
redis_client = aioredis.from_url(redis_url, decode_responses=True)
working = WorkingMemory(redis=redis_client)
memory_retriever = MemoryRetriever(working_memory=working)
if server_config.memory.get("semantic", {}).get("enabled"):
sem_conf = server_config.memory["semantic"]
rag_service = HttpRAGService(
base_url=sem_conf["base_url"],
api_key=sem_conf.get("api_key"),
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
timeout=sem_conf.get("timeout", 30),
)
semantic = SemanticMemory(
rag_service=rag_service,
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
)
memory_retriever = MemoryRetriever(
working_memory=working,
episodic_memory=episodic,
semantic_memory=semantic,
)
app.state.memory_retriever = memory_retriever
except Exception as e:
import logging

View File

@ -62,6 +62,7 @@ class ServerConfig:
log_format: str = "text",
task_store: dict[str, Any] | None = None,
cors_origins: list[str] | None = None,
memory: dict[str, Any] | None = None,
):
self.host = host
self.port = port
@ -75,6 +76,7 @@ class ServerConfig:
self.log_format = log_format
self.task_store = task_store or {}
self.cors_origins = cors_origins or ["*"]
self.memory = memory or {}
@classmethod
def from_yaml(cls, path: str) -> "ServerConfig":
@ -95,6 +97,7 @@ class ServerConfig:
skills_data = data.get("skills", {})
logging_data = data.get("logging", {})
task_store_data = data.get("task_store", {})
memory_data = data.get("memory", {})
# Build LLMConfig
llm_config = cls._build_llm_config(llm_data)
@ -116,6 +119,7 @@ class ServerConfig:
log_format=logging_data.get("format", "text"),
task_store=task_store_data,
cors_origins=server.get("cors_origins"),
memory=memory_data,
)
@staticmethod

View File

@ -0,0 +1,472 @@
"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.http_rag import HttpRAGService
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.retriever import MemoryRetriever
# ---------------------------------------------------------------------------
# HttpRAGService unit tests
# ---------------------------------------------------------------------------
class TestHttpRAGServiceInit:
"""HttpRAGService 初始化"""
def test_basic_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
assert svc._base_url == "http://localhost:8000/api/knowledge"
assert svc._api_key is None
assert svc._knowledge_base_ids == []
assert svc._timeout == 30
def test_init_with_all_params(self):
svc = HttpRAGService(
base_url="http://geo:8000/api/knowledge/",
api_key="sk-test",
knowledge_base_ids=["kb-1", "kb-2"],
timeout=60,
)
assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped
assert svc._api_key == "sk-test"
assert svc._knowledge_base_ids == ["kb-1", "kb-2"]
assert svc._timeout == 60
def test_trailing_slash_stripped(self):
svc = HttpRAGService(base_url="http://host/api/")
assert svc._base_url == "http://host/api"
class TestHttpRAGServiceSearch:
"""HttpRAGService.search — 语义检索"""
@pytest.fixture
def svc(self):
return HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
api_key="test-key",
knowledge_base_ids=["kb-industry", "kb-enterprise"],
)
@pytest.mark.asyncio
async def test_search_standard_response(self, svc):
"""标准 SearchResponse 格式: {"results": [...]}"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{
"chunk_id": "c1",
"content": "AI 行业趋势分析",
"score": 0.92,
"document_id": "d1",
"document_title": "行业报告",
"metadata": {"page": 1},
},
{
"chunk_id": "c2",
"content": "企业数字化转型",
"score": 0.85,
"document_id": "d2",
"document_title": "企业案例",
"metadata": {},
},
],
"total": 2,
"latency_ms": 50,
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("AI 趋势", top_k=5)
assert len(results) == 2
assert results[0]["id"] == "c1"
assert results[0]["content"] == "AI 行业趋势分析"
assert results[0]["score"] == 0.92
assert results[0]["document_id"] == "d1"
assert results[1]["content"] == "企业数字化转型"
# Verify payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/search"
payload = call_args[1]["json"]
assert payload["query"] == "AI 趋势"
assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"]
assert payload["top_k"] == 5
@pytest.mark.asyncio
async def test_search_list_response(self, svc):
"""直接返回列表格式: [...]"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = [
{"chunk_id": "c1", "content": "test", "score": 0.8},
]
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert len(results) == 1
assert results[0]["content"] == "test"
@pytest.mark.asyncio
async def test_search_custom_kb_ids(self, svc):
"""传入自定义 knowledge_base_ids 覆盖默认值"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
await svc.search("test", knowledge_base_ids=["custom-kb"])
payload = mock_client.post.call_args[1]["json"]
assert payload["knowledge_base_ids"] == ["custom-kb"]
@pytest.mark.asyncio
async def test_search_http_error_returns_empty(self, svc):
"""HTTP 错误返回空列表"""
import httpx
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_connection_error_returns_empty(self, svc):
"""连接错误返回空列表"""
import httpx
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
@pytest.mark.asyncio
async def test_search_unexpected_format_returns_empty(self, svc):
"""非预期响应格式返回空列表"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"error": "something"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
results = await svc.search("test")
assert results == []
class TestHttpRAGServiceIngest:
"""HttpRAGService.ingest — 文档写入"""
@pytest.mark.asyncio
async def test_ingest_success(self):
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 201
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"id": "doc-1", "status": "processing"}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("测试文档", "文档内容")
assert result["id"] == "doc-1"
# Verify endpoint and payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "/bases/kb-1/documents"
payload = call_args[1]["json"]
assert payload["title"] == "测试文档"
assert payload["content"] == "文档内容"
@pytest.mark.asyncio
async def test_ingest_no_kb_ids_returns_none(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
result = await svc.ingest("test", "content")
assert result is None
@pytest.mark.asyncio
async def test_ingest_http_error_returns_none(self):
import httpx
svc = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"500", request=MagicMock(), response=mock_resp
)
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
result = await svc.ingest("test", "content")
assert result is None
class TestHttpRAGServiceHealthCheck:
"""HttpRAGService.health_check"""
@pytest.mark.asyncio
async def test_health_check_ok(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_401_still_healthy(self):
"""401 表示服务在运行,只是需要认证"""
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is True
@pytest.mark.asyncio
async def test_health_check_connection_error(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
svc._get_client = MagicMock(return_value=mock_client)
assert await svc.health_check() is False
class TestHttpRAGServiceClient:
"""HttpRAGService HTTP 客户端管理"""
def test_client_lazy_init(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test")
assert svc._client is None
client = svc._get_client()
assert client is not None
assert "Bearer sk-test" in str(client.headers.get("Authorization", ""))
def test_client_reuse(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
c1 = svc._get_client()
c2 = svc._get_client()
assert c1 is c2
@pytest.mark.asyncio
async def test_close(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
svc._get_client() # init client
await svc.close()
assert svc._client is None
@pytest.mark.asyncio
async def test_context_manager(self):
svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge")
async with svc as s:
s._get_client()
assert s._client is not None
assert svc._client is None
# ---------------------------------------------------------------------------
# SemanticMemory + HttpRAGService integration
# ---------------------------------------------------------------------------
class TestSemanticMemoryWithHttpRAG:
"""SemanticMemory 通过 HttpRAGService 检索知识库"""
@pytest.mark.asyncio
async def test_search_delegates_to_rag_service(self):
"""SemanticMemory.search 委托给 HttpRAGService.search"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
# Mock the search
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
items = await semantic.search("行业趋势", top_k=3)
assert len(items) == 1
assert items[0].key == "c1"
assert items[0].value == "行业知识"
assert items[0].score == 0.9
assert items[0].metadata["source"] == "rag"
@pytest.mark.asyncio
async def test_search_no_rag_service_returns_empty(self):
"""无 RAG 服务时返回空列表"""
semantic = SemanticMemory()
items = await semantic.search("test")
assert items == []
class TestMemoryRetrieverWithSemantic:
"""MemoryRetriever 集成 SemanticMemory + HttpRAGService"""
@pytest.mark.asyncio
async def test_retriever_queries_semantic_layer(self):
"""MemoryRetriever 查询 Semantic 层并融合结果"""
rag = HttpRAGService(
base_url="http://localhost:8000/api/knowledge",
knowledge_base_ids=["kb-1"],
)
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"},
]
}
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_resp)
rag._get_client = MagicMock(return_value=mock_client)
semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"])
retriever = MemoryRetriever(semantic_memory=semantic)
items = await retriever.retrieve("知识查询", top_k=3)
assert len(items) >= 1
# Semantic weight is 0.4 by default
assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01)
# ---------------------------------------------------------------------------
# Config-driven integration tests
# ---------------------------------------------------------------------------
class TestServerConfigMemorySemantic:
"""ServerConfig 解析 memory.semantic 配置"""
def test_from_dict_with_semantic(self):
from agentkit.server.config import ServerConfig
data = {
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://geo:8000/api/knowledge",
"api_key": "sk-test",
"knowledge_base_ids": ["kb-1", "kb-2"],
"timeout": 60,
},
},
}
config = ServerConfig.from_dict(data)
assert config.memory["semantic"]["enabled"] is True
assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge"
assert config.memory["semantic"]["api_key"] == "sk-test"
assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"]
assert config.memory["semantic"]["timeout"] == 60
def test_from_dict_without_memory(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.memory == {}
def test_from_yaml_with_env_var_resolution(self):
"""验证 from_yaml 路径的 ${VAR:-default} 环境变量解析"""
from agentkit.server.config import ServerConfig
import os
import tempfile
os.environ["TEST_GEO_API_KEY"] = "sk-from-env"
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${TEST_GEO_API_KEY}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-from-env"
del os.environ["TEST_GEO_API_KEY"]
def test_from_yaml_with_default_env_var(self):
"""验证 from_yaml 路径的 ${VAR:-default} 带默认值"""
from agentkit.server.config import ServerConfig
import tempfile
yaml_content = """
memory:
semantic:
enabled: true
base_url: http://geo:8000/api/knowledge
api_key: "${NONEXISTENT_KEY:-sk-default}"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
config = ServerConfig.from_yaml(f.name)
assert config.memory["semantic"]["api_key"] == "sk-default"