refactor: ce-simplify-code 审查修复 — 去重 + 效率 + 死代码清理

3 个审查代理(复用/质量/效率)发现 15 个问题,全部修复:

效率与安全(6 项):
- MCPClient 缓存 MultiServerMCPClient 单例 + aclose(),修复连接/子进程泄漏
- _rate_limits 清理空 IP 条目,修复 X-Forwarded-For 欺骗下内存泄漏
- _seen_nonces 改用 OrderedDict,O(1) 摊销过期清理
- webhook 后台任务加 Semaphore(20) + 任务引用追踪,限制无界并发
- _build_adapter 用 asyncio.gather 并行解密 secrets
- 适配器实例缓存(_adapter_cache),token TTL 缓存跨请求命中

去重(4 项):
- header_get 提取到 channels/base.py,4 个适配器统一 import
- _get_client/close() 移入 MessageAdapter 基类,子类继承
- URLVerificationChallenge 统一到 base.py,feishu/slack/wecom 共用
- Transport ABC 添加 endpoint_url 属性,from_transport 不再访问私有字段

死代码与类型安全(5 项):
- detect_cache_hit 死方法替换为 record_cache_result 公开 API
- execution_mode.value == "direct_chat" 改用枚举比较
- 删除 yielded_any 死变量、重复 from fastapi import Request、
  多余 getattr 防御

453 tests passed, ruff clean(预存 F841 非本次引入)
This commit is contained in:
chiguyong 2026-06-25 23:54:14 +08:00
parent 793476cafa
commit 1ccaf56b9a
13 changed files with 307 additions and 348 deletions

View File

@ -11,6 +11,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import httpx
class ChannelType(str, Enum):
"""支持的消息平台渠道。"""
@ -51,17 +53,42 @@ class OutgoingMessage:
reply_to_message_id: str | None = None
class URLVerificationChallenge(Exception):
"""URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
飞书 / Slack 在配置 webhook 时发送一次 ``url_verification`` 事件
要求服务端原样返回 ``challenge`` 字段以验证 URL 可达
"""
def __init__(self, challenge: str) -> None:
super().__init__(f"URL verification challenge: {challenge}")
self.challenge = challenge
class MessageAdapter(abc.ABC):
"""消息适配器 ABC。
生命周期
__init__ verify_signature() receive_message() send_message() close()
子类必须实现全部抽象方法verify_signature 失败时调用方应拒绝处理
webhook 端点 fail-closedRedis 不可用或签名校验失败均返回 503/401
不可跳过 nonce dedup 直接处理消息
子类必须实现 verify_signature / receive_message / send_message 抽象方法
``_get_client`` / ``close`` 已在基类提供懒构造 httpx 客户端 + 释放资源
子类只需在 ``__init__`` 中调用 ``super().__init__()``
verify_signature 失败时调用方应拒绝处理webhook 端点 fail-closed
Redis 不可用或签名校验失败均返回 503/401不可跳过 nonce dedup 直接处理消息
"""
def __init__(self) -> None:
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
"""懒构造 httpx 客户端(子类共享此实现)。"""
if self._client is None:
self._client = httpx.AsyncClient(timeout=10.0)
return self._client
@abc.abstractmethod
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
"""验证平台签名/token。返回 True 表示请求可信。"""
@ -74,6 +101,25 @@ class MessageAdapter(abc.ABC):
async def send_message(self, message: OutgoingMessage) -> bool:
"""向平台发送消息。返回 True 表示发送成功。"""
@abc.abstractmethod
async def close(self) -> None:
"""释放资源HTTP 客户端、连接池等)。"""
"""关闭 httpx 客户端(如已创建)。子类可覆盖以释放额外资源。"""
if self._client is not None:
await self._client.aclose()
self._client = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def header_get(headers: dict[str, str], name: str) -> str | None:
"""大小写不敏感的 header 查找。"""
# 直接命中
if name in headers:
return headers[name]
lower = name.lower()
for k, v in headers.items():
if k.lower() == lower:
return v
return None

View File

