"""GenericHTTPAdapter - 通用 HTTP 知识库适配器 配置 API endpoint + auth 即可对接任意 HTTP 知识库服务。 实现 KnowledgeBase 协议,提供统一的检索接口。 """ from __future__ import annotations import ipaddress import logging from typing import Any from urllib.parse import urlparse import httpx from agentkit.memory.adapters.base import KBAdapter from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo logger = logging.getLogger(__name__) def _is_safe_url(url: str) -> bool: """Check if URL is safe (not pointing to private/internal networks).""" try: parsed = urlparse(url) if parsed.scheme not in ("http", "https"): return False hostname = parsed.hostname if not hostname: return False # Block common internal hostnames if hostname in ("localhost", "metadata.google.internal", "metadata.internal"): return False # Try to resolve as IP and check for private ranges try: ip = ipaddress.ip_address(hostname) if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local: return False except ValueError: # Not an IP address, that's OK (it's a domain name) pass return True except Exception: return False class GenericHTTPAdapter(KBAdapter): """通用 HTTP 知识库适配器 通过配置 API endpoint 和认证信息,对接任意提供 HTTP API 的知识库服务。 期望的 API 接口: - POST {endpoint_url}/search → 语义检索 - POST {endpoint_url}/ingest → 文档写入(可选) - GET {endpoint_url}/sources → 列出信息源(可选) - GET {endpoint_url}/health → 健康检查(可选) 典型配置:: adapter = GenericHTTPAdapter( endpoint_url="http://localhost:8000/api/knowledge", auth_config={"type": "bearer", "token": "sk-xxx"}, headers={"X-Custom-Header": "value"}, ) """ def __init__( self, endpoint_url: str, auth_config: dict[str, str] | None = None, headers: dict[str, str] | None = None, source_id: str | None = None, source_name: str = "HTTP Knowledge Base", timeout: int = 30, ): super().__init__( source_id=source_id or f"http-{endpoint_url.rstrip('/').split('/')[-1]}", source_name=source_name, source_type="generic_http", timeout=timeout, ) self._endpoint_url = endpoint_url.rstrip("/") if not _is_safe_url(self._endpoint_url): raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.") self._auth_config = auth_config or {} self._extra_headers = headers or {} def _make_client(self) -> httpx.AsyncClient: """创建通用 HTTP 客户端""" headers: dict[str, str] = { "Content-Type": "application/json", **self._extra_headers, } # 配置认证 auth_type = self._auth_config.get("type", "") if auth_type == "bearer": token = self._auth_config.get("token", "") if token: headers["Authorization"] = f"Bearer {token}" elif auth_type == "basic": import base64 username = self._auth_config.get("username", "") password = self._auth_config.get("password", "") if username and password: credentials = base64.b64encode( f"{username}:{password}".encode() ).decode() headers["Authorization"] = f"Basic {credentials}" elif auth_type == "api_key": key_name = self._auth_config.get("header_name", "X-API-Key") key_value = self._auth_config.get("api_key", "") if key_value: headers[key_name] = key_value return httpx.AsyncClient( base_url=self._endpoint_url, headers=headers, timeout=self._timeout, ) async def search(self, query: str, top_k: int = 5) -> list[QueryResult]: """搜索 HTTP 知识库 POST {endpoint_url}/search Body: {"query": ..., "top_k": ...} """ client = self._get_client() try: payload = {"query": query, "top_k": top_k} resp = await client.post("/search", json=payload) resp.raise_for_status() data = resp.json() # 兼容两种响应格式: # 1. {"results": [...]} # 2. [...] if isinstance(data, dict) and "results" in data: items = data["results"] elif isinstance(data, list): items = data else: logger.warning(f"Unexpected search response format: {type(data)}") return [] results: list[QueryResult] = [] for item in items: if isinstance(item, dict): results.append( QueryResult( content=item.get("content", ""), source_id=self._source_id, source_name=self._source_name, score=float(item.get("score", 0.0)), metadata=item.get("metadata", {}), doc_id=item.get("doc_id", item.get("id", "")), title=item.get("title", ""), ) ) return results[:top_k] except httpx.HTTPStatusError as e: logger.error( f"GenericHTTP search HTTP error: {e.response.status_code} — " f"{e.response.text[:200]}" ) return [] except Exception as e: logger.error(f"GenericHTTP search error: {e}") return [] async def ingest(self, documents: list[Document]) -> list[str]: """写入文档到 HTTP 知识库 POST {endpoint_url}/ingest Body: {"documents": [...]} """ client = self._get_client() try: payload = { "documents": [ { "doc_id": doc.doc_id, "content": doc.content, "title": doc.title, "source_id": doc.source_id, "metadata": doc.metadata, } for doc in documents ] } resp = await client.post("/ingest", json=payload) resp.raise_for_status() data = resp.json() # 兼容响应格式 if isinstance(data, dict) and "ids" in data: return data["ids"] elif isinstance(data, list): return [str(item) for item in data] else: return [doc.doc_id for doc in documents] except httpx.HTTPStatusError as e: logger.error( f"GenericHTTP ingest HTTP error: {e.response.status_code} — " f"{e.response.text[:200]}" ) return [] except Exception as e: logger.error(f"GenericHTTP ingest error: {e}") return [] async def delete_by_id(self, id: str) -> bool: """按文档 ID 删除 DELETE {endpoint_url}/documents/{id} """ client = self._get_client() try: resp = await client.delete(f"/documents/{id}") if resp.status_code in (200, 204): return True return False except Exception as e: logger.error(f"GenericHTTP delete_by_id error: {e}") return False async def get_document(self, doc_id: str) -> Document | None: """获取单个文档 GET {endpoint_url}/documents/{doc_id} """ client = self._get_client() try: resp = await client.get(f"/documents/{doc_id}") resp.raise_for_status() data = resp.json() return Document( doc_id=data.get("doc_id", data.get("id", doc_id)), content=data.get("content", ""), title=data.get("title", ""), source_id=data.get("source_id", self._source_id), metadata=data.get("metadata", {}), ) except Exception as e: logger.error(f"GenericHTTP get_document error: {e}") return None async def list_sources(self) -> list[SourceInfo]: """列出信息源 GET {endpoint_url}/sources """ client = self._get_client() try: resp = await client.get("/sources") resp.raise_for_status() data = resp.json() if isinstance(data, list): sources: list[SourceInfo] = [] for item in data: if isinstance(item, dict): sources.append( SourceInfo( source_id=item.get("source_id", ""), source_name=item.get("source_name", ""), source_type=item.get("source_type", "generic_http"), document_count=item.get("document_count", 0), ) ) return sources if sources else [ SourceInfo( source_id=self._source_id, source_name=self._source_name, source_type=self._source_type, ) ] except Exception as e: logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}") return [ SourceInfo( source_id=self._source_id, source_name=self._source_name, source_type=self._source_type, ) ] async def health_check(self) -> bool: """检查 HTTP 知识库服务连接状态""" client = self._get_client() try: resp = await client.get("/health") if resp.status_code == 200: return True except Exception: pass # Fallback: try the base endpoint try: resp = await client.get("/") return resp.status_code in (200, 401, 403) except Exception: return False async def authenticate(self) -> bool: """认证验证 对于 GenericHTTPAdapter,认证通过 health_check 验证。 """ try: self._authenticated = await self.health_check() except Exception: self._authenticated = False return self._authenticated