fischer-agentkit/src/agentkit/memory/http_rag.py

329 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, TypeAlias
import httpx
from agentkit.memory.base import MetadataDict
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
# 标准化检索结果id/content/score/source/document_id/document_title/
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
RAGSearchResult: TypeAlias = dict[str, object]
# ingest() 写入的文档负载title/content/source_type/metadata。
RAGIngestPayload: TypeAlias = dict[str, object]
class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API
适配任意提供以下接口的知识库服务:
- POST {base_url}/search → 语义检索
- POST {base_url}/ingest → 文档写入(可选)
典型配置agentkit.yaml::
memory:
semantic:
enabled: true
base_url: "http://localhost:8000/api/knowledge"
api_key: "${GEO_API_KEY}"
knowledge_base_ids:
- "industry-kb-id"
- "enterprise-kb-id"
timeout: 30
contextual_chunking: false
"""
def __init__(
self,
base_url: str,
api_key: str | None = None,
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
contextual_chunking: bool = False,
llm_gateway: LLMGateway | None = None,
):
"""
Args:
base_url: 知识库 API 基础地址,如 http://localhost:8000/api/knowledge
api_key: 认证 API Key放在 Authorization: Bearer 头)
knowledge_base_ids: 默认检索的知识库 ID 列表
timeout: HTTP 请求超时秒数
"""
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._knowledge_base_ids = knowledge_base_ids or []
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
self._contextual_chunking = contextual_chunking
self._llm_gateway = llm_gateway
def _get_client(self) -> httpx.AsyncClient:
"""懒初始化 httpx 客户端"""
if self._client is None or self._client.is_closed:
headers: dict[str, str] = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers=headers,
timeout=self._timeout,
)
return self._client
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
) -> list[RAGSearchResult]:
"""语义检索知识库
Args:
query: 检索查询
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
top_k: 返回结果数量
Returns:
检索结果列表,每项包含 content/score/document_id 等字段
"""
kb_ids = knowledge_base_ids or self._knowledge_base_ids
payload = {
"query": query,
"knowledge_base_ids": kb_ids,
"top_k": top_k,
}
client = self._get_client()
try:
resp = await client.post("/search", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容两种响应格式:
# 1. {"results": [...]} — GEO 标准 SearchResponse
# 2. [...] — 直接返回列表
if isinstance(data, dict) and "results" in data:
results = data["results"]
elif isinstance(data, list):
results = data
else:
logger.warning(f"Unexpected search response format: {type(data)}")
return []
# 标准化为 SemanticMemory 期望的格式
normalized = []
for r in results:
if isinstance(r, dict):
normalized.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
"source": r.get("source", "rag"),
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}),
}
)
return normalized
except httpx.HTTPStatusError as e:
logger.error(
f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}")
return []
except Exception as e:
logger.error(f"RAG search unexpected error: {e}")
return []
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
use_rerank: bool = True,
use_compression: bool = False,
) -> list[RAGSearchResult]:
"""增强语义检索知识库(支持 rerank 和 compression
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
合并结果后按 score 降序返回 top_k 条。
Args:
query: 检索查询
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
top_k: 返回结果数量
use_rerank: 是否启用 rerank 重排序
use_compression: 是否启用上下文压缩
Returns:
检索结果列表,每项包含 content/score/document_id 等字段
"""
kb_ids = knowledge_base_ids or self._knowledge_base_ids
if not kb_ids:
return []
payload = {
"query": query,
"top_k": top_k,
"use_rerank": use_rerank,
"use_compression": use_compression,
}
client = self._get_client()
all_results: list[RAGSearchResult] = []
for kb_id in kb_ids:
try:
resp = await client.post(f"/bases/{kb_id}/retrieve", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容两种响应格式
if isinstance(data, dict) and "results" in data:
results = data["results"]
elif isinstance(data, list):
results = data
else:
logger.warning(f"Unexpected enhanced_search response format: {type(data)}")
continue
# 标准化
for r in results:
if isinstance(r, dict):
all_results.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
"source": r.get("source", "rag"),
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"knowledge_base_id": kb_id,
"metadata": r.get("metadata", {}),
}
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
# This KB doesn't support enhanced search — fall back to
# standard search for THIS KB only, not all KBs.
logger.info(
f"Enhanced search not available for KB {kb_id}, using standard search"
)
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
all_results.extend(std_result)
else:
logger.error(
f"RAG enhanced_search HTTP error for KB {kb_id}: "
f"{e.response.status_code}{e.response.text[:200]}"
)
raise
except httpx.RequestError as e:
logger.error(f"RAG enhanced_search request error for KB {kb_id}: {e}")
raise
except Exception as e:
logger.error(f"RAG enhanced_search unexpected error for KB {kb_id}: {e}")
raise
# 按 score 降序排序,返回 top_k
all_results.sort(key=lambda x: x["score"], reverse=True)
return all_results[:top_k]
async def ingest(
self,
key: str,
value: object,
metadata: MetadataDict | None = None,
) -> dict[str, object] | None:
"""写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured,
the document content is enhanced with contextual prefixes before ingestion.
Args:
key: 文档标题或标识
value: 文档内容
metadata: 额外元数据
Returns:
写入结果,或 None 表示写入不可用
"""
kb_ids = self._knowledge_base_ids
if not kb_ids:
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
return None
content = str(value)
# Apply contextual chunking if enabled
if self._contextual_chunking and self._llm_gateway:
from agentkit.memory.contextual_retrieval import ContextualChunker
chunker = ContextualChunker(llm_gateway=self._llm_gateway)
# Simple chunking: split by paragraphs
raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
if raw_chunks:
enhanced = await chunker.enhance_chunks(
document=content, chunks=raw_chunks, metadata=metadata
)
# Rejoin enhanced chunks
content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
payload = {
"title": key,
"content": content,
"source_type": "text",
"metadata": metadata or {},
}
client = self._get_client()
try:
# 写入到第一个配置的知识库
kb_id = kb_ids[0]
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
return None
except Exception as e:
logger.error(f"RAG ingest error: {e}")
return None
async def health_check(self) -> bool:
"""检查知识库服务是否可用"""
client = self._get_client()
try:
resp = await client.get("/bases")
return resp.status_code in (200, 401) # 401 = 服务在但需认证
except Exception:
return False
async def close(self) -> None:
"""关闭 HTTP 客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "HttpRAGService":
return self
async def __aexit__(self, *args: object) -> None:
await self.close()