@ -28,6 +28,7 @@ from agentkit.channels.base import (
IncomingMessage,
MessageAdapter,
OutgoingMessage,
header_get,
)
logger = logging.getLogger(__name__)
@ -68,25 +69,15 @@ class DingTalkMessageAdapter(MessageAdapter):
robot_code: str,
token: str | None = None,
) -> None:
super().__init__()
self.app_key = app_key
self.app_secret = app_secret
self.robot_code = robot_code
self.token = token
# 懒加载 httpx 客户端
self._client: httpx.AsyncClient | None = None
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
# 升级路径Redis 缓存共享给多实例。
self._token_cache: tuple[str, float] | None = None
# ------------------------------------------------------------------
# httpx 客户端懒加载
# ------------------------------------------------------------------
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=10.0)
return self._client
# ------------------------------------------------------------------
# 签名验证
# ------------------------------------------------------------------
@ -107,12 +98,12 @@ class DingTalkMessageAdapter(MessageAdapter):
"""
# Token 校验(若配置)
if self.token is not None:
token_header = _header_get(headers, "Token")
token_header = header_get(headers, "Token")
if token_header is None or not hmac.compare_digest(token_header, self.token):
return False
sign = _header_get(headers, "Sign")
timestamp_str = _header_get(headers, "Timestamp")
sign = header_get(headers, "Sign")
timestamp_str = header_get(headers, "Timestamp")
if sign is None and timestamp_str is None:
# 无签名头:仅当 token 已校验通过才放行
@ -253,29 +244,3 @@ class DingTalkMessageAdapter(MessageAdapter):
except httpx.HTTPError as exc:
logger.error("钉钉 accessToken 网络错误: %s", exc)
return None
# ------------------------------------------------------------------
# 资源释放
# ------------------------------------------------------------------
async def close(self) -> None:
"""关闭 httpx 客户端(如已创建)。"""
if self._client is not None:
await self._client.aclose()
self._client = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _header_get(headers: dict[str, str], name: str) -> str | None:
"""大小写不敏感的 header 查找。"""
if name in headers:
return headers[name]
lower = name.lower()
for k, v in headers.items():
if k.lower() == lower:
return v
return None

View File

@ -28,6 +28,8 @@ from agentkit.channels.base import (
IncomingMessage,
MessageAdapter,
OutgoingMessage,
URLVerificationChallenge,
header_get,
)
logger = logging.getLogger(__name__)
@ -45,18 +47,6 @@ _SEND_MESSAGE_URL = "https://open.feishu.cn/open-apis/im/v1/messages"
_MENTION_RE = re.compile(r"@_user_\d+\s*")
class URLVerificationChallenge(Exception):
"""飞书 URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
飞书在配置 webhook 时会发送一次 ``url_verification`` 事件要求服务端
原样返回 ``challenge`` 字段以验证 URL 可达
"""
def __init__(self, challenge: str) -> None:
super().__init__(f"URL verification challenge: {challenge}")
self.challenge = challenge
class SignatureVerificationError(Exception):
"""事件 ``verification_token`` 校验失败 — 拒绝处理。"""
@ -82,25 +72,15 @@ class FeishuMessageAdapter(MessageAdapter):
encrypt_key: str | None = None,
verification_token: str | None = None,
) -> None:
super().__init__()
self.app_id = app_id
self.app_secret = app_secret
self.encrypt_key = encrypt_key
self.verification_token = verification_token
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
self._client: httpx.AsyncClient | None = None
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
# 升级路径Redis 缓存共享给多实例。
self._token_cache: tuple[str, float] | None = None
# ------------------------------------------------------------------
# httpx 客户端懒加载
# ------------------------------------------------------------------
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=10.0)
return self._client
# ------------------------------------------------------------------
# 签名验证
# ------------------------------------------------------------------
@ -122,12 +102,12 @@ class FeishuMessageAdapter(MessageAdapter):
logger.warning("飞书适配器未配置 encrypt_key — 拒绝所有 webhook 请求")
return False
signature = _header_get(headers, "X-Lark-Signature")
signature = header_get(headers, "X-Lark-Signature")
if not signature:
return False
timestamp_str = _header_get(headers, "X-Lark-Request-Timestamp")
nonce = _header_get(headers, "X-Lark-Request-Nonce")
timestamp_str = header_get(headers, "X-Lark-Request-Timestamp")
nonce = header_get(headers, "X-Lark-Request-Nonce")
if not timestamp_str or not nonce:
return False
@ -335,30 +315,3 @@ class FeishuMessageAdapter(MessageAdapter):
except httpx.HTTPError as exc:
logger.error("飞书 tenant_token 网络错误: %s", exc)
return None
# ------------------------------------------------------------------
# 资源释放
# ------------------------------------------------------------------
async def close(self) -> None:
"""关闭 httpx 客户端(如已创建)。"""
if self._client is not None:
await self._client.aclose()
self._client = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _header_get(headers: dict[str, str], name: str) -> str | None:
"""大小写不敏感的 header 查找。"""
# 直接命中
if name in headers:
return headers[name]
lower = name.lower()
for k, v in headers.items():
if k.lower() == lower:
return v
return None

View File

@ -30,6 +30,8 @@ from agentkit.channels.base import (
IncomingMessage,
MessageAdapter,
OutgoingMessage,
URLVerificationChallenge,
header_get,
)
logger = logging.getLogger(__name__)
@ -44,18 +46,6 @@ _SEND_MESSAGE_URL = "https://slack.com/api/chat.postMessage"
_MENTION_RE = re.compile(r"<@[^>]+>\s*")
class URLVerificationChallenge(Exception):
"""Slack URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
Slack 在配置 Events API 时发送 ``url_verification`` 事件要求服务端
原样返回 ``challenge`` 字段以验证 URL 可达
"""
def __init__(self, challenge: str) -> None:
super().__init__(f"Slack URL verification challenge: {challenge}")
self.challenge = challenge
class SignatureVerificationError(Exception):
"""``verification_token`` 校验失败 — 拒绝处理。"""
@ -79,20 +69,10 @@ class SlackMessageAdapter(MessageAdapter):
signing_secret: str,
verification_token: str | None = None,
) -> None:
super().__init__()
self.bot_token = bot_token
self.signing_secret = signing_secret
self.verification_token = verification_token
# 懒加载 httpx 客户端
self._client: httpx.AsyncClient | None = None
# ------------------------------------------------------------------
# httpx 客户端懒加载
# ------------------------------------------------------------------
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=10.0)
return self._client
# ------------------------------------------------------------------
# 签名验证
@ -111,8 +91,8 @@ class SlackMessageAdapter(MessageAdapter):
Returns:
True 表示签名校验通过
"""
signature = _header_get(headers, "X-Slack-Signature")
timestamp_str = _header_get(headers, "X-Slack-Request-Timestamp")
signature = header_get(headers, "X-Slack-Signature")
timestamp_str = header_get(headers, "X-Slack-Request-Timestamp")
if not signature or not timestamp_str:
return False
@ -259,29 +239,3 @@ class SlackMessageAdapter(MessageAdapter):
except httpx.HTTPError as exc:
logger.error("Slack send_message 网络错误: %s", exc)
return False
# ------------------------------------------------------------------
# 资源释放
# ------------------------------------------------------------------
async def close(self) -> None:
"""关闭 httpx 客户端(如已创建)。"""
if self._client is not None:
await self._client.aclose()
self._client = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _header_get(headers: dict[str, str], name: str) -> str | None:
"""大小写不敏感的 header 查找。"""
if name in headers:
return headers[name]
lower = name.lower()
for k, v in headers.items():
if k.lower() == lower:
return v
return None

