329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""HTTP RAG Service - 通过 HTTP 调用业务系统知识库 API
|
||
|
||
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import TYPE_CHECKING, TypeAlias
|
||
|
||
import httpx
|
||
|
||
from agentkit.memory.base import MetadataDict
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.llm.gateway import LLMGateway
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 标准化检索结果:id/content/score/source/document_id/document_title/
|
||
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
|
||
RAGSearchResult: TypeAlias = dict[str, object]
|
||
# ingest() 写入的文档负载:title/content/source_type/metadata。
|
||
RAGIngestPayload: TypeAlias = dict[str, object]
|
||
|
||
|
||
class HttpRAGService:
|
||
"""HTTP 客户端,调用业务系统的知识库检索 API
|
||
|
||
适配任意提供以下接口的知识库服务:
|
||
- POST {base_url}/search → 语义检索
|
||
- POST {base_url}/ingest → 文档写入(可选)
|
||
|
||
典型配置(agentkit.yaml)::
|
||
|
||
memory:
|
||
semantic:
|
||
enabled: true
|
||
base_url: "http://localhost:8000/api/knowledge"
|
||
api_key: "${GEO_API_KEY}"
|
||
knowledge_base_ids:
|
||
- "industry-kb-id"
|
||
- "enterprise-kb-id"
|
||
timeout: 30
|
||
contextual_chunking: false
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
api_key: str | None = None,
|
||
knowledge_base_ids: list[str] | None = None,
|
||
timeout: int = 30,
|
||
contextual_chunking: bool = False,
|
||
llm_gateway: LLMGateway | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
base_url: 知识库 API 基础地址,如 http://localhost:8000/api/knowledge
|
||
api_key: 认证 API Key(放在 Authorization: Bearer 头)
|
||
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||
timeout: HTTP 请求超时秒数
|
||
"""
|
||
self._base_url = base_url.rstrip("/")
|
||
self._api_key = api_key
|
||
self._knowledge_base_ids = knowledge_base_ids or []
|
||
self._timeout = timeout
|
||
self._client: httpx.AsyncClient | None = None
|
||
self._contextual_chunking = contextual_chunking
|
||
self._llm_gateway = llm_gateway
|
||
|
||
def _get_client(self) -> httpx.AsyncClient:
|
||
"""懒初始化 httpx 客户端"""
|
||
if self._client is None or self._client.is_closed:
|
||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||
if self._api_key:
|
||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||
self._client = httpx.AsyncClient(
|
||
base_url=self._base_url,
|
||
headers=headers,
|
||
timeout=self._timeout,
|
||
)
|
||
return self._client
|
||
|
||
async def search(
|
||
self,
|
||
query: str,
|
||
knowledge_base_ids: list[str] | None = None,
|
||
top_k: int = 5,
|
||
) -> list[RAGSearchResult]:
|
||
"""语义检索知识库
|
||
|
||
Args:
|
||
query: 检索查询
|
||
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
|
||
top_k: 返回结果数量
|
||
|
||
Returns:
|
||
检索结果列表,每项包含 content/score/document_id 等字段
|
||
"""
|
||
kb_ids = knowledge_base_ids or self._knowledge_base_ids
|
||
payload = {
|
||
"query": query,
|
||
"knowledge_base_ids": kb_ids,
|
||
"top_k": top_k,
|
||
}
|
||
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.post("/search", json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
# 兼容两种响应格式:
|
||
# 1. {"results": [...]} — GEO 标准 SearchResponse
|
||
# 2. [...] — 直接返回列表
|
||
if isinstance(data, dict) and "results" in data:
|
||
results = data["results"]
|
||
elif isinstance(data, list):
|
||
results = data
|
||
else:
|
||
logger.warning(f"Unexpected search response format: {type(data)}")
|
||
return []
|
||
|
||
# 标准化为 SemanticMemory 期望的格式
|
||
normalized = []
|
||
for r in results:
|
||
if isinstance(r, dict):
|
||
normalized.append(
|
||
{
|
||
"id": r.get("chunk_id", r.get("id", "")),
|
||
"content": r.get("content", ""),
|
||
"score": float(r.get("score", 0.0)),
|
||
"source": r.get("source", "rag"),
|
||
"document_id": r.get("document_id", ""),
|
||
"document_title": r.get("document_title", ""),
|
||
"metadata": r.get("metadata", {}),
|
||
}
|
||
)
|
||
return normalized
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(
|
||
f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||
)
|
||
return []
|
||
except httpx.RequestError as e:
|
||
logger.error(f"RAG search request error: {e}")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"RAG search unexpected error: {e}")
|
||
return []
|
||
|
||
async def enhanced_search(
|
||
self,
|
||
query: str,
|
||
knowledge_base_ids: list[str] | None = None,
|
||
top_k: int = 5,
|
||
use_rerank: bool = True,
|
||
use_compression: bool = False,
|
||
) -> list[RAGSearchResult]:
|
||
"""增强语义检索知识库(支持 rerank 和 compression)
|
||
|
||
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
|
||
合并结果后按 score 降序返回 top_k 条。
|
||
|
||
Args:
|
||
query: 检索查询
|
||
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
|
||
top_k: 返回结果数量
|
||
use_rerank: 是否启用 rerank 重排序
|
||
use_compression: 是否启用上下文压缩
|
||
|
||
Returns:
|
||
检索结果列表,每项包含 content/score/document_id 等字段
|
||
"""
|
||
kb_ids = knowledge_base_ids or self._knowledge_base_ids
|
||
if not kb_ids:
|
||
return []
|
||
|
||
payload = {
|
||
"query": query,
|
||
"top_k": top_k,
|
||
"use_rerank": use_rerank,
|
||
"use_compression": use_compression,
|
||
}
|
||
|
||
client = self._get_client()
|
||
all_results: list[RAGSearchResult] = []
|
||
|
||
for kb_id in kb_ids:
|
||
try:
|
||
resp = await client.post(f"/bases/{kb_id}/retrieve", json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
# 兼容两种响应格式
|
||
if isinstance(data, dict) and "results" in data:
|
||
results = data["results"]
|
||
elif isinstance(data, list):
|
||
results = data
|
||
else:
|
||
logger.warning(f"Unexpected enhanced_search response format: {type(data)}")
|
||
continue
|
||
|
||
# 标准化
|
||
for r in results:
|
||
if isinstance(r, dict):
|
||
all_results.append(
|
||
{
|
||
"id": r.get("chunk_id", r.get("id", "")),
|
||
"content": r.get("content", ""),
|
||
"score": float(r.get("score", 0.0)),
|
||
"source": r.get("source", "rag"),
|
||
"document_id": r.get("document_id", ""),
|
||
"document_title": r.get("document_title", ""),
|
||
"knowledge_base_id": kb_id,
|
||
"metadata": r.get("metadata", {}),
|
||
}
|
||
)
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
if e.response.status_code == 404:
|
||
# This KB doesn't support enhanced search — fall back to
|
||
# standard search for THIS KB only, not all KBs.
|
||
logger.info(
|
||
f"Enhanced search not available for KB {kb_id}, using standard search"
|
||
)
|
||
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
|
||
all_results.extend(std_result)
|
||
else:
|
||
logger.error(
|
||
f"RAG enhanced_search HTTP error for KB {kb_id}: "
|
||
f"{e.response.status_code} — {e.response.text[:200]}"
|
||
)
|
||
raise
|
||
except httpx.RequestError as e:
|
||
logger.error(f"RAG enhanced_search request error for KB {kb_id}: {e}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"RAG enhanced_search unexpected error for KB {kb_id}: {e}")
|
||
raise
|
||
|
||
# 按 score 降序排序,返回 top_k
|
||
all_results.sort(key=lambda x: x["score"], reverse=True)
|
||
return all_results[:top_k]
|
||
|
||
async def ingest(
|
||
self,
|
||
key: str,
|
||
value: object,
|
||
metadata: MetadataDict | None = None,
|
||
) -> dict[str, object] | None:
|
||
"""写入文档到知识库(可选操作)
|
||
|
||
When contextual_chunking is enabled and llm_gateway is configured,
|
||
the document content is enhanced with contextual prefixes before ingestion.
|
||
|
||
Args:
|
||
key: 文档标题或标识
|
||
value: 文档内容
|
||
metadata: 额外元数据
|
||
|
||
Returns:
|
||
写入结果,或 None 表示写入不可用
|
||
"""
|
||
kb_ids = self._knowledge_base_ids
|
||
if not kb_ids:
|
||
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
||
return None
|
||
|
||
content = str(value)
|
||
|
||
# Apply contextual chunking if enabled
|
||
if self._contextual_chunking and self._llm_gateway:
|
||
from agentkit.memory.contextual_retrieval import ContextualChunker
|
||
|
||
chunker = ContextualChunker(llm_gateway=self._llm_gateway)
|
||
# Simple chunking: split by paragraphs
|
||
raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
|
||
if raw_chunks:
|
||
enhanced = await chunker.enhance_chunks(
|
||
document=content, chunks=raw_chunks, metadata=metadata
|
||
)
|
||
# Rejoin enhanced chunks
|
||
content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
|
||
|
||
payload = {
|
||
"title": key,
|
||
"content": content,
|
||
"source_type": "text",
|
||
"metadata": metadata or {},
|
||
}
|
||
|
||
client = self._get_client()
|
||
try:
|
||
# 写入到第一个配置的知识库
|
||
kb_id = kb_ids[0]
|
||
resp = await client.post(f"/bases/{kb_id}/documents", json=payload)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(f"RAG ingest HTTP error: {e.response.status_code}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"RAG ingest error: {e}")
|
||
return None
|
||
|
||
async def health_check(self) -> bool:
|
||
"""检查知识库服务是否可用"""
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.get("/bases")
|
||
return resp.status_code in (200, 401) # 401 = 服务在但需认证
|
||
except Exception:
|
||
return False
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP 客户端"""
|
||
if self._client and not self._client.is_closed:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
async def __aenter__(self) -> "HttpRAGService":
|
||
return self
|
||
|
||
async def __aexit__(self, *args: object) -> None:
|
||
await self.close()
|