refactor: ce-simplify-code 审查修复 — 去重 + 效率 + 死代码清理
3 个审查代理(复用/质量/效率)发现 15 个问题,全部修复: 效率与安全(6 项): - MCPClient 缓存 MultiServerMCPClient 单例 + aclose(),修复连接/子进程泄漏 - _rate_limits 清理空 IP 条目,修复 X-Forwarded-For 欺骗下内存泄漏 - _seen_nonces 改用 OrderedDict,O(1) 摊销过期清理 - webhook 后台任务加 Semaphore(20) + 任务引用追踪,限制无界并发 - _build_adapter 用 asyncio.gather 并行解密 secrets - 适配器实例缓存(_adapter_cache),token TTL 缓存跨请求命中 去重(4 项): - header_get 提取到 channels/base.py,4 个适配器统一 import - _get_client/close() 移入 MessageAdapter 基类,子类继承 - URLVerificationChallenge 统一到 base.py,feishu/slack/wecom 共用 - Transport ABC 添加 endpoint_url 属性,from_transport 不再访问私有字段 死代码与类型安全(5 项): - detect_cache_hit 死方法替换为 record_cache_result 公开 API - execution_mode.value == "direct_chat" 改用枚举比较 - 删除 yielded_any 死变量、重复 from fastapi import Request、 多余 getattr 防御 453 tests passed, ruff clean(预存 F841 非本次引入)
This commit is contained in:
parent
793476cafa
commit
1ccaf56b9a
|
|
@ -11,6 +11,8 @@ from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class ChannelType(str, Enum):
|
class ChannelType(str, Enum):
|
||||||
"""支持的消息平台渠道。"""
|
"""支持的消息平台渠道。"""
|
||||||
|
|
@ -51,17 +53,42 @@ class OutgoingMessage:
|
||||||
reply_to_message_id: str | None = None
|
reply_to_message_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class URLVerificationChallenge(Exception):
|
||||||
|
"""URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
||||||
|
|
||||||
|
飞书 / Slack 在配置 webhook 时发送一次 ``url_verification`` 事件,
|
||||||
|
要求服务端原样返回 ``challenge`` 字段以验证 URL 可达。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, challenge: str) -> None:
|
||||||
|
super().__init__(f"URL verification challenge: {challenge}")
|
||||||
|
self.challenge = challenge
|
||||||
|
|
||||||
|
|
||||||
class MessageAdapter(abc.ABC):
|
class MessageAdapter(abc.ABC):
|
||||||
"""消息适配器 ABC。
|
"""消息适配器 ABC。
|
||||||
|
|
||||||
生命周期:
|
生命周期:
|
||||||
__init__ → verify_signature() → receive_message() → send_message() → close()
|
__init__ → verify_signature() → receive_message() → send_message() → close()
|
||||||
|
|
||||||
子类必须实现全部抽象方法。verify_signature 失败时调用方应拒绝处理
|
子类必须实现 verify_signature / receive_message / send_message 抽象方法。
|
||||||
(webhook 端点 fail-closed:Redis 不可用或签名校验失败均返回 503/401,
|
``_get_client`` / ``close`` 已在基类提供(懒构造 httpx 客户端 + 释放资源),
|
||||||
不可跳过 nonce dedup 直接处理消息)。
|
子类只需在 ``__init__`` 中调用 ``super().__init__()``。
|
||||||
|
|
||||||
|
verify_signature 失败时调用方应拒绝处理(webhook 端点 fail-closed:
|
||||||
|
Redis 不可用或签名校验失败均返回 503/401,不可跳过 nonce dedup 直接处理消息)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""懒构造 httpx 客户端(子类共享此实现)。"""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=10.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||||
"""验证平台签名/token。返回 True 表示请求可信。"""
|
"""验证平台签名/token。返回 True 表示请求可信。"""
|
||||||
|
|
@ -74,6 +101,25 @@ class MessageAdapter(abc.ABC):
|
||||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||||
"""向平台发送消息。返回 True 表示发送成功。"""
|
"""向平台发送消息。返回 True 表示发送成功。"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""释放资源(HTTP 客户端、连接池等)。"""
|
"""关闭 httpx 客户端(如已创建)。子类可覆盖以释放额外资源。"""
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 辅助函数
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def header_get(headers: dict[str, str], name: str) -> str | None:
|
||||||
|
"""大小写不敏感的 header 查找。"""
|
||||||
|
# 直接命中
|
||||||
|
if name in headers:
|
||||||
|
return headers[name]
|
||||||
|
lower = name.lower()
|
||||||
|
for k, v in headers.items():
|
||||||
|
if k.lower() == lower:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from agentkit.channels.base import (
|
||||||
IncomingMessage,
|
IncomingMessage,
|
||||||
MessageAdapter,
|
MessageAdapter,
|
||||||
OutgoingMessage,
|
OutgoingMessage,
|
||||||
|
header_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -68,25 +69,15 @@ class DingTalkMessageAdapter(MessageAdapter):
|
||||||
robot_code: str,
|
robot_code: str,
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.app_key = app_key
|
self.app_key = app_key
|
||||||
self.app_secret = app_secret
|
self.app_secret = app_secret
|
||||||
self.robot_code = robot_code
|
self.robot_code = robot_code
|
||||||
self.token = token
|
self.token = token
|
||||||
# 懒加载 httpx 客户端
|
|
||||||
self._client: httpx.AsyncClient | None = None
|
|
||||||
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
||||||
# 升级路径:Redis 缓存共享给多实例。
|
# 升级路径:Redis 缓存共享给多实例。
|
||||||
self._token_cache: tuple[str, float] | None = None
|
self._token_cache: tuple[str, float] | None = None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# httpx 客户端懒加载
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.AsyncClient(timeout=10.0)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 签名验证
|
# 签名验证
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -107,12 +98,12 @@ class DingTalkMessageAdapter(MessageAdapter):
|
||||||
"""
|
"""
|
||||||
# Token 校验(若配置)
|
# Token 校验(若配置)
|
||||||
if self.token is not None:
|
if self.token is not None:
|
||||||
token_header = _header_get(headers, "Token")
|
token_header = header_get(headers, "Token")
|
||||||
if token_header is None or not hmac.compare_digest(token_header, self.token):
|
if token_header is None or not hmac.compare_digest(token_header, self.token):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
sign = _header_get(headers, "Sign")
|
sign = header_get(headers, "Sign")
|
||||||
timestamp_str = _header_get(headers, "Timestamp")
|
timestamp_str = header_get(headers, "Timestamp")
|
||||||
|
|
||||||
if sign is None and timestamp_str is None:
|
if sign is None and timestamp_str is None:
|
||||||
# 无签名头:仅当 token 已校验通过才放行
|
# 无签名头:仅当 token 已校验通过才放行
|
||||||
|
|
@ -253,29 +244,3 @@ class DingTalkMessageAdapter(MessageAdapter):
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
logger.error("钉钉 accessToken 网络错误: %s", exc)
|
logger.error("钉钉 accessToken 网络错误: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 资源释放
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""关闭 httpx 客户端(如已创建)。"""
|
|
||||||
if self._client is not None:
|
|
||||||
await self._client.aclose()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# 辅助函数
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
|
||||||
"""大小写不敏感的 header 查找。"""
|
|
||||||
if name in headers:
|
|
||||||
return headers[name]
|
|
||||||
lower = name.lower()
|
|
||||||
for k, v in headers.items():
|
|
||||||
if k.lower() == lower:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ from agentkit.channels.base import (
|
||||||
IncomingMessage,
|
IncomingMessage,
|
||||||
MessageAdapter,
|
MessageAdapter,
|
||||||
OutgoingMessage,
|
OutgoingMessage,
|
||||||
|
URLVerificationChallenge,
|
||||||
|
header_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -45,18 +47,6 @@ _SEND_MESSAGE_URL = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||||
_MENTION_RE = re.compile(r"@_user_\d+\s*")
|
_MENTION_RE = re.compile(r"@_user_\d+\s*")
|
||||||
|
|
||||||
|
|
||||||
class URLVerificationChallenge(Exception):
|
|
||||||
"""飞书 URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
|
||||||
|
|
||||||
飞书在配置 webhook 时会发送一次 ``url_verification`` 事件,要求服务端
|
|
||||||
原样返回 ``challenge`` 字段以验证 URL 可达。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, challenge: str) -> None:
|
|
||||||
super().__init__(f"URL verification challenge: {challenge}")
|
|
||||||
self.challenge = challenge
|
|
||||||
|
|
||||||
|
|
||||||
class SignatureVerificationError(Exception):
|
class SignatureVerificationError(Exception):
|
||||||
"""事件 ``verification_token`` 校验失败 — 拒绝处理。"""
|
"""事件 ``verification_token`` 校验失败 — 拒绝处理。"""
|
||||||
|
|
||||||
|
|
@ -82,25 +72,15 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
encrypt_key: str | None = None,
|
encrypt_key: str | None = None,
|
||||||
verification_token: str | None = None,
|
verification_token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self.app_secret = app_secret
|
self.app_secret = app_secret
|
||||||
self.encrypt_key = encrypt_key
|
self.encrypt_key = encrypt_key
|
||||||
self.verification_token = verification_token
|
self.verification_token = verification_token
|
||||||
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
|
|
||||||
self._client: httpx.AsyncClient | None = None
|
|
||||||
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
||||||
# 升级路径:Redis 缓存共享给多实例。
|
# 升级路径:Redis 缓存共享给多实例。
|
||||||
self._token_cache: tuple[str, float] | None = None
|
self._token_cache: tuple[str, float] | None = None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# httpx 客户端懒加载
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.AsyncClient(timeout=10.0)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 签名验证
|
# 签名验证
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -122,12 +102,12 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
logger.warning("飞书适配器未配置 encrypt_key — 拒绝所有 webhook 请求")
|
logger.warning("飞书适配器未配置 encrypt_key — 拒绝所有 webhook 请求")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
signature = _header_get(headers, "X-Lark-Signature")
|
signature = header_get(headers, "X-Lark-Signature")
|
||||||
if not signature:
|
if not signature:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
timestamp_str = _header_get(headers, "X-Lark-Request-Timestamp")
|
timestamp_str = header_get(headers, "X-Lark-Request-Timestamp")
|
||||||
nonce = _header_get(headers, "X-Lark-Request-Nonce")
|
nonce = header_get(headers, "X-Lark-Request-Nonce")
|
||||||
if not timestamp_str or not nonce:
|
if not timestamp_str or not nonce:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -335,30 +315,3 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
logger.error("飞书 tenant_token 网络错误: %s", exc)
|
logger.error("飞书 tenant_token 网络错误: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 资源释放
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""关闭 httpx 客户端(如已创建)。"""
|
|
||||||
if self._client is not None:
|
|
||||||
await self._client.aclose()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# 辅助函数
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
|
||||||
"""大小写不敏感的 header 查找。"""
|
|
||||||
# 直接命中
|
|
||||||
if name in headers:
|
|
||||||
return headers[name]
|
|
||||||
lower = name.lower()
|
|
||||||
for k, v in headers.items():
|
|
||||||
if k.lower() == lower:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ from agentkit.channels.base import (
|
||||||
IncomingMessage,
|
IncomingMessage,
|
||||||
MessageAdapter,
|
MessageAdapter,
|
||||||
OutgoingMessage,
|
OutgoingMessage,
|
||||||
|
URLVerificationChallenge,
|
||||||
|
header_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -44,18 +46,6 @@ _SEND_MESSAGE_URL = "https://slack.com/api/chat.postMessage"
|
||||||
_MENTION_RE = re.compile(r"<@[^>]+>\s*")
|
_MENTION_RE = re.compile(r"<@[^>]+>\s*")
|
||||||
|
|
||||||
|
|
||||||
class URLVerificationChallenge(Exception):
|
|
||||||
"""Slack URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
|
||||||
|
|
||||||
Slack 在配置 Events API 时发送 ``url_verification`` 事件,要求服务端
|
|
||||||
原样返回 ``challenge`` 字段以验证 URL 可达。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, challenge: str) -> None:
|
|
||||||
super().__init__(f"Slack URL verification challenge: {challenge}")
|
|
||||||
self.challenge = challenge
|
|
||||||
|
|
||||||
|
|
||||||
class SignatureVerificationError(Exception):
|
class SignatureVerificationError(Exception):
|
||||||
"""``verification_token`` 校验失败 — 拒绝处理。"""
|
"""``verification_token`` 校验失败 — 拒绝处理。"""
|
||||||
|
|
||||||
|
|
@ -79,20 +69,10 @@ class SlackMessageAdapter(MessageAdapter):
|
||||||
signing_secret: str,
|
signing_secret: str,
|
||||||
verification_token: str | None = None,
|
verification_token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.bot_token = bot_token
|
self.bot_token = bot_token
|
||||||
self.signing_secret = signing_secret
|
self.signing_secret = signing_secret
|
||||||
self.verification_token = verification_token
|
self.verification_token = verification_token
|
||||||
# 懒加载 httpx 客户端
|
|
||||||
self._client: httpx.AsyncClient | None = None
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# httpx 客户端懒加载
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.AsyncClient(timeout=10.0)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 签名验证
|
# 签名验证
|
||||||
|
|
@ -111,8 +91,8 @@ class SlackMessageAdapter(MessageAdapter):
|
||||||
Returns:
|
Returns:
|
||||||
True 表示签名校验通过。
|
True 表示签名校验通过。
|
||||||
"""
|
"""
|
||||||
signature = _header_get(headers, "X-Slack-Signature")
|
signature = header_get(headers, "X-Slack-Signature")
|
||||||
timestamp_str = _header_get(headers, "X-Slack-Request-Timestamp")
|
timestamp_str = header_get(headers, "X-Slack-Request-Timestamp")
|
||||||
if not signature or not timestamp_str:
|
if not signature or not timestamp_str:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -259,29 +239,3 @@ class SlackMessageAdapter(MessageAdapter):
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
logger.error("Slack send_message 网络错误: %s", exc)
|
logger.error("Slack send_message 网络错误: %s", exc)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 资源释放
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""关闭 httpx 客户端(如已创建)。"""
|
|
||||||
if self._client is not None:
|
|
||||||
await self._client.aclose()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# 辅助函数
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
|
||||||
"""大小写不敏感的 header 查找。"""
|
|
||||||
if name in headers:
|
|
||||||
return headers[name]
|
|
||||||
lower = name.lower()
|
|
||||||
for k, v in headers.items():
|
|
||||||
if k.lower() == lower:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from agentkit.channels.base import (
|
||||||
IncomingMessage,
|
IncomingMessage,
|
||||||
MessageAdapter,
|
MessageAdapter,
|
||||||
OutgoingMessage,
|
OutgoingMessage,
|
||||||
|
header_get,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -80,25 +81,15 @@ class WeComMessageAdapter(MessageAdapter):
|
||||||
token: str,
|
token: str,
|
||||||
encoding_aes_key: str,
|
encoding_aes_key: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.corp_id = corp_id
|
self.corp_id = corp_id
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.corp_secret = corp_secret
|
self.corp_secret = corp_secret
|
||||||
self.token = token
|
self.token = token
|
||||||
self.encoding_aes_key = encoding_aes_key
|
self.encoding_aes_key = encoding_aes_key
|
||||||
# 懒加载 httpx 客户端
|
|
||||||
self._client: httpx.AsyncClient | None = None
|
|
||||||
# ponytail: 简单 TTL 缓存。天花板:单实例内存;升级路径:Redis 共享。
|
# ponytail: 简单 TTL 缓存。天花板:单实例内存;升级路径:Redis 共享。
|
||||||
self._token_cache: tuple[str, float] | None = None
|
self._token_cache: tuple[str, float] | None = None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# httpx 客户端懒加载
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.AsyncClient(timeout=10.0)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# AES 密钥 / 加解密
|
# AES 密钥 / 加解密
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -209,9 +200,9 @@ class WeComMessageAdapter(MessageAdapter):
|
||||||
Returns:
|
Returns:
|
||||||
True 表示签名校验通过。
|
True 表示签名校验通过。
|
||||||
"""
|
"""
|
||||||
msg_signature = _header_get(headers, "msg_signature")
|
msg_signature = header_get(headers, "msg_signature")
|
||||||
timestamp = _header_get(headers, "timestamp")
|
timestamp = header_get(headers, "timestamp")
|
||||||
nonce = _header_get(headers, "nonce")
|
nonce = header_get(headers, "nonce")
|
||||||
if not msg_signature or not timestamp or not nonce:
|
if not msg_signature or not timestamp or not nonce:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -247,8 +238,8 @@ class WeComMessageAdapter(MessageAdapter):
|
||||||
|
|
||||||
# URL 验证流程 — 内部 XML 包含 EchoStr
|
# URL 验证流程 — 内部 XML 包含 EchoStr
|
||||||
if "EchoStr" in inner:
|
if "EchoStr" in inner:
|
||||||
timestamp = _header_get(headers, "timestamp") or ""
|
timestamp = header_get(headers, "timestamp") or ""
|
||||||
nonce = _header_get(headers, "nonce") or ""
|
nonce = header_get(headers, "nonce") or ""
|
||||||
encrypted_echo = self._encrypt(inner["EchoStr"])
|
encrypted_echo = self._encrypt(inner["EchoStr"])
|
||||||
# 计算响应签名
|
# 计算响应签名
|
||||||
sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo])
|
sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo])
|
||||||
|
|
@ -362,29 +353,3 @@ class WeComMessageAdapter(MessageAdapter):
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
logger.error("企微 access_token 网络错误: %s", exc)
|
logger.error("企微 access_token 网络错误: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 资源释放
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""关闭 httpx 客户端(如已创建)。"""
|
|
||||||
if self._client is not None:
|
|
||||||
await self._client.aclose()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# 辅助函数
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
|
||||||
"""大小写不敏感的 header 查找。"""
|
|
||||||
if name in headers:
|
|
||||||
return headers[name]
|
|
||||||
lower = name.lower()
|
|
||||||
for k, v in headers.items():
|
|
||||||
if k.lower() == lower:
|
|
||||||
return v
|
|
||||||
return None
|
|
||||||
|
|
|
||||||
|
|
@ -803,18 +803,16 @@ class LitellmCacheManager:
|
||||||
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
||||||
return {"no-cache": True}
|
return {"no-cache": True}
|
||||||
|
|
||||||
def detect_cache_hit(self, response: Any) -> bool:
|
def record_cache_result(self, is_hit: bool) -> None:
|
||||||
"""检测 LiteLLM 响应是否为缓存命中。
|
"""记录单次 LLM 调用的缓存命中/未命中(用于 usage tracking 统计)。
|
||||||
|
|
||||||
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
|
命中判定由调用方完成(gateway 通过 ``response.cache_hit`` 判定),
|
||||||
|
本方法只负责更新计数器,避免重复检测逻辑。
|
||||||
"""
|
"""
|
||||||
hidden = getattr(response, "_hidden_params", None)
|
if is_hit:
|
||||||
if isinstance(hidden, dict):
|
|
||||||
if "cache_key" in hidden or hidden.get("cache_hit"):
|
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
return True
|
else:
|
||||||
self._misses += 1
|
self._misses += 1
|
||||||
return False
|
|
||||||
|
|
||||||
def stats(self) -> dict[str, int]:
|
def stats(self) -> dict[str, int]:
|
||||||
"""返回缓存统计。"""
|
"""返回缓存统计。"""
|
||||||
|
|
|
||||||
|
|
@ -224,10 +224,8 @@ class LLMGateway:
|
||||||
|
|
||||||
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0)
|
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0)
|
||||||
is_cache_hit = getattr(response, "cache_hit", False)
|
is_cache_hit = getattr(response, "cache_hit", False)
|
||||||
if is_cache_hit and self._cache_manager is not None:
|
if self._cache_manager is not None:
|
||||||
self._cache_manager._hits += 1
|
self._cache_manager.record_cache_result(is_cache_hit)
|
||||||
elif self._cache_manager is not None:
|
|
||||||
self._cache_manager._misses += 1
|
|
||||||
|
|
||||||
# 计算成本(缓存命中时 cost=0)
|
# 计算成本(缓存命中时 cost=0)
|
||||||
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
|
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,6 @@ class LitellmProvider(LLMProvider):
|
||||||
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
final_usage: TokenUsage | None = None
|
final_usage: TokenUsage | None = None
|
||||||
final_model: str = request.model
|
final_model: str = request.model
|
||||||
yielded_any = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
|
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
|
||||||
|
|
@ -129,7 +128,6 @@ class LitellmProvider(LLMProvider):
|
||||||
raw = litellm.acompletion(**kwargs)
|
raw = litellm.acompletion(**kwargs)
|
||||||
stream = await raw if inspect.isawaitable(raw) else raw
|
stream = await raw if inspect.isawaitable(raw) else raw
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
yielded_any = True
|
|
||||||
parsed = self._parse_stream_chunk(
|
parsed = self._parse_stream_chunk(
|
||||||
chunk,
|
chunk,
|
||||||
request.model,
|
request.model,
|
||||||
|
|
@ -156,10 +154,6 @@ class LitellmProvider(LLMProvider):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise LLMProviderError(self._provider_type, str(e)) from e
|
raise LLMProviderError(self._provider_type, str(e)) from e
|
||||||
|
|
||||||
# ponytail: 若流完全为空(yielded_any=False),上面仍会 yield 一个
|
|
||||||
# is_final=True 的空 chunk,调用方据此判断空响应。无需额外分支。
|
|
||||||
_ = yielded_any # 标记保留(调试 / 未来扩展)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# 内部辅助
|
# 内部辅助
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -184,7 +178,7 @@ class LitellmProvider(LLMProvider):
|
||||||
if request.timeout is not None:
|
if request.timeout is not None:
|
||||||
kwargs["timeout"] = request.timeout
|
kwargs["timeout"] = request.timeout
|
||||||
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
||||||
cache_params = getattr(request, "_cache", None)
|
cache_params = request._cache
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
kwargs["cache"] = cache_params
|
kwargs["cache"] = cache_params
|
||||||
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,9 @@ class MCPClient:
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._tools_cache: list[dict] | None = None
|
self._tools_cache: list[dict] | None = None
|
||||||
self._transport = transport
|
self._transport = transport
|
||||||
|
# U10 — 懒构造并缓存的 langchain client,避免每次 list_tools/call_tool
|
||||||
|
# 都新建 MultiServerMCPClient(stdio 传输下会反复 spawn 子进程)。
|
||||||
|
self._lc_client: Any = None
|
||||||
|
|
||||||
if transport is not None:
|
if transport is not None:
|
||||||
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
||||||
|
|
@ -138,15 +141,34 @@ class MCPClient:
|
||||||
发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。
|
发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。
|
||||||
"""
|
"""
|
||||||
if isinstance(transport, HTTPTransport):
|
if isinstance(transport, HTTPTransport):
|
||||||
server_url = transport._endpoint
|
server_url = transport.endpoint_url
|
||||||
elif isinstance(transport, SSETransport):
|
elif isinstance(transport, SSETransport):
|
||||||
server_url = transport._endpoint
|
server_url = transport.endpoint_url
|
||||||
elif isinstance(transport, StdioTransport):
|
elif isinstance(transport, StdioTransport):
|
||||||
server_url = f"stdio://{transport._command}"
|
server_url = f"stdio://{transport._command}"
|
||||||
else:
|
else:
|
||||||
server_url = ""
|
server_url = ""
|
||||||
return cls(server_url=server_url, transport=transport)
|
return cls(server_url=server_url, transport=transport)
|
||||||
|
|
||||||
|
async def _get_lc_client(self) -> Any:
|
||||||
|
"""懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。
|
||||||
|
|
||||||
|
首次调用时创建,后续返回缓存,避免每次 list_tools/call_tool 都新建
|
||||||
|
client(stdio 传输下会 spawn 新子进程,造成连接/进程泄漏)。
|
||||||
|
"""
|
||||||
|
if self._lc_client is None:
|
||||||
|
client_cls = _import_langchain_client()
|
||||||
|
self._lc_client = client_cls({"server": self._langchain_config})
|
||||||
|
return self._lc_client
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
"""关闭缓存的 langchain client(如果它提供 ``aclose`` 方法)。"""
|
||||||
|
if self._lc_client is not None:
|
||||||
|
aclose = getattr(self._lc_client, "aclose", None)
|
||||||
|
if aclose is not None:
|
||||||
|
await aclose()
|
||||||
|
self._lc_client = None
|
||||||
|
|
||||||
async def list_tools(self) -> list[dict]:
|
async def list_tools(self) -> list[dict]:
|
||||||
"""列出远程 MCP Server 上的工具
|
"""列出远程 MCP Server 上的工具
|
||||||
|
|
||||||
|
|
@ -165,9 +187,8 @@ class MCPClient:
|
||||||
self._tools_cache = tools
|
self._tools_cache = tools
|
||||||
return self._tools_cache
|
return self._tools_cache
|
||||||
|
|
||||||
# 新 langchain 路径
|
# 新 langchain 路径 — 复用缓存的 client
|
||||||
client_cls = _import_langchain_client()
|
client = await self._get_lc_client()
|
||||||
client = client_cls({"server": self._langchain_config})
|
|
||||||
lc_tools = await client.get_tools()
|
lc_tools = await client.get_tools()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
|
@ -218,9 +239,8 @@ class MCPClient:
|
||||||
params={"name": tool_name, "arguments": arguments},
|
params={"name": tool_name, "arguments": arguments},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 新 langchain 路径
|
# 新 langchain 路径 — 复用缓存的 client
|
||||||
client_cls = _import_langchain_client()
|
client = await self._get_lc_client()
|
||||||
client = client_cls({"server": self._langchain_config})
|
|
||||||
lc_tools = await client.get_tools()
|
lc_tools = await client.get_tools()
|
||||||
for tool in lc_tools:
|
for tool in lc_tools:
|
||||||
if tool.name == tool_name:
|
if tool.name == tool_name:
|
||||||
|
|
|
||||||
|
|
@ -230,7 +230,6 @@ class MCPServer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import Request # noqa: F401 — 用于 jsonrpc_endpoint 的类型注解
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
|
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,11 @@ class HTTPTransport(Transport):
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._client is not None and not self._client.is_closed
|
return self._client is not None and not self._client.is_closed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint_url(self) -> str:
|
||||||
|
"""已配置的端点 URL(去掉尾部斜杠)。"""
|
||||||
|
return self._endpoint
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""建立 HTTP 连接"""
|
"""建立 HTTP 连接"""
|
||||||
if self.is_connected:
|
if self.is_connected:
|
||||||
|
|
@ -220,6 +225,11 @@ class SSETransport(Transport):
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._connected and self._client is not None and not self._client.is_closed
|
return self._connected and self._client is not None and not self._client.is_closed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint_url(self) -> str:
|
||||||
|
"""已配置的端点 URL(去掉尾部斜杠)。"""
|
||||||
|
return self._endpoint
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""建立 SSE 连接
|
"""建立 SSE 连接
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,21 +21,25 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from agentkit.channels.base import ChannelType, MessageAdapter, OutgoingMessage
|
from agentkit.channels.base import (
|
||||||
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
ChannelType,
|
||||||
from agentkit.channels.feishu import FeishuMessageAdapter, URLVerificationChallenge
|
MessageAdapter,
|
||||||
from agentkit.channels.secrets import SecretsStore
|
OutgoingMessage,
|
||||||
from agentkit.channels.slack import (
|
URLVerificationChallenge,
|
||||||
SlackMessageAdapter,
|
|
||||||
URLVerificationChallenge as SlackURLVerificationChallenge,
|
|
||||||
)
|
)
|
||||||
|
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
||||||
|
from agentkit.channels.feishu import FeishuMessageAdapter
|
||||||
|
from agentkit.channels.secrets import SecretsStore
|
||||||
|
from agentkit.channels.slack import SlackMessageAdapter
|
||||||
from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification
|
from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode
|
||||||
from agentkit.server.auth.dependencies import require_permission
|
from agentkit.server.auth.dependencies import require_permission
|
||||||
from agentkit.server.auth.permissions import Permission
|
from agentkit.server.auth.permissions import Permission
|
||||||
|
|
||||||
|
|
@ -89,16 +93,40 @@ _RATE_LIMIT_WINDOW = 60.0 # 窗口大小(秒)
|
||||||
_RATE_LIMIT_MAX = 100 # 窗口内最大请求数
|
_RATE_LIMIT_MAX = 100 # 窗口内最大请求数
|
||||||
|
|
||||||
# nonce -> 过期时间戳(与飞书签名时间戳窗口一致)
|
# nonce -> 过期时间戳(与飞书签名时间戳窗口一致)
|
||||||
_seen_nonces: dict[str, float] = {}
|
# OrderedDict:按插入顺序遍历,清理过期项时从头部弹出,O(1) 摊销。
|
||||||
|
_seen_nonces: OrderedDict[str, float] = OrderedDict()
|
||||||
_NONCE_TTL = 300.0
|
_NONCE_TTL = 300.0
|
||||||
|
|
||||||
|
# Webhook 后台处理并发上限 — 防止高流量下 LLM 调用无界并发
|
||||||
|
_WEBHOOK_MAX_CONCURRENT = 20
|
||||||
|
_webhook_semaphore: asyncio.Semaphore | None = None
|
||||||
|
# 持有后台任务引用,防止 GC 回收正在运行的 task
|
||||||
|
_pending_webhook_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
|
# 适配器缓存 — channel_id -> adapter。避免 per-request 重建导致 token TTL 缓存失效。
|
||||||
|
# 配置变更(PUT/DELETE channel)时通过 _invalidate_adapter_cache 清除对应条目。
|
||||||
|
_adapter_cache: dict[str, MessageAdapter] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_webhook_semaphore() -> asyncio.Semaphore:
|
||||||
|
"""懒构造 webhook 并发信号量(事件循环首次需要时创建)。"""
|
||||||
|
global _webhook_semaphore
|
||||||
|
if _webhook_semaphore is None:
|
||||||
|
_webhook_semaphore = asyncio.Semaphore(_WEBHOOK_MAX_CONCURRENT)
|
||||||
|
return _webhook_semaphore
|
||||||
|
|
||||||
|
|
||||||
def _check_rate_limit(client_ip: str) -> bool:
|
def _check_rate_limit(client_ip: str) -> bool:
|
||||||
"""滑动窗口限流。返回 True 表示放行,False 表示超限。"""
|
"""滑动窗口限流。返回 True 表示放行,False 表示超限。
|
||||||
|
|
||||||
|
ponytail: 过滤后 timestamps 为空时清除 IP 条目,避免 _rate_limits 无界增长。
|
||||||
|
"""
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
timestamps = _rate_limits.get(client_ip, [])
|
|
||||||
cutoff = now - _RATE_LIMIT_WINDOW
|
cutoff = now - _RATE_LIMIT_WINDOW
|
||||||
timestamps = [t for t in timestamps if t > cutoff]
|
timestamps = [t for t in _rate_limits.get(client_ip, []) if t > cutoff]
|
||||||
|
if not timestamps:
|
||||||
|
# 清理过期空条目 — 不活跃 IP 不再占用 dict 槽位
|
||||||
|
_rate_limits.pop(client_ip, None)
|
||||||
if len(timestamps) >= _RATE_LIMIT_MAX:
|
if len(timestamps) >= _RATE_LIMIT_MAX:
|
||||||
_rate_limits[client_ip] = timestamps
|
_rate_limits[client_ip] = timestamps
|
||||||
return False
|
return False
|
||||||
|
|
@ -108,12 +136,20 @@ def _check_rate_limit(client_ip: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _check_nonce_dedup(nonce: str) -> bool:
|
def _check_nonce_dedup(nonce: str) -> bool:
|
||||||
"""Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。"""
|
"""Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。
|
||||||
|
|
||||||
|
使用 OrderedDict 按插入顺序清理过期项:从头部弹出已过期条目,
|
||||||
|
O(1) 摊销(旧实现遍历整个字典为 O(N))。fail-closed 语义:过期 nonce 拒绝。
|
||||||
|
"""
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
# 惰性清理过期项
|
# 从头部弹出过期项(nonce 按单调时间插入,头部最旧即最先过期)。
|
||||||
expired = [k for k, v in _seen_nonces.items() if v < now]
|
# ponytail: 摊销 O(1) 清理。旧实现遍历整个 dict 为 O(N)。
|
||||||
for k in expired:
|
while _seen_nonces:
|
||||||
del _seen_nonces[k]
|
_oldest_key, oldest_expiry = next(iter(_seen_nonces.items()))
|
||||||
|
if oldest_expiry >= now:
|
||||||
|
break
|
||||||
|
_seen_nonces.popitem(last=False)
|
||||||
|
# 过期 nonce 已被清理 — 视为新 nonce(与旧实现先删全部过期项再查一致)
|
||||||
if nonce in _seen_nonces:
|
if nonce in _seen_nonces:
|
||||||
return False
|
return False
|
||||||
_seen_nonces[nonce] = now + _NONCE_TTL
|
_seen_nonces[nonce] = now + _NONCE_TTL
|
||||||
|
|
@ -124,6 +160,8 @@ def _reset_webhook_state() -> None:
|
||||||
"""重置限流与 nonce 状态(仅供测试使用)。"""
|
"""重置限流与 nonce 状态(仅供测试使用)。"""
|
||||||
_rate_limits.clear()
|
_rate_limits.clear()
|
||||||
_seen_nonces.clear()
|
_seen_nonces.clear()
|
||||||
|
_adapter_cache.clear()
|
||||||
|
_pending_webhook_tasks.clear()
|
||||||
|
|
||||||
|
|
||||||
def _validate_channel_id(channel_id: str) -> str:
|
def _validate_channel_id(channel_id: str) -> str:
|
||||||
|
|
@ -288,6 +326,9 @@ async def update_channel(
|
||||||
if payload.config is not None:
|
if payload.config is not None:
|
||||||
cfg["config"] = payload.config
|
cfg["config"] = payload.config
|
||||||
|
|
||||||
|
# 配置变更 — 清除缓存适配器,下次 webhook 重建(凭证可能已更新)
|
||||||
|
await _invalidate_adapter_cache(channel_id)
|
||||||
|
|
||||||
return ChannelInfo(
|
return ChannelInfo(
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
channel_type=ChannelType(cfg["channel_type"]),
|
channel_type=ChannelType(cfg["channel_type"]),
|
||||||
|
|
@ -311,6 +352,9 @@ async def delete_channel(
|
||||||
for name in cfg.get("secret_keys", []):
|
for name in cfg.get("secret_keys", []):
|
||||||
await store.delete_secret(f"{channel_id}:{name}")
|
await store.delete_secret(f"{channel_id}:{name}")
|
||||||
|
|
||||||
|
# 清除缓存适配器并关闭旧实例
|
||||||
|
await _invalidate_adapter_cache(channel_id)
|
||||||
|
|
||||||
return {"deleted": channel_id}
|
return {"deleted": channel_id}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -320,13 +364,20 @@ async def delete_channel(
|
||||||
|
|
||||||
|
|
||||||
async def _build_adapter(channel_id: str) -> MessageAdapter:
|
async def _build_adapter(channel_id: str) -> MessageAdapter:
|
||||||
"""根据渠道配置与 secrets 构造适配器实例。
|
"""根据渠道配置与 secrets 构造适配器实例(带缓存)。
|
||||||
|
|
||||||
支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型,按 ``channel_type`` 分发。
|
支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型,按 ``channel_type`` 分发。
|
||||||
|
首次构造后缓存到 ``_adapter_cache``,后续请求复用同一实例(token TTL 缓存命中)。
|
||||||
|
secrets 获取使用 ``asyncio.gather`` 并行,避免串行 await。
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: 渠道不存在(404)、渠道类型不支持(400)、缺少必要凭证(500)。
|
HTTPException: 渠道不存在(404)、渠道类型不支持(400)、缺少必要凭证(500)。
|
||||||
"""
|
"""
|
||||||
|
# 命中缓存直接返回(per-request 重建会使 token TTL 缓存失效)
|
||||||
|
cached = _adapter_cache.get(channel_id)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
cfg = _channels.get(channel_id)
|
cfg = _channels.get(channel_id)
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||||
|
|
@ -335,43 +386,53 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
|
||||||
channel_type = cfg["channel_type"]
|
channel_type = cfg["channel_type"]
|
||||||
|
|
||||||
if channel_type == ChannelType.FEISHU.value:
|
if channel_type == ChannelType.FEISHU.value:
|
||||||
app_id = await store.get_secret(f"{channel_id}:app_id")
|
app_id, app_secret, encrypt_key, verification_token = await asyncio.gather(
|
||||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
store.get_secret(f"{channel_id}:app_id"),
|
||||||
encrypt_key = await store.get_secret(f"{channel_id}:encrypt_key")
|
store.get_secret(f"{channel_id}:app_secret"),
|
||||||
verification_token = await store.get_secret(f"{channel_id}:verification_token")
|
store.get_secret(f"{channel_id}:encrypt_key"),
|
||||||
|
store.get_secret(f"{channel_id}:verification_token"),
|
||||||
|
)
|
||||||
if not app_id or not app_secret:
|
if not app_id or not app_secret:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 app_id 或 app_secret"
|
status_code=500, detail=f"渠道 '{channel_id}' 缺少 app_id 或 app_secret"
|
||||||
)
|
)
|
||||||
return FeishuMessageAdapter(
|
adapter = FeishuMessageAdapter(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
app_secret=app_secret,
|
app_secret=app_secret,
|
||||||
encrypt_key=encrypt_key,
|
encrypt_key=encrypt_key,
|
||||||
verification_token=verification_token,
|
verification_token=verification_token,
|
||||||
)
|
)
|
||||||
|
_adapter_cache[channel_id] = adapter
|
||||||
|
return adapter
|
||||||
|
|
||||||
if channel_type == ChannelType.DINGTALK.value:
|
if channel_type == ChannelType.DINGTALK.value:
|
||||||
app_key = await store.get_secret(f"{channel_id}:app_key")
|
app_key, app_secret, robot_code, token = await asyncio.gather(
|
||||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
store.get_secret(f"{channel_id}:app_key"),
|
||||||
robot_code = await store.get_secret(f"{channel_id}:robot_code")
|
store.get_secret(f"{channel_id}:app_secret"),
|
||||||
token = await store.get_secret(f"{channel_id}:token")
|
store.get_secret(f"{channel_id}:robot_code"),
|
||||||
|
store.get_secret(f"{channel_id}:token"),
|
||||||
|
)
|
||||||
if not all([app_key, app_secret, robot_code]):
|
if not all([app_key, app_secret, robot_code]):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
|
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
|
||||||
)
|
)
|
||||||
return DingTalkMessageAdapter(
|
adapter = DingTalkMessageAdapter(
|
||||||
app_key=app_key,
|
app_key=app_key,
|
||||||
app_secret=app_secret,
|
app_secret=app_secret,
|
||||||
robot_code=robot_code,
|
robot_code=robot_code,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
_adapter_cache[channel_id] = adapter
|
||||||
|
return adapter
|
||||||
|
|
||||||
if channel_type == ChannelType.WECOM.value:
|
if channel_type == ChannelType.WECOM.value:
|
||||||
corp_id = await store.get_secret(f"{channel_id}:corp_id")
|
corp_id, corp_secret, token, encoding_aes_key, agent_id_raw = await asyncio.gather(
|
||||||
corp_secret = await store.get_secret(f"{channel_id}:corp_secret")
|
store.get_secret(f"{channel_id}:corp_id"),
|
||||||
token = await store.get_secret(f"{channel_id}:token")
|
store.get_secret(f"{channel_id}:corp_secret"),
|
||||||
encoding_aes_key = await store.get_secret(f"{channel_id}:encoding_aes_key")
|
store.get_secret(f"{channel_id}:token"),
|
||||||
agent_id_raw = await store.get_secret(f"{channel_id}:agent_id")
|
store.get_secret(f"{channel_id}:encoding_aes_key"),
|
||||||
|
store.get_secret(f"{channel_id}:agent_id"),
|
||||||
|
)
|
||||||
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
|
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
|
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
|
||||||
|
|
@ -383,45 +444,76 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"渠道 '{channel_id}' agent_id 不是合法整数",
|
detail=f"渠道 '{channel_id}' agent_id 不是合法整数",
|
||||||
) from exc
|
) from exc
|
||||||
return WeComMessageAdapter(
|
adapter = WeComMessageAdapter(
|
||||||
corp_id=corp_id,
|
corp_id=corp_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
corp_secret=corp_secret,
|
corp_secret=corp_secret,
|
||||||
token=token,
|
token=token,
|
||||||
encoding_aes_key=encoding_aes_key,
|
encoding_aes_key=encoding_aes_key,
|
||||||
)
|
)
|
||||||
|
_adapter_cache[channel_id] = adapter
|
||||||
|
return adapter
|
||||||
|
|
||||||
if channel_type == ChannelType.SLACK.value:
|
if channel_type == ChannelType.SLACK.value:
|
||||||
bot_token = await store.get_secret(f"{channel_id}:bot_token")
|
bot_token, signing_secret, verification_token = await asyncio.gather(
|
||||||
signing_secret = await store.get_secret(f"{channel_id}:signing_secret")
|
store.get_secret(f"{channel_id}:bot_token"),
|
||||||
verification_token = await store.get_secret(f"{channel_id}:verification_token")
|
store.get_secret(f"{channel_id}:signing_secret"),
|
||||||
|
store.get_secret(f"{channel_id}:verification_token"),
|
||||||
|
)
|
||||||
if not bot_token or not signing_secret:
|
if not bot_token or not signing_secret:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
|
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
|
||||||
)
|
)
|
||||||
return SlackMessageAdapter(
|
adapter = SlackMessageAdapter(
|
||||||
bot_token=bot_token,
|
bot_token=bot_token,
|
||||||
signing_secret=signing_secret,
|
signing_secret=signing_secret,
|
||||||
verification_token=verification_token,
|
verification_token=verification_token,
|
||||||
)
|
)
|
||||||
|
_adapter_cache[channel_id] = adapter
|
||||||
|
return adapter
|
||||||
|
|
||||||
raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}")
|
raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _invalidate_adapter_cache(channel_id: str) -> None:
|
||||||
|
"""清除指定渠道的缓存适配器并关闭旧实例(配置变更/删除时调用)。"""
|
||||||
|
old = _adapter_cache.pop(channel_id, None)
|
||||||
|
if old is not None:
|
||||||
|
try:
|
||||||
|
await old.close()
|
||||||
|
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更
|
||||||
|
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_all_adapters() -> None:
|
||||||
|
"""关闭所有缓存的适配器(供 app shutdown 调用)。"""
|
||||||
|
for channel_id, adapter in list(_adapter_cache.items()):
|
||||||
|
try:
|
||||||
|
await adapter.close()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
|
||||||
|
_adapter_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
async def _process_inbound_message(
|
async def _process_inbound_message(
|
||||||
app_state: Any, adapter: MessageAdapter, message: Any
|
app_state: Any, adapter: MessageAdapter, message: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。
|
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。
|
||||||
|
|
||||||
整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出
|
整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出
|
||||||
(webhook 必须保持响应能力)。``adapter.close()`` 在 finally 中调用。
|
(webhook 必须保持响应能力)。处理逻辑受全局信号量限流
|
||||||
|
(``_WEBHOOK_MAX_CONCURRENT``),防止高流量下 LLM 调用无界并发。
|
||||||
|
适配器由 ``_adapter_cache`` 管理,不在 per-request 关闭(关闭会清空 token 缓存)。
|
||||||
适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。
|
适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。
|
||||||
"""
|
"""
|
||||||
|
async with _get_webhook_semaphore():
|
||||||
try:
|
try:
|
||||||
request_preprocessor = getattr(app_state, "request_preprocessor", None)
|
request_preprocessor = getattr(app_state, "request_preprocessor", None)
|
||||||
llm_gateway = getattr(app_state, "llm_gateway", None)
|
llm_gateway = getattr(app_state, "llm_gateway", None)
|
||||||
if request_preprocessor is None or llm_gateway is None:
|
if request_preprocessor is None or llm_gateway is None:
|
||||||
logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理")
|
logger.warning(
|
||||||
|
"app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 路由预处理 — IM 场景使用默认 agent,无需技能注册表
|
# 路由预处理 — IM 场景使用默认 agent,无需技能注册表
|
||||||
|
|
@ -432,7 +524,7 @@ async def _process_inbound_message(
|
||||||
final_content = ""
|
final_content = ""
|
||||||
execution_mode = getattr(routing, "execution_mode", None)
|
execution_mode = getattr(routing, "execution_mode", None)
|
||||||
# DIRECT_CHAT 模式 — 直接调用 LLM
|
# DIRECT_CHAT 模式 — 直接调用 LLM
|
||||||
if execution_mode is not None and execution_mode.value == "direct_chat":
|
if execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||||
response = await llm_gateway.chat(
|
response = await llm_gateway.chat(
|
||||||
messages=[{"role": "user", "content": message.content}],
|
messages=[{"role": "user", "content": message.content}],
|
||||||
model=routing.model or "default",
|
model=routing.model or "default",
|
||||||
|
|
@ -470,11 +562,6 @@ async def _process_inbound_message(
|
||||||
await adapter.send_message(outgoing)
|
await adapter.send_message(outgoing)
|
||||||
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
||||||
logger.exception("处理入站消息失败: %s", exc)
|
logger.exception("处理入站消息失败: %s", exc)
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await adapter.close()
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
logger.debug("adapter.close() 异常已忽略")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip(request: Request) -> str:
|
def _get_client_ip(request: Request) -> str:
|
||||||
|
|
@ -527,7 +614,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = await adapter.receive_message(headers_dict, body)
|
message = await adapter.receive_message(headers_dict, body)
|
||||||
except (URLVerificationChallenge, SlackURLVerificationChallenge) as e:
|
except URLVerificationChallenge as e:
|
||||||
# URL 验证流程 — 飞书 / Slack 配置 webhook 时发送
|
# URL 验证流程 — 飞书 / Slack 配置 webhook 时发送
|
||||||
return {"challenge": e.challenge}
|
return {"challenge": e.challenge}
|
||||||
except WeComURLVerification as e:
|
except WeComURLVerification as e:
|
||||||
|
|
@ -535,6 +622,9 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
|
||||||
return Response(content=e.response_xml, media_type="application/xml")
|
return Response(content=e.response_xml, media_type="application/xml")
|
||||||
|
|
||||||
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200)
|
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200)
|
||||||
asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
|
# 持有 task 引用防止 GC 回收正在运行的后台任务
|
||||||
|
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
|
||||||
|
_pending_webhook_tasks.add(task)
|
||||||
|
task.add_done_callback(_pending_webhook_tasks.discard)
|
||||||
|
|
||||||
return {"code": 0}
|
return {"code": 0}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
||||||
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
||||||
8. cache_params_for_hit / no_cache — 返回正确 dict
|
8. cache_params_for_hit / no_cache — 返回正确 dict
|
||||||
9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True
|
9. record_cache_result — 记录命中/未命中到 stats 计数器
|
||||||
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
||||||
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
||||||
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
||||||
|
|
@ -163,10 +163,10 @@ class TestCacheStats:
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
|
||||||
# 2 hits
|
# 2 hits
|
||||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
|
manager.record_cache_result(True)
|
||||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
|
manager.record_cache_result(True)
|
||||||
# 1 miss
|
# 1 miss
|
||||||
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
|
manager.record_cache_result(False)
|
||||||
|
|
||||||
stats = manager.stats()
|
stats = manager.stats()
|
||||||
assert stats["total_hits"] == 2
|
assert stats["total_hits"] == 2
|
||||||
|
|
@ -302,39 +302,6 @@ class TestCacheParams:
|
||||||
assert params == {"no-cache": True}
|
assert params == {"no-cache": True}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# 9. detect_cache_hit
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestDetectCacheHit:
|
|
||||||
def test_hit_with_cache_key(self):
|
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
|
||||||
resp = _make_litellm_response(cache_key="some_key")
|
|
||||||
assert manager.detect_cache_hit(resp) is True
|
|
||||||
|
|
||||||
def test_hit_with_cache_hit_flag(self):
|
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
|
||||||
resp = _make_litellm_response()
|
|
||||||
resp._hidden_params = {"cache_hit": True}
|
|
||||||
assert manager.detect_cache_hit(resp) is True
|
|
||||||
|
|
||||||
def test_miss_without_cache_key(self):
|
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
|
||||||
resp = _make_litellm_response()
|
|
||||||
assert manager.detect_cache_hit(resp) is False
|
|
||||||
|
|
||||||
def test_miss_with_no_hidden_params(self):
|
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
|
||||||
resp = SimpleNamespace(_hidden_params=None)
|
|
||||||
assert manager.detect_cache_hit(resp) is False
|
|
||||||
|
|
||||||
def test_miss_with_no_hidden_params_attr(self):
|
|
||||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
|
||||||
resp = SimpleNamespace()
|
|
||||||
assert manager.detect_cache_hit(resp) is False
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 10. LitellmCacheConfig.from_cache_config
|
# 10. LitellmCacheConfig.from_cache_config
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue