From 1ccaf56b9a65f9d79810bffca7045978fe23eaa2 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 23:54:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20ce-simplify-code=20=E5=AE=A1?= =?UTF-8?q?=E6=9F=A5=E4=BF=AE=E5=A4=8D=20=E2=80=94=20=E5=8E=BB=E9=87=8D=20?= =?UTF-8?q?+=20=E6=95=88=E7=8E=87=20+=20=E6=AD=BB=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=B8=85=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3 个审查代理(复用/质量/效率)发现 15 个问题,全部修复: 效率与安全(6 项): - MCPClient 缓存 MultiServerMCPClient 单例 + aclose(),修复连接/子进程泄漏 - _rate_limits 清理空 IP 条目,修复 X-Forwarded-For 欺骗下内存泄漏 - _seen_nonces 改用 OrderedDict,O(1) 摊销过期清理 - webhook 后台任务加 Semaphore(20) + 任务引用追踪,限制无界并发 - _build_adapter 用 asyncio.gather 并行解密 secrets - 适配器实例缓存(_adapter_cache),token TTL 缓存跨请求命中 去重(4 项): - header_get 提取到 channels/base.py,4 个适配器统一 import - _get_client/close() 移入 MessageAdapter 基类,子类继承 - URLVerificationChallenge 统一到 base.py,feishu/slack/wecom 共用 - Transport ABC 添加 endpoint_url 属性,from_transport 不再访问私有字段 死代码与类型安全(5 项): - detect_cache_hit 死方法替换为 record_cache_result 公开 API - execution_mode.value == "direct_chat" 改用枚举比较 - 删除 yielded_any 死变量、重复 from fastapi import Request、 多余 getattr 防御 453 tests passed, ruff clean(预存 F841 非本次引入) --- src/agentkit/channels/base.py | 56 +++- src/agentkit/channels/dingtalk.py | 45 +-- src/agentkit/channels/feishu.py | 59 +--- src/agentkit/channels/slack.py | 56 +--- src/agentkit/channels/wecom.py | 49 +--- src/agentkit/llm/cache.py | 18 +- src/agentkit/llm/gateway.py | 6 +- .../llm/providers/litellm_provider.py | 8 +- src/agentkit/mcp/client.py | 36 ++- src/agentkit/mcp/server.py | 1 - src/agentkit/mcp/transport.py | 10 + src/agentkit/server/routes/channels.py | 270 ++++++++++++------ tests/unit/llm/test_cache.py | 41 +-- 13 files changed, 307 insertions(+), 348 deletions(-) diff --git a/src/agentkit/channels/base.py b/src/agentkit/channels/base.py index ebda830..450f307 100644 --- a/src/agentkit/channels/base.py +++ b/src/agentkit/channels/base.py @@ -11,6 +11,8 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any +import httpx + class ChannelType(str, Enum): """支持的消息平台渠道。""" @@ -51,17 +53,42 @@ class OutgoingMessage: reply_to_message_id: str | None = None +class URLVerificationChallenge(Exception): + """URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。 + + 飞书 / Slack 在配置 webhook 时发送一次 ``url_verification`` 事件, + 要求服务端原样返回 ``challenge`` 字段以验证 URL 可达。 + """ + + def __init__(self, challenge: str) -> None: + super().__init__(f"URL verification challenge: {challenge}") + self.challenge = challenge + + class MessageAdapter(abc.ABC): """消息适配器 ABC。 生命周期: __init__ → verify_signature() → receive_message() → send_message() → close() - 子类必须实现全部抽象方法。verify_signature 失败时调用方应拒绝处理 - (webhook 端点 fail-closed:Redis 不可用或签名校验失败均返回 503/401, - 不可跳过 nonce dedup 直接处理消息)。 + 子类必须实现 verify_signature / receive_message / send_message 抽象方法。 + ``_get_client`` / ``close`` 已在基类提供(懒构造 httpx 客户端 + 释放资源), + 子类只需在 ``__init__`` 中调用 ``super().__init__()``。 + + verify_signature 失败时调用方应拒绝处理(webhook 端点 fail-closed: + Redis 不可用或签名校验失败均返回 503/401,不可跳过 nonce dedup 直接处理消息)。 """ + def __init__(self) -> None: + # 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池 + self._client: httpx.AsyncClient | None = None + + def _get_client(self) -> httpx.AsyncClient: + """懒构造 httpx 客户端(子类共享此实现)。""" + if self._client is None: + self._client = httpx.AsyncClient(timeout=10.0) + return self._client + @abc.abstractmethod async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool: """验证平台签名/token。返回 True 表示请求可信。""" @@ -74,6 +101,25 @@ class MessageAdapter(abc.ABC): async def send_message(self, message: OutgoingMessage) -> bool: """向平台发送消息。返回 True 表示发送成功。""" - @abc.abstractmethod async def close(self) -> None: - """释放资源(HTTP 客户端、连接池等)。""" + """关闭 httpx 客户端(如已创建)。子类可覆盖以释放额外资源。""" + if self._client is not None: + await self._client.aclose() + self._client = None + + +# --------------------------------------------------------------------------- +# 辅助函数 +# --------------------------------------------------------------------------- + + +def header_get(headers: dict[str, str], name: str) -> str | None: + """大小写不敏感的 header 查找。""" + # 直接命中 + if name in headers: + return headers[name] + lower = name.lower() + for k, v in headers.items(): + if k.lower() == lower: + return v + return None diff --git a/src/agentkit/channels/dingtalk.py b/src/agentkit/channels/dingtalk.py index 33a43c9..9b657b4 100644 --- a/src/agentkit/channels/dingtalk.py +++ b/src/agentkit/channels/dingtalk.py @@ -28,6 +28,7 @@ from agentkit.channels.base import ( IncomingMessage, MessageAdapter, OutgoingMessage, + header_get, ) logger = logging.getLogger(__name__) @@ -68,25 +69,15 @@ class DingTalkMessageAdapter(MessageAdapter): robot_code: str, token: str | None = None, ) -> None: + super().__init__() self.app_key = app_key self.app_secret = app_secret self.robot_code = robot_code self.token = token - # 懒加载 httpx 客户端 - self._client: httpx.AsyncClient | None = None # ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存; # 升级路径:Redis 缓存共享给多实例。 self._token_cache: tuple[str, float] | None = None - # ------------------------------------------------------------------ - # httpx 客户端懒加载 - # ------------------------------------------------------------------ - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None: - self._client = httpx.AsyncClient(timeout=10.0) - return self._client - # ------------------------------------------------------------------ # 签名验证 # ------------------------------------------------------------------ @@ -107,12 +98,12 @@ class DingTalkMessageAdapter(MessageAdapter): """ # Token 校验(若配置) if self.token is not None: - token_header = _header_get(headers, "Token") + token_header = header_get(headers, "Token") if token_header is None or not hmac.compare_digest(token_header, self.token): return False - sign = _header_get(headers, "Sign") - timestamp_str = _header_get(headers, "Timestamp") + sign = header_get(headers, "Sign") + timestamp_str = header_get(headers, "Timestamp") if sign is None and timestamp_str is None: # 无签名头:仅当 token 已校验通过才放行 @@ -253,29 +244,3 @@ class DingTalkMessageAdapter(MessageAdapter): except httpx.HTTPError as exc: logger.error("钉钉 accessToken 网络错误: %s", exc) return None - - # ------------------------------------------------------------------ - # 资源释放 - # ------------------------------------------------------------------ - - async def close(self) -> None: - """关闭 httpx 客户端(如已创建)。""" - if self._client is not None: - await self._client.aclose() - self._client = None - - -# --------------------------------------------------------------------------- -# 辅助函数 -# --------------------------------------------------------------------------- - - -def _header_get(headers: dict[str, str], name: str) -> str | None: - """大小写不敏感的 header 查找。""" - if name in headers: - return headers[name] - lower = name.lower() - for k, v in headers.items(): - if k.lower() == lower: - return v - return None diff --git a/src/agentkit/channels/feishu.py b/src/agentkit/channels/feishu.py index 42d94f1..c9d4355 100644 --- a/src/agentkit/channels/feishu.py +++ b/src/agentkit/channels/feishu.py @@ -28,6 +28,8 @@ from agentkit.channels.base import ( IncomingMessage, MessageAdapter, OutgoingMessage, + URLVerificationChallenge, + header_get, ) logger = logging.getLogger(__name__) @@ -45,18 +47,6 @@ _SEND_MESSAGE_URL = "https://open.feishu.cn/open-apis/im/v1/messages" _MENTION_RE = re.compile(r"@_user_\d+\s*") -class URLVerificationChallenge(Exception): - """飞书 URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。 - - 飞书在配置 webhook 时会发送一次 ``url_verification`` 事件,要求服务端 - 原样返回 ``challenge`` 字段以验证 URL 可达。 - """ - - def __init__(self, challenge: str) -> None: - super().__init__(f"URL verification challenge: {challenge}") - self.challenge = challenge - - class SignatureVerificationError(Exception): """事件 ``verification_token`` 校验失败 — 拒绝处理。""" @@ -82,25 +72,15 @@ class FeishuMessageAdapter(MessageAdapter): encrypt_key: str | None = None, verification_token: str | None = None, ) -> None: + super().__init__() self.app_id = app_id self.app_secret = app_secret self.encrypt_key = encrypt_key self.verification_token = verification_token - # 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池 - self._client: httpx.AsyncClient | None = None # ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存; # 升级路径:Redis 缓存共享给多实例。 self._token_cache: tuple[str, float] | None = None - # ------------------------------------------------------------------ - # httpx 客户端懒加载 - # ------------------------------------------------------------------ - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None: - self._client = httpx.AsyncClient(timeout=10.0) - return self._client - # ------------------------------------------------------------------ # 签名验证 # ------------------------------------------------------------------ @@ -122,12 +102,12 @@ class FeishuMessageAdapter(MessageAdapter): logger.warning("飞书适配器未配置 encrypt_key — 拒绝所有 webhook 请求") return False - signature = _header_get(headers, "X-Lark-Signature") + signature = header_get(headers, "X-Lark-Signature") if not signature: return False - timestamp_str = _header_get(headers, "X-Lark-Request-Timestamp") - nonce = _header_get(headers, "X-Lark-Request-Nonce") + timestamp_str = header_get(headers, "X-Lark-Request-Timestamp") + nonce = header_get(headers, "X-Lark-Request-Nonce") if not timestamp_str or not nonce: return False @@ -335,30 +315,3 @@ class FeishuMessageAdapter(MessageAdapter): except httpx.HTTPError as exc: logger.error("飞书 tenant_token 网络错误: %s", exc) return None - - # ------------------------------------------------------------------ - # 资源释放 - # ------------------------------------------------------------------ - - async def close(self) -> None: - """关闭 httpx 客户端(如已创建)。""" - if self._client is not None: - await self._client.aclose() - self._client = None - - -# --------------------------------------------------------------------------- -# 辅助函数 -# --------------------------------------------------------------------------- - - -def _header_get(headers: dict[str, str], name: str) -> str | None: - """大小写不敏感的 header 查找。""" - # 直接命中 - if name in headers: - return headers[name] - lower = name.lower() - for k, v in headers.items(): - if k.lower() == lower: - return v - return None diff --git a/src/agentkit/channels/slack.py b/src/agentkit/channels/slack.py index e9cf6fd..2087aee 100644 --- a/src/agentkit/channels/slack.py +++ b/src/agentkit/channels/slack.py @@ -30,6 +30,8 @@ from agentkit.channels.base import ( IncomingMessage, MessageAdapter, OutgoingMessage, + URLVerificationChallenge, + header_get, ) logger = logging.getLogger(__name__) @@ -44,18 +46,6 @@ _SEND_MESSAGE_URL = "https://slack.com/api/chat.postMessage" _MENTION_RE = re.compile(r"<@[^>]+>\s*") -class URLVerificationChallenge(Exception): - """Slack URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。 - - Slack 在配置 Events API 时发送 ``url_verification`` 事件,要求服务端 - 原样返回 ``challenge`` 字段以验证 URL 可达。 - """ - - def __init__(self, challenge: str) -> None: - super().__init__(f"Slack URL verification challenge: {challenge}") - self.challenge = challenge - - class SignatureVerificationError(Exception): """``verification_token`` 校验失败 — 拒绝处理。""" @@ -79,20 +69,10 @@ class SlackMessageAdapter(MessageAdapter): signing_secret: str, verification_token: str | None = None, ) -> None: + super().__init__() self.bot_token = bot_token self.signing_secret = signing_secret self.verification_token = verification_token - # 懒加载 httpx 客户端 - self._client: httpx.AsyncClient | None = None - - # ------------------------------------------------------------------ - # httpx 客户端懒加载 - # ------------------------------------------------------------------ - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None: - self._client = httpx.AsyncClient(timeout=10.0) - return self._client # ------------------------------------------------------------------ # 签名验证 @@ -111,8 +91,8 @@ class SlackMessageAdapter(MessageAdapter): Returns: True 表示签名校验通过。 """ - signature = _header_get(headers, "X-Slack-Signature") - timestamp_str = _header_get(headers, "X-Slack-Request-Timestamp") + signature = header_get(headers, "X-Slack-Signature") + timestamp_str = header_get(headers, "X-Slack-Request-Timestamp") if not signature or not timestamp_str: return False @@ -259,29 +239,3 @@ class SlackMessageAdapter(MessageAdapter): except httpx.HTTPError as exc: logger.error("Slack send_message 网络错误: %s", exc) return False - - # ------------------------------------------------------------------ - # 资源释放 - # ------------------------------------------------------------------ - - async def close(self) -> None: - """关闭 httpx 客户端(如已创建)。""" - if self._client is not None: - await self._client.aclose() - self._client = None - - -# --------------------------------------------------------------------------- -# 辅助函数 -# --------------------------------------------------------------------------- - - -def _header_get(headers: dict[str, str], name: str) -> str | None: - """大小写不敏感的 header 查找。""" - if name in headers: - return headers[name] - lower = name.lower() - for k, v in headers.items(): - if k.lower() == lower: - return v - return None diff --git a/src/agentkit/channels/wecom.py b/src/agentkit/channels/wecom.py index 49aa1e7..429ca60 100644 --- a/src/agentkit/channels/wecom.py +++ b/src/agentkit/channels/wecom.py @@ -29,6 +29,7 @@ from agentkit.channels.base import ( IncomingMessage, MessageAdapter, OutgoingMessage, + header_get, ) logger = logging.getLogger(__name__) @@ -80,25 +81,15 @@ class WeComMessageAdapter(MessageAdapter): token: str, encoding_aes_key: str, ) -> None: + super().__init__() self.corp_id = corp_id self.agent_id = agent_id self.corp_secret = corp_secret self.token = token self.encoding_aes_key = encoding_aes_key - # 懒加载 httpx 客户端 - self._client: httpx.AsyncClient | None = None # ponytail: 简单 TTL 缓存。天花板:单实例内存;升级路径:Redis 共享。 self._token_cache: tuple[str, float] | None = None - # ------------------------------------------------------------------ - # httpx 客户端懒加载 - # ------------------------------------------------------------------ - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None: - self._client = httpx.AsyncClient(timeout=10.0) - return self._client - # ------------------------------------------------------------------ # AES 密钥 / 加解密 # ------------------------------------------------------------------ @@ -209,9 +200,9 @@ class WeComMessageAdapter(MessageAdapter): Returns: True 表示签名校验通过。 """ - msg_signature = _header_get(headers, "msg_signature") - timestamp = _header_get(headers, "timestamp") - nonce = _header_get(headers, "nonce") + msg_signature = header_get(headers, "msg_signature") + timestamp = header_get(headers, "timestamp") + nonce = header_get(headers, "nonce") if not msg_signature or not timestamp or not nonce: return False @@ -247,8 +238,8 @@ class WeComMessageAdapter(MessageAdapter): # URL 验证流程 — 内部 XML 包含 EchoStr if "EchoStr" in inner: - timestamp = _header_get(headers, "timestamp") or "" - nonce = _header_get(headers, "nonce") or "" + timestamp = header_get(headers, "timestamp") or "" + nonce = header_get(headers, "nonce") or "" encrypted_echo = self._encrypt(inner["EchoStr"]) # 计算响应签名 sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo]) @@ -362,29 +353,3 @@ class WeComMessageAdapter(MessageAdapter): except httpx.HTTPError as exc: logger.error("企微 access_token 网络错误: %s", exc) return None - - # ------------------------------------------------------------------ - # 资源释放 - # ------------------------------------------------------------------ - - async def close(self) -> None: - """关闭 httpx 客户端(如已创建)。""" - if self._client is not None: - await self._client.aclose() - self._client = None - - -# --------------------------------------------------------------------------- -# 辅助函数 -# --------------------------------------------------------------------------- - - -def _header_get(headers: dict[str, str], name: str) -> str | None: - """大小写不敏感的 header 查找。""" - if name in headers: - return headers[name] - lower = name.lower() - for k, v in headers.items(): - if k.lower() == lower: - return v - return None diff --git a/src/agentkit/llm/cache.py b/src/agentkit/llm/cache.py index 108d209..c58e724 100644 --- a/src/agentkit/llm/cache.py +++ b/src/agentkit/llm/cache.py @@ -803,18 +803,16 @@ class LitellmCacheManager: """返回 litellm acompletion 的 cache 参数(禁用缓存)。""" return {"no-cache": True} - def detect_cache_hit(self, response: Any) -> bool: - """检测 LiteLLM 响应是否为缓存命中。 + def record_cache_result(self, is_hit: bool) -> None: + """记录单次 LLM 调用的缓存命中/未命中(用于 usage tracking 统计)。 - LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。 + 命中判定由调用方完成(gateway 通过 ``response.cache_hit`` 判定), + 本方法只负责更新计数器,避免重复检测逻辑。 """ - hidden = getattr(response, "_hidden_params", None) - if isinstance(hidden, dict): - if "cache_key" in hidden or hidden.get("cache_hit"): - self._hits += 1 - return True - self._misses += 1 - return False + if is_hit: + self._hits += 1 + else: + self._misses += 1 def stats(self) -> dict[str, int]: """返回缓存统计。""" diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 48ea170..691eca9 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -224,10 +224,8 @@ class LLMGateway: # U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0) is_cache_hit = getattr(response, "cache_hit", False) - if is_cache_hit and self._cache_manager is not None: - self._cache_manager._hits += 1 - elif self._cache_manager is not None: - self._cache_manager._misses += 1 + if self._cache_manager is not None: + self._cache_manager.record_cache_result(is_cache_hit) # 计算成本(缓存命中时 cost=0) cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage) diff --git a/src/agentkit/llm/providers/litellm_provider.py b/src/agentkit/llm/providers/litellm_provider.py index 56adce4..45d144a 100644 --- a/src/agentkit/llm/providers/litellm_provider.py +++ b/src/agentkit/llm/providers/litellm_provider.py @@ -119,7 +119,6 @@ class LitellmProvider(LLMProvider): accumulated_tool_calls: dict[int, dict[str, Any]] = {} final_usage: TokenUsage | None = None final_model: str = request.model - yielded_any = False try: # litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式: @@ -129,7 +128,6 @@ class LitellmProvider(LLMProvider): raw = litellm.acompletion(**kwargs) stream = await raw if inspect.isawaitable(raw) else raw async for chunk in stream: - yielded_any = True parsed = self._parse_stream_chunk( chunk, request.model, @@ -156,10 +154,6 @@ class LitellmProvider(LLMProvider): except Exception as e: raise LLMProviderError(self._provider_type, str(e)) from e - # ponytail: 若流完全为空(yielded_any=False),上面仍会 yield 一个 - # is_final=True 的空 chunk,调用方据此判断空响应。无需额外分支。 - _ = yielded_any # 标记保留(调试 / 未来扩展) - # ------------------------------------------------------------------ # 内部辅助 # ------------------------------------------------------------------ @@ -184,7 +178,7 @@ class LitellmProvider(LLMProvider): if request.timeout is not None: kwargs["timeout"] = request.timeout # U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion - cache_params = getattr(request, "_cache", None) + cache_params = request._cache if cache_params is not None: kwargs["cache"] = cache_params # 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数) diff --git a/src/agentkit/mcp/client.py b/src/agentkit/mcp/client.py index 202fae3..b912ab9 100644 --- a/src/agentkit/mcp/client.py +++ b/src/agentkit/mcp/client.py @@ -63,6 +63,9 @@ class MCPClient: self._timeout = timeout self._tools_cache: list[dict] | None = None self._transport = transport + # U10 — 懒构造并缓存的 langchain client,避免每次 list_tools/call_tool + # 都新建 MultiServerMCPClient(stdio 传输下会反复 spawn 子进程)。 + self._lc_client: Any = None if transport is not None: # 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为 @@ -138,15 +141,34 @@ class MCPClient: 发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。 """ if isinstance(transport, HTTPTransport): - server_url = transport._endpoint + server_url = transport.endpoint_url elif isinstance(transport, SSETransport): - server_url = transport._endpoint + server_url = transport.endpoint_url elif isinstance(transport, StdioTransport): server_url = f"stdio://{transport._command}" else: server_url = "" return cls(server_url=server_url, transport=transport) + async def _get_lc_client(self) -> Any: + """懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。 + + 首次调用时创建,后续返回缓存,避免每次 list_tools/call_tool 都新建 + client(stdio 传输下会 spawn 新子进程,造成连接/进程泄漏)。 + """ + if self._lc_client is None: + client_cls = _import_langchain_client() + self._lc_client = client_cls({"server": self._langchain_config}) + return self._lc_client + + async def aclose(self) -> None: + """关闭缓存的 langchain client(如果它提供 ``aclose`` 方法)。""" + if self._lc_client is not None: + aclose = getattr(self._lc_client, "aclose", None) + if aclose is not None: + await aclose() + self._lc_client = None + async def list_tools(self) -> list[dict]: """列出远程 MCP Server 上的工具 @@ -165,9 +187,8 @@ class MCPClient: self._tools_cache = tools return self._tools_cache - # 新 langchain 路径 - client_cls = _import_langchain_client() - client = client_cls({"server": self._langchain_config}) + # 新 langchain 路径 — 复用缓存的 client + client = await self._get_lc_client() lc_tools = await client.get_tools() tools = [ { @@ -218,9 +239,8 @@ class MCPClient: params={"name": tool_name, "arguments": arguments}, ) - # 新 langchain 路径 - client_cls = _import_langchain_client() - client = client_cls({"server": self._langchain_config}) + # 新 langchain 路径 — 复用缓存的 client + client = await self._get_lc_client() lc_tools = await client.get_tools() for tool in lc_tools: if tool.name == tool_name: diff --git a/src/agentkit/mcp/server.py b/src/agentkit/mcp/server.py index 5907d29..4195d5e 100644 --- a/src/agentkit/mcp/server.py +++ b/src/agentkit/mcp/server.py @@ -230,7 +230,6 @@ class MCPServer: try: from fastapi import FastAPI - from fastapi import Request # noqa: F401 — 用于 jsonrpc_endpoint 的类型注解 except ImportError: raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]") diff --git a/src/agentkit/mcp/transport.py b/src/agentkit/mcp/transport.py index 1c7c705..163181f 100644 --- a/src/agentkit/mcp/transport.py +++ b/src/agentkit/mcp/transport.py @@ -98,6 +98,11 @@ class HTTPTransport(Transport): def is_connected(self) -> bool: return self._client is not None and not self._client.is_closed + @property + def endpoint_url(self) -> str: + """已配置的端点 URL(去掉尾部斜杠)。""" + return self._endpoint + async def connect(self) -> None: """建立 HTTP 连接""" if self.is_connected: @@ -220,6 +225,11 @@ class SSETransport(Transport): def is_connected(self) -> bool: return self._connected and self._client is not None and not self._client.is_closed + @property + def endpoint_url(self) -> str: + """已配置的端点 URL(去掉尾部斜杠)。""" + return self._endpoint + async def connect(self) -> None: """建立 SSE 连接 diff --git a/src/agentkit/server/routes/channels.py b/src/agentkit/server/routes/channels.py index 941a169..71b46c2 100644 --- a/src/agentkit/server/routes/channels.py +++ b/src/agentkit/server/routes/channels.py @@ -21,21 +21,25 @@ import asyncio import logging import re import time +from collections import OrderedDict from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import Response from pydantic import BaseModel, Field -from agentkit.channels.base import ChannelType, MessageAdapter, OutgoingMessage -from agentkit.channels.dingtalk import DingTalkMessageAdapter -from agentkit.channels.feishu import FeishuMessageAdapter, URLVerificationChallenge -from agentkit.channels.secrets import SecretsStore -from agentkit.channels.slack import ( - SlackMessageAdapter, - URLVerificationChallenge as SlackURLVerificationChallenge, +from agentkit.channels.base import ( + ChannelType, + MessageAdapter, + OutgoingMessage, + URLVerificationChallenge, ) +from agentkit.channels.dingtalk import DingTalkMessageAdapter +from agentkit.channels.feishu import FeishuMessageAdapter +from agentkit.channels.secrets import SecretsStore +from agentkit.channels.slack import SlackMessageAdapter from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification +from agentkit.chat.skill_routing import ExecutionMode from agentkit.server.auth.dependencies import require_permission from agentkit.server.auth.permissions import Permission @@ -89,16 +93,40 @@ _RATE_LIMIT_WINDOW = 60.0 # 窗口大小(秒) _RATE_LIMIT_MAX = 100 # 窗口内最大请求数 # nonce -> 过期时间戳(与飞书签名时间戳窗口一致) -_seen_nonces: dict[str, float] = {} +# OrderedDict:按插入顺序遍历,清理过期项时从头部弹出,O(1) 摊销。 +_seen_nonces: OrderedDict[str, float] = OrderedDict() _NONCE_TTL = 300.0 +# Webhook 后台处理并发上限 — 防止高流量下 LLM 调用无界并发 +_WEBHOOK_MAX_CONCURRENT = 20 +_webhook_semaphore: asyncio.Semaphore | None = None +# 持有后台任务引用,防止 GC 回收正在运行的 task +_pending_webhook_tasks: set[asyncio.Task[None]] = set() + +# 适配器缓存 — channel_id -> adapter。避免 per-request 重建导致 token TTL 缓存失效。 +# 配置变更(PUT/DELETE channel)时通过 _invalidate_adapter_cache 清除对应条目。 +_adapter_cache: dict[str, MessageAdapter] = {} + + +def _get_webhook_semaphore() -> asyncio.Semaphore: + """懒构造 webhook 并发信号量(事件循环首次需要时创建)。""" + global _webhook_semaphore + if _webhook_semaphore is None: + _webhook_semaphore = asyncio.Semaphore(_WEBHOOK_MAX_CONCURRENT) + return _webhook_semaphore + def _check_rate_limit(client_ip: str) -> bool: - """滑动窗口限流。返回 True 表示放行,False 表示超限。""" + """滑动窗口限流。返回 True 表示放行,False 表示超限。 + + ponytail: 过滤后 timestamps 为空时清除 IP 条目,避免 _rate_limits 无界增长。 + """ now = time.monotonic() - timestamps = _rate_limits.get(client_ip, []) cutoff = now - _RATE_LIMIT_WINDOW - timestamps = [t for t in timestamps if t > cutoff] + timestamps = [t for t in _rate_limits.get(client_ip, []) if t > cutoff] + if not timestamps: + # 清理过期空条目 — 不活跃 IP 不再占用 dict 槽位 + _rate_limits.pop(client_ip, None) if len(timestamps) >= _RATE_LIMIT_MAX: _rate_limits[client_ip] = timestamps return False @@ -108,12 +136,20 @@ def _check_rate_limit(client_ip: str) -> bool: def _check_nonce_dedup(nonce: str) -> bool: - """Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。""" + """Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。 + + 使用 OrderedDict 按插入顺序清理过期项:从头部弹出已过期条目, + O(1) 摊销(旧实现遍历整个字典为 O(N))。fail-closed 语义:过期 nonce 拒绝。 + """ now = time.monotonic() - # 惰性清理过期项 - expired = [k for k, v in _seen_nonces.items() if v < now] - for k in expired: - del _seen_nonces[k] + # 从头部弹出过期项(nonce 按单调时间插入,头部最旧即最先过期)。 + # ponytail: 摊销 O(1) 清理。旧实现遍历整个 dict 为 O(N)。 + while _seen_nonces: + _oldest_key, oldest_expiry = next(iter(_seen_nonces.items())) + if oldest_expiry >= now: + break + _seen_nonces.popitem(last=False) + # 过期 nonce 已被清理 — 视为新 nonce(与旧实现先删全部过期项再查一致) if nonce in _seen_nonces: return False _seen_nonces[nonce] = now + _NONCE_TTL @@ -124,6 +160,8 @@ def _reset_webhook_state() -> None: """重置限流与 nonce 状态(仅供测试使用)。""" _rate_limits.clear() _seen_nonces.clear() + _adapter_cache.clear() + _pending_webhook_tasks.clear() def _validate_channel_id(channel_id: str) -> str: @@ -288,6 +326,9 @@ async def update_channel( if payload.config is not None: cfg["config"] = payload.config + # 配置变更 — 清除缓存适配器,下次 webhook 重建(凭证可能已更新) + await _invalidate_adapter_cache(channel_id) + return ChannelInfo( channel_id=channel_id, channel_type=ChannelType(cfg["channel_type"]), @@ -311,6 +352,9 @@ async def delete_channel( for name in cfg.get("secret_keys", []): await store.delete_secret(f"{channel_id}:{name}") + # 清除缓存适配器并关闭旧实例 + await _invalidate_adapter_cache(channel_id) + return {"deleted": channel_id} @@ -320,13 +364,20 @@ async def delete_channel( async def _build_adapter(channel_id: str) -> MessageAdapter: - """根据渠道配置与 secrets 构造适配器实例。 + """根据渠道配置与 secrets 构造适配器实例(带缓存)。 支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型,按 ``channel_type`` 分发。 + 首次构造后缓存到 ``_adapter_cache``,后续请求复用同一实例(token TTL 缓存命中)。 + secrets 获取使用 ``asyncio.gather`` 并行,避免串行 await。 Raises: HTTPException: 渠道不存在(404)、渠道类型不支持(400)、缺少必要凭证(500)。 """ + # 命中缓存直接返回(per-request 重建会使 token TTL 缓存失效) + cached = _adapter_cache.get(channel_id) + if cached is not None: + return cached + cfg = _channels.get(channel_id) if cfg is None: raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在") @@ -335,43 +386,53 @@ async def _build_adapter(channel_id: str) -> MessageAdapter: channel_type = cfg["channel_type"] if channel_type == ChannelType.FEISHU.value: - app_id = await store.get_secret(f"{channel_id}:app_id") - app_secret = await store.get_secret(f"{channel_id}:app_secret") - encrypt_key = await store.get_secret(f"{channel_id}:encrypt_key") - verification_token = await store.get_secret(f"{channel_id}:verification_token") + app_id, app_secret, encrypt_key, verification_token = await asyncio.gather( + store.get_secret(f"{channel_id}:app_id"), + store.get_secret(f"{channel_id}:app_secret"), + store.get_secret(f"{channel_id}:encrypt_key"), + store.get_secret(f"{channel_id}:verification_token"), + ) if not app_id or not app_secret: raise HTTPException( status_code=500, detail=f"渠道 '{channel_id}' 缺少 app_id 或 app_secret" ) - return FeishuMessageAdapter( + adapter = FeishuMessageAdapter( app_id=app_id, app_secret=app_secret, encrypt_key=encrypt_key, verification_token=verification_token, ) + _adapter_cache[channel_id] = adapter + return adapter if channel_type == ChannelType.DINGTALK.value: - app_key = await store.get_secret(f"{channel_id}:app_key") - app_secret = await store.get_secret(f"{channel_id}:app_secret") - robot_code = await store.get_secret(f"{channel_id}:robot_code") - token = await store.get_secret(f"{channel_id}:token") + app_key, app_secret, robot_code, token = await asyncio.gather( + store.get_secret(f"{channel_id}:app_key"), + store.get_secret(f"{channel_id}:app_secret"), + store.get_secret(f"{channel_id}:robot_code"), + store.get_secret(f"{channel_id}:token"), + ) if not all([app_key, app_secret, robot_code]): raise HTTPException( status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证" ) - return DingTalkMessageAdapter( + adapter = DingTalkMessageAdapter( app_key=app_key, app_secret=app_secret, robot_code=robot_code, token=token, ) + _adapter_cache[channel_id] = adapter + return adapter if channel_type == ChannelType.WECOM.value: - corp_id = await store.get_secret(f"{channel_id}:corp_id") - corp_secret = await store.get_secret(f"{channel_id}:corp_secret") - token = await store.get_secret(f"{channel_id}:token") - encoding_aes_key = await store.get_secret(f"{channel_id}:encoding_aes_key") - agent_id_raw = await store.get_secret(f"{channel_id}:agent_id") + corp_id, corp_secret, token, encoding_aes_key, agent_id_raw = await asyncio.gather( + store.get_secret(f"{channel_id}:corp_id"), + store.get_secret(f"{channel_id}:corp_secret"), + store.get_secret(f"{channel_id}:token"), + store.get_secret(f"{channel_id}:encoding_aes_key"), + store.get_secret(f"{channel_id}:agent_id"), + ) if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]): raise HTTPException( status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证" @@ -383,98 +444,124 @@ async def _build_adapter(channel_id: str) -> MessageAdapter: status_code=500, detail=f"渠道 '{channel_id}' agent_id 不是合法整数", ) from exc - return WeComMessageAdapter( + adapter = WeComMessageAdapter( corp_id=corp_id, agent_id=agent_id, corp_secret=corp_secret, token=token, encoding_aes_key=encoding_aes_key, ) + _adapter_cache[channel_id] = adapter + return adapter if channel_type == ChannelType.SLACK.value: - bot_token = await store.get_secret(f"{channel_id}:bot_token") - signing_secret = await store.get_secret(f"{channel_id}:signing_secret") - verification_token = await store.get_secret(f"{channel_id}:verification_token") + bot_token, signing_secret, verification_token = await asyncio.gather( + store.get_secret(f"{channel_id}:bot_token"), + store.get_secret(f"{channel_id}:signing_secret"), + store.get_secret(f"{channel_id}:verification_token"), + ) if not bot_token or not signing_secret: raise HTTPException( status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证" ) - return SlackMessageAdapter( + adapter = SlackMessageAdapter( bot_token=bot_token, signing_secret=signing_secret, verification_token=verification_token, ) + _adapter_cache[channel_id] = adapter + return adapter raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}") +async def _invalidate_adapter_cache(channel_id: str) -> None: + """清除指定渠道的缓存适配器并关闭旧实例(配置变更/删除时调用)。""" + old = _adapter_cache.pop(channel_id, None) + if old is not None: + try: + await old.close() + except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更 + logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id) + + +async def close_all_adapters() -> None: + """关闭所有缓存的适配器(供 app shutdown 调用)。""" + for channel_id, adapter in list(_adapter_cache.items()): + try: + await adapter.close() + except Exception: # noqa: BLE001 + logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id) + _adapter_cache.clear() + + async def _process_inbound_message( app_state: Any, adapter: MessageAdapter, message: Any ) -> None: """后台处理入站消息 — 调用 chat 链路并通过适配器回复。 整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出 - (webhook 必须保持响应能力)。``adapter.close()`` 在 finally 中调用。 + (webhook 必须保持响应能力)。处理逻辑受全局信号量限流 + (``_WEBHOOK_MAX_CONCURRENT``),防止高流量下 LLM 调用无界并发。 + 适配器由 ``_adapter_cache`` 管理,不在 per-request 关闭(关闭会清空 token 缓存)。 适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。 """ - try: - request_preprocessor = getattr(app_state, "request_preprocessor", None) - llm_gateway = getattr(app_state, "llm_gateway", None) - if request_preprocessor is None or llm_gateway is None: - logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理") - return - - # 路由预处理 — IM 场景使用默认 agent,无需技能注册表 - routing = await request_preprocessor.preprocess( - content=message.content, default_agent_name="default" - ) - - final_content = "" - execution_mode = getattr(routing, "execution_mode", None) - # DIRECT_CHAT 模式 — 直接调用 LLM - if execution_mode is not None and execution_mode.value == "direct_chat": - response = await llm_gateway.chat( - messages=[{"role": "user", "content": message.content}], - model=routing.model or "default", - ) - final_content = response.content - else: - # REACT 或其他模式 — 优先使用 ReActEngine,失败回退到 DIRECT_CHAT - try: - from agentkit.core.react import ReActEngine - - engine = ReActEngine(llm_gateway=llm_gateway) - result = await engine.execute( - messages=[{"role": "user", "content": message.content}], - tools=getattr(routing, "tools", None) or None, - model=routing.model or "default", + async with _get_webhook_semaphore(): + try: + request_preprocessor = getattr(app_state, "request_preprocessor", None) + llm_gateway = getattr(app_state, "llm_gateway", None) + if request_preprocessor is None or llm_gateway is None: + logger.warning( + "app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理" ) - final_content = getattr(result, "content", "") or "" - except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常 - logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc) + return + + # 路由预处理 — IM 场景使用默认 agent,无需技能注册表 + routing = await request_preprocessor.preprocess( + content=message.content, default_agent_name="default" + ) + + final_content = "" + execution_mode = getattr(routing, "execution_mode", None) + # DIRECT_CHAT 模式 — 直接调用 LLM + if execution_mode == ExecutionMode.DIRECT_CHAT: response = await llm_gateway.chat( messages=[{"role": "user", "content": message.content}], model=routing.model or "default", ) final_content = response.content + else: + # REACT 或其他模式 — 优先使用 ReActEngine,失败回退到 DIRECT_CHAT + try: + from agentkit.core.react import ReActEngine - if not final_content: - logger.warning("消息处理未产生内容 — 不发送回复") - return + engine = ReActEngine(llm_gateway=llm_gateway) + result = await engine.execute( + messages=[{"role": "user", "content": message.content}], + tools=getattr(routing, "tools", None) or None, + model=routing.model or "default", + ) + final_content = getattr(result, "content", "") or "" + except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常 + logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc) + response = await llm_gateway.chat( + messages=[{"role": "user", "content": message.content}], + model=routing.model or "default", + ) + final_content = response.content - outgoing = OutgoingMessage( - channel=message.channel, - chat_id=message.chat_id, - content=final_content, - ) - await adapter.send_message(outgoing) - except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力 - logger.exception("处理入站消息失败: %s", exc) - finally: - try: - await adapter.close() - except Exception: # noqa: BLE001 - logger.debug("adapter.close() 异常已忽略") + if not final_content: + logger.warning("消息处理未产生内容 — 不发送回复") + return + + outgoing = OutgoingMessage( + channel=message.channel, + chat_id=message.chat_id, + content=final_content, + ) + await adapter.send_message(outgoing) + except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力 + logger.exception("处理入站消息失败: %s", exc) def _get_client_ip(request: Request) -> str: @@ -527,7 +614,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: try: message = await adapter.receive_message(headers_dict, body) - except (URLVerificationChallenge, SlackURLVerificationChallenge) as e: + except URLVerificationChallenge as e: # URL 验证流程 — 飞书 / Slack 配置 webhook 时发送 return {"challenge": e.challenge} except WeComURLVerification as e: @@ -535,6 +622,9 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: return Response(content=e.response_xml, media_type="application/xml") # 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200) - asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) + # 持有 task 引用防止 GC 回收正在运行的后台任务 + task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) + _pending_webhook_tasks.add(task) + task.add_done_callback(_pending_webhook_tasks.discard) return {"code": 0} diff --git a/tests/unit/llm/test_cache.py b/tests/unit/llm/test_cache.py index bc3f142..45cefe8 100644 --- a/tests/unit/llm/test_cache.py +++ b/tests/unit/llm/test_cache.py @@ -9,7 +9,7 @@ 6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key 7. kb_caching_disabled 禁用缓存(安全要求 c) 8. cache_params_for_hit / no_cache — 返回正确 dict -9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True +9. record_cache_result — 记录命中/未命中到 stats 计数器 10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87 11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除 12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同 @@ -163,10 +163,10 @@ class TestCacheStats: manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) # 2 hits - manager.detect_cache_hit(_make_litellm_response(cache_key="k1")) - manager.detect_cache_hit(_make_litellm_response(cache_key="k2")) + manager.record_cache_result(True) + manager.record_cache_result(True) # 1 miss - manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key + manager.record_cache_result(False) stats = manager.stats() assert stats["total_hits"] == 2 @@ -302,39 +302,6 @@ class TestCacheParams: assert params == {"no-cache": True} -# --------------------------------------------------------------------------- -# 9. detect_cache_hit -# --------------------------------------------------------------------------- - - -class TestDetectCacheHit: - def test_hit_with_cache_key(self): - manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - resp = _make_litellm_response(cache_key="some_key") - assert manager.detect_cache_hit(resp) is True - - def test_hit_with_cache_hit_flag(self): - manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - resp = _make_litellm_response() - resp._hidden_params = {"cache_hit": True} - assert manager.detect_cache_hit(resp) is True - - def test_miss_without_cache_key(self): - manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - resp = _make_litellm_response() - assert manager.detect_cache_hit(resp) is False - - def test_miss_with_no_hidden_params(self): - manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - resp = SimpleNamespace(_hidden_params=None) - assert manager.detect_cache_hit(resp) is False - - def test_miss_with_no_hidden_params_attr(self): - manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - resp = SimpleNamespace() - assert manager.detect_cache_hit(resp) is False - - # --------------------------------------------------------------------------- # 10. LitellmCacheConfig.from_cache_config # ---------------------------------------------------------------------------