View File

@ -29,6 +29,7 @@ from agentkit.channels.base import (
IncomingMessage,
MessageAdapter,
OutgoingMessage,
header_get,
)
logger = logging.getLogger(__name__)
@ -80,25 +81,15 @@ class WeComMessageAdapter(MessageAdapter):
token: str,
encoding_aes_key: str,
) -> None:
super().__init__()
self.corp_id = corp_id
self.agent_id = agent_id
self.corp_secret = corp_secret
self.token = token
self.encoding_aes_key = encoding_aes_key
# 懒加载 httpx 客户端
self._client: httpx.AsyncClient | None = None
# ponytail: 简单 TTL 缓存。天花板单实例内存升级路径Redis 共享。
self._token_cache: tuple[str, float] | None = None
# ------------------------------------------------------------------
# httpx 客户端懒加载
# ------------------------------------------------------------------
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=10.0)
return self._client
# ------------------------------------------------------------------
# AES 密钥 / 加解密
# ------------------------------------------------------------------
@ -209,9 +200,9 @@ class WeComMessageAdapter(MessageAdapter):
Returns:
True 表示签名校验通过
"""
msg_signature = _header_get(headers, "msg_signature")
timestamp = _header_get(headers, "timestamp")
nonce = _header_get(headers, "nonce")
msg_signature = header_get(headers, "msg_signature")
timestamp = header_get(headers, "timestamp")
nonce = header_get(headers, "nonce")
if not msg_signature or not timestamp or not nonce:
return False
@ -247,8 +238,8 @@ class WeComMessageAdapter(MessageAdapter):
# URL 验证流程 — 内部 XML 包含 EchoStr
if "EchoStr" in inner:
timestamp = _header_get(headers, "timestamp") or ""
nonce = _header_get(headers, "nonce") or ""
timestamp = header_get(headers, "timestamp") or ""
nonce = header_get(headers, "nonce") or ""
encrypted_echo = self._encrypt(inner["EchoStr"])
# 计算响应签名
sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo])
@ -362,29 +353,3 @@ class WeComMessageAdapter(MessageAdapter):
except httpx.HTTPError as exc:
logger.error("企微 access_token 网络错误: %s", exc)
return None
# ------------------------------------------------------------------
# 资源释放
# ------------------------------------------------------------------
async def close(self) -> None:
"""关闭 httpx 客户端(如已创建)。"""
if self._client is not None:
await self._client.aclose()
self._client = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _header_get(headers: dict[str, str], name: str) -> str | None:
"""大小写不敏感的 header 查找。"""
if name in headers:
return headers[name]
lower = name.lower()
for k, v in headers.items():
if k.lower() == lower:
return v
return None

View File

@ -803,18 +803,16 @@ class LitellmCacheManager:
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
return {"no-cache": True}
def detect_cache_hit(self, response: Any) -> bool:
"""检测 LiteLLM 响应是否为缓存命中
def record_cache_result(self, is_hit: bool) -> None:
"""记录单次 LLM 调用的缓存命中/未命中(用于 usage tracking 统计)
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``
命中判定由调用方完成gateway 通过 ``response.cache_hit`` 判定
本方法只负责更新计数器避免重复检测逻辑
"""
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict):
if "cache_key" in hidden or hidden.get("cache_hit"):
self._hits += 1
return True
self._misses += 1
return False
if is_hit:
self._hits += 1
else:
self._misses += 1
def stats(self) -> dict[str, int]:
"""返回缓存统计。"""

