319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""GenericHTTPAdapter - 通用 HTTP 知识库适配器
|
||
|
||
配置 API endpoint + auth 即可对接任意 HTTP 知识库服务。
|
||
实现 KnowledgeBase 协议,提供统一的检索接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import ipaddress
|
||
import logging
|
||
from typing import Any
|
||
from urllib.parse import urlparse
|
||
|
||
import httpx
|
||
|
||
from agentkit.memory.adapters.base import KBAdapter
|
||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _is_safe_url(url: str) -> bool:
|
||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||
try:
|
||
parsed = urlparse(url)
|
||
if parsed.scheme not in ("http", "https"):
|
||
return False
|
||
hostname = parsed.hostname
|
||
if not hostname:
|
||
return False
|
||
# Block common internal hostnames
|
||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||
return False
|
||
# Try to resolve as IP and check for private ranges
|
||
try:
|
||
ip = ipaddress.ip_address(hostname)
|
||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||
return False
|
||
except ValueError:
|
||
# Not an IP address, that's OK (it's a domain name)
|
||
pass
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class GenericHTTPAdapter(KBAdapter):
|
||
"""通用 HTTP 知识库适配器
|
||
|
||
通过配置 API endpoint 和认证信息,对接任意提供 HTTP API 的知识库服务。
|
||
|
||
期望的 API 接口:
|
||
- POST {endpoint_url}/search → 语义检索
|
||
- POST {endpoint_url}/ingest → 文档写入(可选)
|
||
- GET {endpoint_url}/sources → 列出信息源(可选)
|
||
- GET {endpoint_url}/health → 健康检查(可选)
|
||
|
||
典型配置::
|
||
|
||
adapter = GenericHTTPAdapter(
|
||
endpoint_url="http://localhost:8000/api/knowledge",
|
||
auth_config={"type": "bearer", "token": "sk-xxx"},
|
||
headers={"X-Custom-Header": "value"},
|
||
)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
endpoint_url: str,
|
||
auth_config: dict[str, str] | None = None,
|
||
headers: dict[str, str] | None = None,
|
||
source_id: str | None = None,
|
||
source_name: str = "HTTP Knowledge Base",
|
||
timeout: int = 30,
|
||
):
|
||
super().__init__(
|
||
source_id=source_id or f"http-{endpoint_url.rstrip('/').split('/')[-1]}",
|
||
source_name=source_name,
|
||
source_type="generic_http",
|
||
timeout=timeout,
|
||
)
|
||
self._endpoint_url = endpoint_url.rstrip("/")
|
||
if not _is_safe_url(self._endpoint_url):
|
||
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
||
self._auth_config = auth_config or {}
|
||
self._extra_headers = headers or {}
|
||
|
||
def _make_client(self) -> httpx.AsyncClient:
|
||
"""创建通用 HTTP 客户端"""
|
||
headers: dict[str, str] = {
|
||
"Content-Type": "application/json",
|
||
**self._extra_headers,
|
||
}
|
||
|
||
# 配置认证
|
||
auth_type = self._auth_config.get("type", "")
|
||
if auth_type == "bearer":
|
||
token = self._auth_config.get("token", "")
|
||
if token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
elif auth_type == "basic":
|
||
import base64
|
||
username = self._auth_config.get("username", "")
|
||
password = self._auth_config.get("password", "")
|
||
if username and password:
|
||
credentials = base64.b64encode(
|
||
f"{username}:{password}".encode()
|
||
).decode()
|
||
headers["Authorization"] = f"Basic {credentials}"
|
||
elif auth_type == "api_key":
|
||
key_name = self._auth_config.get("header_name", "X-API-Key")
|
||
key_value = self._auth_config.get("api_key", "")
|
||
if key_value:
|
||
headers[key_name] = key_value
|
||
|
||
return httpx.AsyncClient(
|
||
base_url=self._endpoint_url,
|
||
headers=headers,
|
||
timeout=self._timeout,
|
||
)
|
||
|
||
async def search(self, query: str, top_k: int = 5) -> list[QueryResult]:
|
||
"""搜索 HTTP 知识库
|
||
|
||
POST {endpoint_url}/search
|
||
Body: {"query": ..., "top_k": ...}
|
||
"""
|
||
client = self._get_client()
|
||
try:
|
||
payload = {"query": query, "top_k": top_k}
|
||
resp = await client.post("/search", json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
# 兼容两种响应格式:
|
||
# 1. {"results": [...]}
|
||
# 2. [...]
|
||
if isinstance(data, dict) and "results" in data:
|
||
items = data["results"]
|
||
elif isinstance(data, list):
|
||
items = data
|
||
else:
|
||
logger.warning(f"Unexpected search response format: {type(data)}")
|
||
return []
|
||
|
||
results: list[QueryResult] = []
|
||
for item in items:
|
||
if isinstance(item, dict):
|
||
results.append(
|
||
QueryResult(
|
||
content=item.get("content", ""),
|
||
source_id=self._source_id,
|
||
source_name=self._source_name,
|
||
score=float(item.get("score", 0.0)),
|
||
metadata=item.get("metadata", {}),
|
||
doc_id=item.get("doc_id", item.get("id", "")),
|
||
title=item.get("title", ""),
|
||
)
|
||
)
|
||
return results[:top_k]
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(
|
||
f"GenericHTTP search HTTP error: {e.response.status_code} — "
|
||
f"{e.response.text[:200]}"
|
||
)
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"GenericHTTP search error: {e}")
|
||
return []
|
||
|
||
async def ingest(self, documents: list[Document]) -> list[str]:
|
||
"""写入文档到 HTTP 知识库
|
||
|
||
POST {endpoint_url}/ingest
|
||
Body: {"documents": [...]}
|
||
"""
|
||
client = self._get_client()
|
||
try:
|
||
payload = {
|
||
"documents": [
|
||
{
|
||
"doc_id": doc.doc_id,
|
||
"content": doc.content,
|
||
"title": doc.title,
|
||
"source_id": doc.source_id,
|
||
"metadata": doc.metadata,
|
||
}
|
||
for doc in documents
|
||
]
|
||
}
|
||
resp = await client.post("/ingest", json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
# 兼容响应格式
|
||
if isinstance(data, dict) and "ids" in data:
|
||
return data["ids"]
|
||
elif isinstance(data, list):
|
||
return [str(item) for item in data]
|
||
else:
|
||
return [doc.doc_id for doc in documents]
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(
|
||
f"GenericHTTP ingest HTTP error: {e.response.status_code} — "
|
||
f"{e.response.text[:200]}"
|
||
)
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"GenericHTTP ingest error: {e}")
|
||
return []
|
||
|
||
async def delete_by_id(self, id: str) -> bool:
|
||
"""按文档 ID 删除
|
||
|
||
DELETE {endpoint_url}/documents/{id}
|
||
"""
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.delete(f"/documents/{id}")
|
||
if resp.status_code in (200, 204):
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"GenericHTTP delete_by_id error: {e}")
|
||
return False
|
||
|
||
async def get_document(self, doc_id: str) -> Document | None:
|
||
"""获取单个文档
|
||
|
||
GET {endpoint_url}/documents/{doc_id}
|
||
"""
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.get(f"/documents/{doc_id}")
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
return Document(
|
||
doc_id=data.get("doc_id", data.get("id", doc_id)),
|
||
content=data.get("content", ""),
|
||
title=data.get("title", ""),
|
||
source_id=data.get("source_id", self._source_id),
|
||
metadata=data.get("metadata", {}),
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"GenericHTTP get_document error: {e}")
|
||
return None
|
||
|
||
async def list_sources(self) -> list[SourceInfo]:
|
||
"""列出信息源
|
||
|
||
GET {endpoint_url}/sources
|
||
"""
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.get("/sources")
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
if isinstance(data, list):
|
||
sources: list[SourceInfo] = []
|
||
for item in data:
|
||
if isinstance(item, dict):
|
||
sources.append(
|
||
SourceInfo(
|
||
source_id=item.get("source_id", ""),
|
||
source_name=item.get("source_name", ""),
|
||
source_type=item.get("source_type", "generic_http"),
|
||
document_count=item.get("document_count", 0),
|
||
)
|
||
)
|
||
return sources if sources else [
|
||
SourceInfo(
|
||
source_id=self._source_id,
|
||
source_name=self._source_name,
|
||
source_type=self._source_type,
|
||
)
|
||
]
|
||
except Exception as e:
|
||
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")
|
||
|
||
return [
|
||
SourceInfo(
|
||
source_id=self._source_id,
|
||
source_name=self._source_name,
|
||
source_type=self._source_type,
|
||
)
|
||
]
|
||
|
||
async def health_check(self) -> bool:
|
||
"""检查 HTTP 知识库服务连接状态"""
|
||
client = self._get_client()
|
||
try:
|
||
resp = await client.get("/health")
|
||
if resp.status_code == 200:
|
||
return True
|
||
except Exception:
|
||
pass
|
||
|
||
# Fallback: try the base endpoint
|
||
try:
|
||
resp = await client.get("/")
|
||
return resp.status_code in (200, 401, 403)
|
||
except Exception:
|
||
return False
|
||
|
||
async def authenticate(self) -> bool:
|
||
"""认证验证
|
||
|
||
对于 GenericHTTPAdapter,认证通过 health_check 验证。
|
||
"""
|
||
try:
|
||
self._authenticated = await self.health_check()
|
||
except Exception:
|
||
self._authenticated = False
|
||
return self._authenticated
|