fischer-agentkit/src/agentkit/memory/adapters/generic_http.py

319 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""GenericHTTPAdapter - 通用 HTTP 知识库适配器
配置 API endpoint + auth 即可对接任意 HTTP 知识库服务。
实现 KnowledgeBase 协议,提供统一的检索接口。
"""
from __future__ import annotations
import ipaddress
import logging
from typing import Any
from urllib.parse import urlparse
import httpx
from agentkit.memory.adapters.base import KBAdapter
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
# Block common internal hostnames
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
# Try to resolve as IP and check for private ranges
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
# Not an IP address, that's OK (it's a domain name)
pass
return True
except Exception:
return False
class GenericHTTPAdapter(KBAdapter):
"""通用 HTTP 知识库适配器
通过配置 API endpoint 和认证信息,对接任意提供 HTTP API 的知识库服务。
期望的 API 接口:
- POST {endpoint_url}/search → 语义检索
- POST {endpoint_url}/ingest → 文档写入(可选)
- GET {endpoint_url}/sources → 列出信息源(可选)
- GET {endpoint_url}/health → 健康检查(可选)
典型配置::
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/knowledge",
auth_config={"type": "bearer", "token": "sk-xxx"},
headers={"X-Custom-Header": "value"},
)
"""
def __init__(
self,
endpoint_url: str,
auth_config: dict[str, str] | None = None,
headers: dict[str, str] | None = None,
source_id: str | None = None,
source_name: str = "HTTP Knowledge Base",
timeout: int = 30,
):
super().__init__(
source_id=source_id or f"http-{endpoint_url.rstrip('/').split('/')[-1]}",
source_name=source_name,
source_type="generic_http",
timeout=timeout,
)
self._endpoint_url = endpoint_url.rstrip("/")
if not _is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
self._auth_config = auth_config or {}
self._extra_headers = headers or {}
def _make_client(self) -> httpx.AsyncClient:
"""创建通用 HTTP 客户端"""
headers: dict[str, str] = {
"Content-Type": "application/json",
**self._extra_headers,
}
# 配置认证
auth_type = self._auth_config.get("type", "")
if auth_type == "bearer":
token = self._auth_config.get("token", "")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic":
import base64
username = self._auth_config.get("username", "")
password = self._auth_config.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
elif auth_type == "api_key":
key_name = self._auth_config.get("header_name", "X-API-Key")
key_value = self._auth_config.get("api_key", "")
if key_value:
headers[key_name] = key_value
return httpx.AsyncClient(
base_url=self._endpoint_url,
headers=headers,
timeout=self._timeout,
)
async def search(self, query: str, top_k: int = 5) -> list[QueryResult]:
"""搜索 HTTP 知识库
POST {endpoint_url}/search
Body: {"query": ..., "top_k": ...}
"""
client = self._get_client()
try:
payload = {"query": query, "top_k": top_k}
resp = await client.post("/search", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容两种响应格式:
# 1. {"results": [...]}
# 2. [...]
if isinstance(data, dict) and "results" in data:
items = data["results"]
elif isinstance(data, list):
items = data
else:
logger.warning(f"Unexpected search response format: {type(data)}")
return []
results: list[QueryResult] = []
for item in items:
if isinstance(item, dict):
results.append(
QueryResult(
content=item.get("content", ""),
source_id=self._source_id,
source_name=self._source_name,
score=float(item.get("score", 0.0)),
metadata=item.get("metadata", {}),
doc_id=item.get("doc_id", item.get("id", "")),
title=item.get("title", ""),
)
)
return results[:top_k]
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
)
return []
except Exception as e:
logger.error(f"GenericHTTP search error: {e}")
return []
async def ingest(self, documents: list[Document]) -> list[str]:
"""写入文档到 HTTP 知识库
POST {endpoint_url}/ingest
Body: {"documents": [...]}
"""
client = self._get_client()
try:
payload = {
"documents": [
{
"doc_id": doc.doc_id,
"content": doc.content,
"title": doc.title,
"source_id": doc.source_id,
"metadata": doc.metadata,
}
for doc in documents
]
}
resp = await client.post("/ingest", json=payload)
resp.raise_for_status()
data = resp.json()
# 兼容响应格式
if isinstance(data, dict) and "ids" in data:
return data["ids"]
elif isinstance(data, list):
return [str(item) for item in data]
else:
return [doc.doc_id for doc in documents]
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP ingest HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
)
return []
except Exception as e:
logger.error(f"GenericHTTP ingest error: {e}")
return []
async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除
DELETE {endpoint_url}/documents/{id}
"""
client = self._get_client()
try:
resp = await client.delete(f"/documents/{id}")
if resp.status_code in (200, 204):
return True
return False
except Exception as e:
logger.error(f"GenericHTTP delete_by_id error: {e}")
return False
async def get_document(self, doc_id: str) -> Document | None:
"""获取单个文档
GET {endpoint_url}/documents/{doc_id}
"""
client = self._get_client()
try:
resp = await client.get(f"/documents/{doc_id}")
resp.raise_for_status()
data = resp.json()
return Document(
doc_id=data.get("doc_id", data.get("id", doc_id)),
content=data.get("content", ""),
title=data.get("title", ""),
source_id=data.get("source_id", self._source_id),
metadata=data.get("metadata", {}),
)
except Exception as e:
logger.error(f"GenericHTTP get_document error: {e}")
return None
async def list_sources(self) -> list[SourceInfo]:
"""列出信息源
GET {endpoint_url}/sources
"""
client = self._get_client()
try:
resp = await client.get("/sources")
resp.raise_for_status()
data = resp.json()
if isinstance(data, list):
sources: list[SourceInfo] = []
for item in data:
if isinstance(item, dict):
sources.append(
SourceInfo(
source_id=item.get("source_id", ""),
source_name=item.get("source_name", ""),
source_type=item.get("source_type", "generic_http"),
document_count=item.get("document_count", 0),
)
)
return sources if sources else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
except Exception as e:
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")
return [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
async def health_check(self) -> bool:
"""检查 HTTP 知识库服务连接状态"""
client = self._get_client()
try:
resp = await client.get("/health")
if resp.status_code == 200:
return True
except Exception:
pass
# Fallback: try the base endpoint
try:
resp = await client.get("/")
return resp.status_code in (200, 401, 403)
except Exception:
return False
async def authenticate(self) -> bool:
"""认证验证
对于 GenericHTTPAdapter认证通过 health_check 验证。
"""
try:
self._authenticated = await self.health_check()
except Exception:
self._authenticated = False
return self._authenticated