View File

@ -224,10 +224,8 @@ class LLMGateway:
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0
is_cache_hit = getattr(response, "cache_hit", False)
if is_cache_hit and self._cache_manager is not None:
self._cache_manager._hits += 1
elif self._cache_manager is not None:
self._cache_manager._misses += 1
if self._cache_manager is not None:
self._cache_manager.record_cache_result(is_cache_hit)
# 计算成本(缓存命中时 cost=0
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)

View File

@ -119,7 +119,6 @@ class LitellmProvider(LLMProvider):
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
final_usage: TokenUsage | None = None
final_model: str = request.model
yielded_any = False
try:
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
@ -129,7 +128,6 @@ class LitellmProvider(LLMProvider):
raw = litellm.acompletion(**kwargs)
stream = await raw if inspect.isawaitable(raw) else raw
async for chunk in stream:
yielded_any = True
parsed = self._parse_stream_chunk(
chunk,
request.model,
@ -156,10 +154,6 @@ class LitellmProvider(LLMProvider):
except Exception as e:
raise LLMProviderError(self._provider_type, str(e)) from e
# ponytail: 若流完全为空yielded_any=False上面仍会 yield 一个
# is_final=True 的空 chunk调用方据此判断空响应。无需额外分支。
_ = yielded_any # 标记保留(调试 / 未来扩展)
# ------------------------------------------------------------------
# 内部辅助
# ------------------------------------------------------------------
@ -184,7 +178,7 @@ class LitellmProvider(LLMProvider):
if request.timeout is not None:
kwargs["timeout"] = request.timeout
# U17 — 透传 LiteLLM cache 参数cache_key 或 no-cache到 litellm.acompletion
cache_params = getattr(request, "_cache", None)
cache_params = request._cache
if cache_params is not None:
kwargs["cache"] = cache_params
# 合并构造时传入的默认 kwargs如 max_connections 等provider特定参数

View File

@ -63,6 +63,9 @@ class MCPClient:
self._timeout = timeout
self._tools_cache: list[dict] | None = None
self._transport = transport
# U10 — 懒构造并缓存的 langchain client避免每次 list_tools/call_tool
# 都新建 MultiServerMCPClientstdio 传输下会反复 spawn 子进程)。
self._lc_client: Any = None
if transport is not None:
# 旧 Transport 路径 — 发出 DeprecationWarning但保持原有行为
@ -138,15 +141,34 @@ class MCPClient:
发出 DeprecationWarning建议迁移到 URL scheme 自动检测
"""
if isinstance(transport, HTTPTransport):
server_url = transport._endpoint
server_url = transport.endpoint_url
elif isinstance(transport, SSETransport):
server_url = transport._endpoint
server_url = transport.endpoint_url
elif isinstance(transport, StdioTransport):
server_url = f"stdio://{transport._command}"
else:
server_url = ""
return cls(server_url=server_url, transport=transport)
async def _get_lc_client(self) -> Any:
"""懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。
首次调用时创建后续返回缓存避免每次 list_tools/call_tool 都新建
clientstdio 传输下会 spawn 新子进程造成连接/进程泄漏
"""
if self._lc_client is None:
client_cls = _import_langchain_client()
self._lc_client = client_cls({"server": self._langchain_config})
return self._lc_client
async def aclose(self) -> None:
"""关闭缓存的 langchain client如果它提供 ``aclose`` 方法)。"""
if self._lc_client is not None:
aclose = getattr(self._lc_client, "aclose", None)
if aclose is not None:
await aclose()
self._lc_client = None
async def list_tools(self) -> list[dict]:
"""列出远程 MCP Server 上的工具
@ -165,9 +187,8 @@ class MCPClient:
self._tools_cache = tools
return self._tools_cache
# 新 langchain 路径
client_cls = _import_langchain_client()
client = client_cls({"server": self._langchain_config})
# 新 langchain 路径 — 复用缓存的 client
client = await self._get_lc_client()
lc_tools = await client.get_tools()
tools = [
{
@ -218,9 +239,8 @@ class MCPClient:
params={"name": tool_name, "arguments": arguments},
)
# 新 langchain 路径
client_cls = _import_langchain_client()
client = client_cls({"server": self._langchain_config})
# 新 langchain 路径 — 复用缓存的 client
client = await self._get_lc_client()
lc_tools = await client.get_tools()
for tool in lc_tools:
if tool.name == tool_name:

View File

@ -230,7 +230,6 @@ class MCPServer:
try:
from fastapi import FastAPI
from fastapi import Request # noqa: F401 — 用于 jsonrpc_endpoint 的类型注解
except ImportError:
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")

View File

@ -98,6 +98,11 @@ class HTTPTransport(Transport):
def is_connected(self) -> bool:
return self._client is not None and not self._client.is_closed
@property
def endpoint_url(self) -> str:
"""已配置的端点 URL去掉尾部斜杠"""
return self._endpoint
async def connect(self) -> None:
"""建立 HTTP 连接"""
if self.is_connected:
@ -220,6 +225,11 @@ class SSETransport(Transport):
def is_connected(self) -> bool:
return self._connected and self._client is not None and not self._client.is_closed
@property
def endpoint_url(self) -> str:
"""已配置的端点 URL去掉尾部斜杠"""
return self._endpoint
async def connect(self) -> None:
"""建立 SSE 连接

View File

@ -21,21 +21,25 @@ import asyncio
import logging
import re
import time
from collections import OrderedDict
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from agentkit.channels.base import ChannelType, MessageAdapter, OutgoingMessage
from agentkit.channels.dingtalk import DingTalkMessageAdapter
from agentkit.channels.feishu import FeishuMessageAdapter, URLVerificationChallenge
from agentkit.channels.secrets import SecretsStore
from agentkit.channels.slack import (
SlackMessageAdapter,
URLVerificationChallenge as SlackURLVerificationChallenge,
from agentkit.channels.base import (
ChannelType,
MessageAdapter,
OutgoingMessage,
URLVerificationChallenge,
)
from agentkit.channels.dingtalk import DingTalkMessageAdapter
from agentkit.channels.feishu import FeishuMessageAdapter
from agentkit.channels.secrets import SecretsStore
from agentkit.channels.slack import SlackMessageAdapter
from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification
from agentkit.chat.skill_routing import ExecutionMode
from agentkit.server.auth.dependencies import require_permission
from agentkit.server.auth.permissions import Permission
@ -89,16 +93,40 @@ _RATE_LIMIT_WINDOW = 60.0 # 窗口大小(秒)
_RATE_LIMIT_MAX = 100 # 窗口内最大请求数
# nonce -> 过期时间戳(与飞书签名时间戳窗口一致)
_seen_nonces: dict[str, float] = {}
# OrderedDict按插入顺序遍历清理过期项时从头部弹出O(1) 摊销。
_seen_nonces: OrderedDict[str, float] = OrderedDict()
_NONCE_TTL = 300.0
# Webhook 后台处理并发上限 — 防止高流量下 LLM 调用无界并发
_WEBHOOK_MAX_CONCURRENT = 20
_webhook_semaphore: asyncio.Semaphore | None = None
# 持有后台任务引用,防止 GC 回收正在运行的 task
_pending_webhook_tasks: set[asyncio.Task[None]] = set()
# 适配器缓存 — channel_id -> adapter。避免 per-request 重建导致 token TTL 缓存失效。
# 配置变更PUT/DELETE channel时通过 _invalidate_adapter_cache 清除对应条目。
_adapter_cache: dict[str, MessageAdapter] = {}
def _get_webhook_semaphore() -> asyncio.Semaphore:
"""懒构造 webhook 并发信号量(事件循环首次需要时创建)。"""
global _webhook_semaphore
if _webhook_semaphore is None:
_webhook_semaphore = asyncio.Semaphore(_WEBHOOK_MAX_CONCURRENT)
return _webhook_semaphore
def _check_rate_limit(client_ip: str) -> bool:
"""滑动窗口限流。返回 True 表示放行False 表示超限。"""
"""滑动窗口限流。返回 True 表示放行False 表示超限。
ponytail: 过滤后 timestamps 为空时清除 IP 条目避免 _rate_limits 无界增长
"""
now = time.monotonic()
timestamps = _rate_limits.get(client_ip, [])
cutoff = now - _RATE_LIMIT_WINDOW
timestamps = [t for t in timestamps if t > cutoff]
timestamps = [t for t in _rate_limits.get(client_ip, []) if t > cutoff]
if not timestamps:
# 清理过期空条目 — 不活跃 IP 不再占用 dict 槽位
_rate_limits.pop(client_ip, None)
if len(timestamps) >= _RATE_LIMIT_MAX:
_rate_limits[client_ip] = timestamps
return False
@ -108,12 +136,20 @@ def _check_rate_limit(client_ip: str) -> bool:
def _check_nonce_dedup(nonce: str) -> bool:
"""Nonce 去重。返回 True 表示新 nonce应处理False 表示重复(跳过)。"""
"""Nonce 去重。返回 True 表示新 nonce应处理False 表示重复(跳过)。
使用 OrderedDict 按插入顺序清理过期项从头部弹出已过期条目
O(1) 摊销旧实现遍历整个字典为 O(N)fail-closed 语义过期 nonce 拒绝
"""
now = time.monotonic()
# 惰性清理过期项
expired = [k for k, v in _seen_nonces.items() if v < now]
for k in expired:
del _seen_nonces[k]
# 从头部弹出过期项nonce 按单调时间插入,头部最旧即最先过期)。
# ponytail: 摊销 O(1) 清理。旧实现遍历整个 dict 为 O(N)。
while _seen_nonces:
_oldest_key, oldest_expiry = next(iter(_seen_nonces.items()))
if oldest_expiry >= now:
break
_seen_nonces.popitem(last=False)
# 过期 nonce 已被清理 — 视为新 nonce与旧实现先删全部过期项再查一致
if nonce in _seen_nonces:
return False
_seen_nonces[nonce] = now + _NONCE_TTL
@ -124,6 +160,8 @@ def _reset_webhook_state() -> None:
"""重置限流与 nonce 状态(仅供测试使用)。"""
_rate_limits.clear()
_seen_nonces.clear()
_adapter_cache.clear()
_pending_webhook_tasks.clear()
def _validate_channel_id(channel_id: str) -> str:
@ -288,6 +326,9 @@ async def update_channel(
if payload.config is not None:
cfg["config"] = payload.config
# 配置变更 — 清除缓存适配器,下次 webhook 重建(凭证可能已更新)
await _invalidate_adapter_cache(channel_id)
return ChannelInfo(
channel_id=channel_id,
channel_type=ChannelType(cfg["channel_type"]),
@ -311,6 +352,9 @@ async def delete_channel(
for name in cfg.get("secret_keys", []):
await store.delete_secret(f"{channel_id}:{name}")
# 清除缓存适配器并关闭旧实例
await _invalidate_adapter_cache(channel_id)
return {"deleted": channel_id}
@ -320,13 +364,20 @@ async def delete_channel(
async def _build_adapter(channel_id: str) -> MessageAdapter:
"""根据渠道配置与 secrets 构造适配器实例
"""根据渠道配置与 secrets 构造适配器实例(带缓存)
支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型 ``channel_type`` 分发
首次构造后缓存到 ``_adapter_cache``后续请求复用同一实例token TTL 缓存命中
secrets 获取使用 ``asyncio.gather`` 并行避免串行 await
Raises:
HTTPException: 渠道不存在404渠道类型不支持400缺少必要凭证500
"""
# 命中缓存直接返回per-request 重建会使 token TTL 缓存失效)
cached = _adapter_cache.get(channel_id)
if cached is not None:
return cached
cfg = _channels.get(channel_id)
if cfg is None:
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
@ -335,43 +386,53 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
channel_type = cfg["channel_type"]
if channel_type == ChannelType.FEISHU.value:
app_id = await store.get_secret(f"{channel_id}:app_id")
app_secret = await store.get_secret(f"{channel_id}:app_secret")
encrypt_key = await store.get_secret(f"{channel_id}:encrypt_key")
verification_token = await store.get_secret(f"{channel_id}:verification_token")
app_id, app_secret, encrypt_key, verification_token = await asyncio.gather(
store.get_secret(f"{channel_id}:app_id"),
store.get_secret(f"{channel_id}:app_secret"),
store.get_secret(f"{channel_id}:encrypt_key"),
store.get_secret(f"{channel_id}:verification_token"),
)
if not app_id or not app_secret:
raise HTTPException(
status_code=500, detail=f"渠道 '{channel_id}' 缺少 app_id 或 app_secret"
)
return FeishuMessageAdapter(
adapter = FeishuMessageAdapter(
app_id=app_id,
app_secret=app_secret,
encrypt_key=encrypt_key,
verification_token=verification_token,
)
_adapter_cache[channel_id] = adapter
return adapter
if channel_type == ChannelType.DINGTALK.value:
app_key = await store.get_secret(f"{channel_id}:app_key")
app_secret = await store.get_secret(f"{channel_id}:app_secret")
robot_code = await store.get_secret(f"{channel_id}:robot_code")
token = await store.get_secret(f"{channel_id}:token")
app_key, app_secret, robot_code, token = await asyncio.gather(
store.get_secret(f"{channel_id}:app_key"),
store.get_secret(f"{channel_id}:app_secret"),
store.get_secret(f"{channel_id}:robot_code"),
store.get_secret(f"{channel_id}:token"),
)
if not all([app_key, app_secret, robot_code]):
raise HTTPException(
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
)
return DingTalkMessageAdapter(
adapter = DingTalkMessageAdapter(
app_key=app_key,
app_secret=app_secret,
robot_code=robot_code,
token=token,
)
_adapter_cache[channel_id] = adapter
return adapter
if channel_type == ChannelType.WECOM.value:
corp_id = await store.get_secret(f"{channel_id}:corp_id")
corp_secret = await store.get_secret(f"{channel_id}:corp_secret")
token = await store.get_secret(f"{channel_id}:token")
encoding_aes_key = await store.get_secret(f"{channel_id}:encoding_aes_key")
agent_id_raw = await store.get_secret(f"{channel_id}:agent_id")
corp_id, corp_secret, token, encoding_aes_key, agent_id_raw = await asyncio.gather(
store.get_secret(f"{channel_id}:corp_id"),
store.get_secret(f"{channel_id}:corp_secret"),
store.get_secret(f"{channel_id}:token"),
store.get_secret(f"{channel_id}:encoding_aes_key"),
store.get_secret(f"{channel_id}:agent_id"),
)
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
raise HTTPException(
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
@ -383,98 +444,124 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
status_code=500,
detail=f"渠道 '{channel_id}' agent_id 不是合法整数",
) from exc
return WeComMessageAdapter(
adapter = WeComMessageAdapter(
corp_id=corp_id,
agent_id=agent_id,
corp_secret=corp_secret,
token=token,
encoding_aes_key=encoding_aes_key,
)
_adapter_cache[channel_id] = adapter
return adapter
if channel_type == ChannelType.SLACK.value:
bot_token = await store.get_secret(f"{channel_id}:bot_token")
signing_secret = await store.get_secret(f"{channel_id}:signing_secret")
verification_token = await store.get_secret(f"{channel_id}:verification_token")
bot_token, signing_secret, verification_token = await asyncio.gather(
store.get_secret(f"{channel_id}:bot_token"),
store.get_secret(f"{channel_id}:signing_secret"),
store.get_secret(f"{channel_id}:verification_token"),
)
if not bot_token or not signing_secret:
raise HTTPException(
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
)
return SlackMessageAdapter(
adapter = SlackMessageAdapter(
bot_token=bot_token,
signing_secret=signing_secret,
verification_token=verification_token,
)
_adapter_cache[channel_id] = adapter
return adapter
raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}")
async def _invalidate_adapter_cache(channel_id: str) -> None:
"""清除指定渠道的缓存适配器并关闭旧实例(配置变更/删除时调用)。"""
old = _adapter_cache.pop(channel_id, None)
if old is not None:
try:
await old.close()
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
async def close_all_adapters() -> None:
"""关闭所有缓存的适配器(供 app shutdown 调用)。"""
for channel_id, adapter in list(_adapter_cache.items()):
try:
await adapter.close()
except Exception: # noqa: BLE001
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
_adapter_cache.clear()
async def _process_inbound_message(
app_state: Any, adapter: MessageAdapter, message: Any
) -> None:
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。
整个流程 try/except 包裹任何异常仅记录日志不向上抛出
webhook 必须保持响应能力``adapter.close()`` finally 中调用
webhook 必须保持响应能力处理逻辑受全局信号量限流
``_WEBHOOK_MAX_CONCURRENT``防止高流量下 LLM 调用无界并发
适配器由 ``_adapter_cache`` 管理不在 per-request 关闭关闭会清空 token 缓存
适配器类型不限 出站消息的 ``channel`` 取自入站消息以匹配平台
"""
try:
request_preprocessor = getattr(app_state, "request_preprocessor", None)
llm_gateway = getattr(app_state, "llm_gateway", None)
if request_preprocessor is None or llm_gateway is None:
logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理")
return
# 路由预处理 — IM 场景使用默认 agent无需技能注册表
routing = await request_preprocessor.preprocess(
content=message.content, default_agent_name="default"
)
final_content = ""
execution_mode = getattr(routing, "execution_mode", None)
# DIRECT_CHAT 模式 — 直接调用 LLM
if execution_mode is not None and execution_mode.value == "direct_chat":
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
final_content = response.content
else:
# REACT 或其他模式 — 优先使用 ReActEngine失败回退到 DIRECT_CHAT
try:
from agentkit.core.react import ReActEngine
engine = ReActEngine(llm_gateway=llm_gateway)
result = await engine.execute(
messages=[{"role": "user", "content": message.content}],
tools=getattr(routing, "tools", None) or None,
model=routing.model or "default",
async with _get_webhook_semaphore():
try:
request_preprocessor = getattr(app_state, "request_preprocessor", None)
llm_gateway = getattr(app_state, "llm_gateway", None)
if request_preprocessor is None or llm_gateway is None:
logger.warning(
"app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理"
)
final_content = getattr(result, "content", "") or ""
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
return
# 路由预处理 — IM 场景使用默认 agent无需技能注册表
routing = await request_preprocessor.preprocess(
content=message.content, default_agent_name="default"
)
final_content = ""
execution_mode = getattr(routing, "execution_mode", None)
# DIRECT_CHAT 模式 — 直接调用 LLM
if execution_mode == ExecutionMode.DIRECT_CHAT:
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
final_content = response.content
else:
# REACT 或其他模式 — 优先使用 ReActEngine失败回退到 DIRECT_CHAT
try:
from agentkit.core.react import ReActEngine
if not final_content:
logger.warning("消息处理未产生内容 — 不发送回复")
return
engine = ReActEngine(llm_gateway=llm_gateway)
result = await engine.execute(
messages=[{"role": "user", "content": message.content}],
tools=getattr(routing, "tools", None) or None,
model=routing.model or "default",
)
final_content = getattr(result, "content", "") or ""
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
final_content = response.content
outgoing = OutgoingMessage(
channel=message.channel,
chat_id=message.chat_id,
content=final_content,
)
await adapter.send_message(outgoing)
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
logger.exception("处理入站消息失败: %s", exc)
finally:
try:
await adapter.close()
except Exception: # noqa: BLE001
logger.debug("adapter.close() 异常已忽略")
if not final_content:
logger.warning("消息处理未产生内容 — 不发送回复")
return
outgoing = OutgoingMessage(
channel=message.channel,
chat_id=message.chat_id,
content=final_content,
)
await adapter.send_message(outgoing)
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
logger.exception("处理入站消息失败: %s", exc)
def _get_client_ip(request: Request) -> str:
@ -527,7 +614,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
try:
message = await adapter.receive_message(headers_dict, body)
except (URLVerificationChallenge, SlackURLVerificationChallenge) as e:
except URLVerificationChallenge as e:
# URL 验证流程 — 飞书 / Slack 配置 webhook 时发送
return {"challenge": e.challenge}
except WeComURLVerification as e:
@ -535,6 +622,9 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
return Response(content=e.response_xml, media_type="application/xml")
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200
asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
# 持有 task 引用防止 GC 回收正在运行的后台任务
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
_pending_webhook_tasks.add(task)
task.add_done_callback(_pending_webhook_tasks.discard)
return {"code": 0}

View File

@ -9,7 +9,7 @@
6. kb_acl_hash 隔离 不同 ACL hash 产生不同 key
7. kb_caching_disabled 禁用缓存安全要求 c
8. cache_params_for_hit / no_cache 返回正确 dict
9. detect_cache_hit _hidden_params cache_key 时返回 True
9. record_cache_result 记录命中/未命中到 stats 计数器
10. LitellmCacheConfig.from_cache_config 转换正确similarity_threshold=0.87
11. LitellmCacheManager.enable/disable litellm.cache 正确设置/清除
12. generate_cache_key 向后兼容 user_id=None, kb_acl_hash=None 时与旧版相同
@ -163,10 +163,10 @@ class TestCacheStats:
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
# 2 hits
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
manager.record_cache_result(True)
manager.record_cache_result(True)
# 1 miss
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
manager.record_cache_result(False)
stats = manager.stats()
assert stats["total_hits"] == 2
@ -302,39 +302,6 @@ class TestCacheParams:
assert params == {"no-cache": True}
# ---------------------------------------------------------------------------
# 9. detect_cache_hit
# ---------------------------------------------------------------------------
class TestDetectCacheHit:
def test_hit_with_cache_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response(cache_key="some_key")
assert manager.detect_cache_hit(resp) is True
def test_hit_with_cache_hit_flag(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response()
resp._hidden_params = {"cache_hit": True}
assert manager.detect_cache_hit(resp) is True
def test_miss_without_cache_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response()
assert manager.detect_cache_hit(resp) is False
def test_miss_with_no_hidden_params(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = SimpleNamespace(_hidden_params=None)
assert manager.detect_cache_hit(resp) is False
def test_miss_with_no_hidden_params_attr(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = SimpleNamespace()
assert manager.detect_cache_hit(resp) is False
# ---------------------------------------------------------------------------
# 10. LitellmCacheConfig.from_cache_config
# ---------------------------------------------------------------------------