refactor: ce-simplify-code 审查修复 — 去重 + 效率 + 死代码清理
3 个审查代理(复用/质量/效率)发现 15 个问题,全部修复: 效率与安全(6 项): - MCPClient 缓存 MultiServerMCPClient 单例 + aclose(),修复连接/子进程泄漏 - _rate_limits 清理空 IP 条目,修复 X-Forwarded-For 欺骗下内存泄漏 - _seen_nonces 改用 OrderedDict,O(1) 摊销过期清理 - webhook 后台任务加 Semaphore(20) + 任务引用追踪,限制无界并发 - _build_adapter 用 asyncio.gather 并行解密 secrets - 适配器实例缓存(_adapter_cache),token TTL 缓存跨请求命中 去重(4 项): - header_get 提取到 channels/base.py,4 个适配器统一 import - _get_client/close() 移入 MessageAdapter 基类,子类继承 - URLVerificationChallenge 统一到 base.py,feishu/slack/wecom 共用 - Transport ABC 添加 endpoint_url 属性,from_transport 不再访问私有字段 死代码与类型安全(5 项): - detect_cache_hit 死方法替换为 record_cache_result 公开 API - execution_mode.value == "direct_chat" 改用枚举比较 - 删除 yielded_any 死变量、重复 from fastapi import Request、 多余 getattr 防御 453 tests passed, ruff clean(预存 F841 非本次引入)
This commit is contained in:
parent
793476cafa
commit
1ccaf56b9a
|
|
@ -11,6 +11,8 @@ from dataclasses import dataclass, field
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
"""支持的消息平台渠道。"""
|
||||
|
|
@ -51,17 +53,42 @@ class OutgoingMessage:
|
|||
reply_to_message_id: str | None = None
|
||||
|
||||
|
||||
class URLVerificationChallenge(Exception):
|
||||
"""URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
||||
|
||||
飞书 / Slack 在配置 webhook 时发送一次 ``url_verification`` 事件,
|
||||
要求服务端原样返回 ``challenge`` 字段以验证 URL 可达。
|
||||
"""
|
||||
|
||||
def __init__(self, challenge: str) -> None:
|
||||
super().__init__(f"URL verification challenge: {challenge}")
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class MessageAdapter(abc.ABC):
|
||||
"""消息适配器 ABC。
|
||||
|
||||
生命周期:
|
||||
__init__ → verify_signature() → receive_message() → send_message() → close()
|
||||
|
||||
子类必须实现全部抽象方法。verify_signature 失败时调用方应拒绝处理
|
||||
(webhook 端点 fail-closed:Redis 不可用或签名校验失败均返回 503/401,
|
||||
不可跳过 nonce dedup 直接处理消息)。
|
||||
子类必须实现 verify_signature / receive_message / send_message 抽象方法。
|
||||
``_get_client`` / ``close`` 已在基类提供(懒构造 httpx 客户端 + 释放资源),
|
||||
子类只需在 ``__init__`` 中调用 ``super().__init__()``。
|
||||
|
||||
verify_signature 失败时调用方应拒绝处理(webhook 端点 fail-closed:
|
||||
Redis 不可用或签名校验失败均返回 503/401,不可跳过 nonce dedup 直接处理消息)。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""懒构造 httpx 客户端(子类共享此实现)。"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
@abc.abstractmethod
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
"""验证平台签名/token。返回 True 表示请求可信。"""
|
||||
|
|
@ -74,6 +101,25 @@ class MessageAdapter(abc.ABC):
|
|||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
"""向平台发送消息。返回 True 表示发送成功。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""释放资源(HTTP 客户端、连接池等)。"""
|
||||
"""关闭 httpx 客户端(如已创建)。子类可覆盖以释放额外资源。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
# 直接命中
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from agentkit.channels.base import (
|
|||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
header_get,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -68,25 +69,15 @@ class DingTalkMessageAdapter(MessageAdapter):
|
|||
robot_code: str,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app_key = app_key
|
||||
self.app_secret = app_secret
|
||||
self.robot_code = robot_code
|
||||
self.token = token
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
||||
# 升级路径:Redis 缓存共享给多实例。
|
||||
self._token_cache: tuple[str, float] | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -107,12 +98,12 @@ class DingTalkMessageAdapter(MessageAdapter):
|
|||
"""
|
||||
# Token 校验(若配置)
|
||||
if self.token is not None:
|
||||
token_header = _header_get(headers, "Token")
|
||||
token_header = header_get(headers, "Token")
|
||||
if token_header is None or not hmac.compare_digest(token_header, self.token):
|
||||
return False
|
||||
|
||||
sign = _header_get(headers, "Sign")
|
||||
timestamp_str = _header_get(headers, "Timestamp")
|
||||
sign = header_get(headers, "Sign")
|
||||
timestamp_str = header_get(headers, "Timestamp")
|
||||
|
||||
if sign is None and timestamp_str is None:
|
||||
# 无签名头:仅当 token 已校验通过才放行
|
||||
|
|
@ -253,29 +244,3 @@ class DingTalkMessageAdapter(MessageAdapter):
|
|||
except httpx.HTTPError as exc:
|
||||
logger.error("钉钉 accessToken 网络错误: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ from agentkit.channels.base import (
|
|||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
URLVerificationChallenge,
|
||||
header_get,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -45,18 +47,6 @@ _SEND_MESSAGE_URL = "https://open.feishu.cn/open-apis/im/v1/messages"
|
|||
_MENTION_RE = re.compile(r"@_user_\d+\s*")
|
||||
|
||||
|
||||
class URLVerificationChallenge(Exception):
|
||||
"""飞书 URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
||||
|
||||
飞书在配置 webhook 时会发送一次 ``url_verification`` 事件,要求服务端
|
||||
原样返回 ``challenge`` 字段以验证 URL 可达。
|
||||
"""
|
||||
|
||||
def __init__(self, challenge: str) -> None:
|
||||
super().__init__(f"URL verification challenge: {challenge}")
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class SignatureVerificationError(Exception):
|
||||
"""事件 ``verification_token`` 校验失败 — 拒绝处理。"""
|
||||
|
||||
|
|
@ -82,25 +72,15 @@ class FeishuMessageAdapter(MessageAdapter):
|
|||
encrypt_key: str | None = None,
|
||||
verification_token: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app_id = app_id
|
||||
self.app_secret = app_secret
|
||||
self.encrypt_key = encrypt_key
|
||||
self.verification_token = verification_token
|
||||
# 懒加载 httpx 客户端 — 避免未使用的适配器持有连接池
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
||||
# 升级路径:Redis 缓存共享给多实例。
|
||||
self._token_cache: tuple[str, float] | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -122,12 +102,12 @@ class FeishuMessageAdapter(MessageAdapter):
|
|||
logger.warning("飞书适配器未配置 encrypt_key — 拒绝所有 webhook 请求")
|
||||
return False
|
||||
|
||||
signature = _header_get(headers, "X-Lark-Signature")
|
||||
signature = header_get(headers, "X-Lark-Signature")
|
||||
if not signature:
|
||||
return False
|
||||
|
||||
timestamp_str = _header_get(headers, "X-Lark-Request-Timestamp")
|
||||
nonce = _header_get(headers, "X-Lark-Request-Nonce")
|
||||
timestamp_str = header_get(headers, "X-Lark-Request-Timestamp")
|
||||
nonce = header_get(headers, "X-Lark-Request-Nonce")
|
||||
if not timestamp_str or not nonce:
|
||||
return False
|
||||
|
||||
|
|
@ -335,30 +315,3 @@ class FeishuMessageAdapter(MessageAdapter):
|
|||
except httpx.HTTPError as exc:
|
||||
logger.error("飞书 tenant_token 网络错误: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
# 直接命中
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from agentkit.channels.base import (
|
|||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
URLVerificationChallenge,
|
||||
header_get,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -44,18 +46,6 @@ _SEND_MESSAGE_URL = "https://slack.com/api/chat.postMessage"
|
|||
_MENTION_RE = re.compile(r"<@[^>]+>\s*")
|
||||
|
||||
|
||||
class URLVerificationChallenge(Exception):
|
||||
"""Slack URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
||||
|
||||
Slack 在配置 Events API 时发送 ``url_verification`` 事件,要求服务端
|
||||
原样返回 ``challenge`` 字段以验证 URL 可达。
|
||||
"""
|
||||
|
||||
def __init__(self, challenge: str) -> None:
|
||||
super().__init__(f"Slack URL verification challenge: {challenge}")
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class SignatureVerificationError(Exception):
|
||||
"""``verification_token`` 校验失败 — 拒绝处理。"""
|
||||
|
||||
|
|
@ -79,20 +69,10 @@ class SlackMessageAdapter(MessageAdapter):
|
|||
signing_secret: str,
|
||||
verification_token: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.verification_token = verification_token
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
|
|
@ -111,8 +91,8 @@ class SlackMessageAdapter(MessageAdapter):
|
|||
Returns:
|
||||
True 表示签名校验通过。
|
||||
"""
|
||||
signature = _header_get(headers, "X-Slack-Signature")
|
||||
timestamp_str = _header_get(headers, "X-Slack-Request-Timestamp")
|
||||
signature = header_get(headers, "X-Slack-Signature")
|
||||
timestamp_str = header_get(headers, "X-Slack-Request-Timestamp")
|
||||
if not signature or not timestamp_str:
|
||||
return False
|
||||
|
||||
|
|
@ -259,29 +239,3 @@ class SlackMessageAdapter(MessageAdapter):
|
|||
except httpx.HTTPError as exc:
|
||||
logger.error("Slack send_message 网络错误: %s", exc)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from agentkit.channels.base import (
|
|||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
header_get,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -80,25 +81,15 @@ class WeComMessageAdapter(MessageAdapter):
|
|||
token: str,
|
||||
encoding_aes_key: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.corp_id = corp_id
|
||||
self.agent_id = agent_id
|
||||
self.corp_secret = corp_secret
|
||||
self.token = token
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
# ponytail: 简单 TTL 缓存。天花板:单实例内存;升级路径:Redis 共享。
|
||||
self._token_cache: tuple[str, float] | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AES 密钥 / 加解密
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -209,9 +200,9 @@ class WeComMessageAdapter(MessageAdapter):
|
|||
Returns:
|
||||
True 表示签名校验通过。
|
||||
"""
|
||||
msg_signature = _header_get(headers, "msg_signature")
|
||||
timestamp = _header_get(headers, "timestamp")
|
||||
nonce = _header_get(headers, "nonce")
|
||||
msg_signature = header_get(headers, "msg_signature")
|
||||
timestamp = header_get(headers, "timestamp")
|
||||
nonce = header_get(headers, "nonce")
|
||||
if not msg_signature or not timestamp or not nonce:
|
||||
return False
|
||||
|
||||
|
|
@ -247,8 +238,8 @@ class WeComMessageAdapter(MessageAdapter):
|
|||
|
||||
# URL 验证流程 — 内部 XML 包含 EchoStr
|
||||
if "EchoStr" in inner:
|
||||
timestamp = _header_get(headers, "timestamp") or ""
|
||||
nonce = _header_get(headers, "nonce") or ""
|
||||
timestamp = header_get(headers, "timestamp") or ""
|
||||
nonce = header_get(headers, "nonce") or ""
|
||||
encrypted_echo = self._encrypt(inner["EchoStr"])
|
||||
# 计算响应签名
|
||||
sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo])
|
||||
|
|
@ -362,29 +353,3 @@ class WeComMessageAdapter(MessageAdapter):
|
|||
except httpx.HTTPError as exc:
|
||||
logger.error("企微 access_token 网络错误: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -803,18 +803,16 @@ class LitellmCacheManager:
|
|||
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
||||
return {"no-cache": True}
|
||||
|
||||
def detect_cache_hit(self, response: Any) -> bool:
|
||||
"""检测 LiteLLM 响应是否为缓存命中。
|
||||
def record_cache_result(self, is_hit: bool) -> None:
|
||||
"""记录单次 LLM 调用的缓存命中/未命中(用于 usage tracking 统计)。
|
||||
|
||||
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
|
||||
命中判定由调用方完成(gateway 通过 ``response.cache_hit`` 判定),
|
||||
本方法只负责更新计数器,避免重复检测逻辑。
|
||||
"""
|
||||
hidden = getattr(response, "_hidden_params", None)
|
||||
if isinstance(hidden, dict):
|
||||
if "cache_key" in hidden or hidden.get("cache_hit"):
|
||||
if is_hit:
|
||||
self._hits += 1
|
||||
return True
|
||||
else:
|
||||
self._misses += 1
|
||||
return False
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""返回缓存统计。"""
|
||||
|
|
|
|||
|
|
@ -224,10 +224,8 @@ class LLMGateway:
|
|||
|
||||
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0)
|
||||
is_cache_hit = getattr(response, "cache_hit", False)
|
||||
if is_cache_hit and self._cache_manager is not None:
|
||||
self._cache_manager._hits += 1
|
||||
elif self._cache_manager is not None:
|
||||
self._cache_manager._misses += 1
|
||||
if self._cache_manager is not None:
|
||||
self._cache_manager.record_cache_result(is_cache_hit)
|
||||
|
||||
# 计算成本(缓存命中时 cost=0)
|
||||
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
|
||||
|
|
|
|||
|
|
@ -119,7 +119,6 @@ class LitellmProvider(LLMProvider):
|
|||
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
final_usage: TokenUsage | None = None
|
||||
final_model: str = request.model
|
||||
yielded_any = False
|
||||
|
||||
try:
|
||||
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
|
||||
|
|
@ -129,7 +128,6 @@ class LitellmProvider(LLMProvider):
|
|||
raw = litellm.acompletion(**kwargs)
|
||||
stream = await raw if inspect.isawaitable(raw) else raw
|
||||
async for chunk in stream:
|
||||
yielded_any = True
|
||||
parsed = self._parse_stream_chunk(
|
||||
chunk,
|
||||
request.model,
|
||||
|
|
@ -156,10 +154,6 @@ class LitellmProvider(LLMProvider):
|
|||
except Exception as e:
|
||||
raise LLMProviderError(self._provider_type, str(e)) from e
|
||||
|
||||
# ponytail: 若流完全为空(yielded_any=False),上面仍会 yield 一个
|
||||
# is_final=True 的空 chunk,调用方据此判断空响应。无需额外分支。
|
||||
_ = yielded_any # 标记保留(调试 / 未来扩展)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -184,7 +178,7 @@ class LitellmProvider(LLMProvider):
|
|||
if request.timeout is not None:
|
||||
kwargs["timeout"] = request.timeout
|
||||
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
||||
cache_params = getattr(request, "_cache", None)
|
||||
cache_params = request._cache
|
||||
if cache_params is not None:
|
||||
kwargs["cache"] = cache_params
|
||||
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,9 @@ class MCPClient:
|
|||
self._timeout = timeout
|
||||
self._tools_cache: list[dict] | None = None
|
||||
self._transport = transport
|
||||
# U10 — 懒构造并缓存的 langchain client,避免每次 list_tools/call_tool
|
||||
# 都新建 MultiServerMCPClient(stdio 传输下会反复 spawn 子进程)。
|
||||
self._lc_client: Any = None
|
||||
|
||||
if transport is not None:
|
||||
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
||||
|
|
@ -138,15 +141,34 @@ class MCPClient:
|
|||
发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。
|
||||
"""
|
||||
if isinstance(transport, HTTPTransport):
|
||||
server_url = transport._endpoint
|
||||
server_url = transport.endpoint_url
|
||||
elif isinstance(transport, SSETransport):
|
||||
server_url = transport._endpoint
|
||||
server_url = transport.endpoint_url
|
||||
elif isinstance(transport, StdioTransport):
|
||||
server_url = f"stdio://{transport._command}"
|
||||
else:
|
||||
server_url = ""
|
||||
return cls(server_url=server_url, transport=transport)
|
||||
|
||||
async def _get_lc_client(self) -> Any:
|
||||
"""懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。
|
||||
|
||||
首次调用时创建,后续返回缓存,避免每次 list_tools/call_tool 都新建
|
||||
client(stdio 传输下会 spawn 新子进程,造成连接/进程泄漏)。
|
||||
"""
|
||||
if self._lc_client is None:
|
||||
client_cls = _import_langchain_client()
|
||||
self._lc_client = client_cls({"server": self._langchain_config})
|
||||
return self._lc_client
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""关闭缓存的 langchain client(如果它提供 ``aclose`` 方法)。"""
|
||||
if self._lc_client is not None:
|
||||
aclose = getattr(self._lc_client, "aclose", None)
|
||||
if aclose is not None:
|
||||
await aclose()
|
||||
self._lc_client = None
|
||||
|
||||
async def list_tools(self) -> list[dict]:
|
||||
"""列出远程 MCP Server 上的工具
|
||||
|
||||
|
|
@ -165,9 +187,8 @@ class MCPClient:
|
|||
self._tools_cache = tools
|
||||
return self._tools_cache
|
||||
|
||||
# 新 langchain 路径
|
||||
client_cls = _import_langchain_client()
|
||||
client = client_cls({"server": self._langchain_config})
|
||||
# 新 langchain 路径 — 复用缓存的 client
|
||||
client = await self._get_lc_client()
|
||||
lc_tools = await client.get_tools()
|
||||
tools = [
|
||||
{
|
||||
|
|
@ -218,9 +239,8 @@ class MCPClient:
|
|||
params={"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
|
||||
# 新 langchain 路径
|
||||
client_cls = _import_langchain_client()
|
||||
client = client_cls({"server": self._langchain_config})
|
||||
# 新 langchain 路径 — 复用缓存的 client
|
||||
client = await self._get_lc_client()
|
||||
lc_tools = await client.get_tools()
|
||||
for tool in lc_tools:
|
||||
if tool.name == tool_name:
|
||||
|
|
|
|||
|
|
@ -230,7 +230,6 @@ class MCPServer:
|
|||
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request # noqa: F401 — 用于 jsonrpc_endpoint 的类型注解
|
||||
except ImportError:
|
||||
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
|
||||
|
||||
|
|
|
|||
|
|
@ -98,6 +98,11 @@ class HTTPTransport(Transport):
|
|||
def is_connected(self) -> bool:
|
||||
return self._client is not None and not self._client.is_closed
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str:
|
||||
"""已配置的端点 URL(去掉尾部斜杠)。"""
|
||||
return self._endpoint
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立 HTTP 连接"""
|
||||
if self.is_connected:
|
||||
|
|
@ -220,6 +225,11 @@ class SSETransport(Transport):
|
|||
def is_connected(self) -> bool:
|
||||
return self._connected and self._client is not None and not self._client.is_closed
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str:
|
||||
"""已配置的端点 URL(去掉尾部斜杠)。"""
|
||||
return self._endpoint
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立 SSE 连接
|
||||
|
||||
|
|
|
|||
|
|
@ -21,21 +21,25 @@ import asyncio
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agentkit.channels.base import ChannelType, MessageAdapter, OutgoingMessage
|
||||
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
||||
from agentkit.channels.feishu import FeishuMessageAdapter, URLVerificationChallenge
|
||||
from agentkit.channels.secrets import SecretsStore
|
||||
from agentkit.channels.slack import (
|
||||
SlackMessageAdapter,
|
||||
URLVerificationChallenge as SlackURLVerificationChallenge,
|
||||
from agentkit.channels.base import (
|
||||
ChannelType,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
URLVerificationChallenge,
|
||||
)
|
||||
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
||||
from agentkit.channels.feishu import FeishuMessageAdapter
|
||||
from agentkit.channels.secrets import SecretsStore
|
||||
from agentkit.channels.slack import SlackMessageAdapter
|
||||
from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification
|
||||
from agentkit.chat.skill_routing import ExecutionMode
|
||||
from agentkit.server.auth.dependencies import require_permission
|
||||
from agentkit.server.auth.permissions import Permission
|
||||
|
||||
|
|
@ -89,16 +93,40 @@ _RATE_LIMIT_WINDOW = 60.0 # 窗口大小(秒)
|
|||
_RATE_LIMIT_MAX = 100 # 窗口内最大请求数
|
||||
|
||||
# nonce -> 过期时间戳(与飞书签名时间戳窗口一致)
|
||||
_seen_nonces: dict[str, float] = {}
|
||||
# OrderedDict:按插入顺序遍历,清理过期项时从头部弹出,O(1) 摊销。
|
||||
_seen_nonces: OrderedDict[str, float] = OrderedDict()
|
||||
_NONCE_TTL = 300.0
|
||||
|
||||
# Webhook 后台处理并发上限 — 防止高流量下 LLM 调用无界并发
|
||||
_WEBHOOK_MAX_CONCURRENT = 20
|
||||
_webhook_semaphore: asyncio.Semaphore | None = None
|
||||
# 持有后台任务引用,防止 GC 回收正在运行的 task
|
||||
_pending_webhook_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# 适配器缓存 — channel_id -> adapter。避免 per-request 重建导致 token TTL 缓存失效。
|
||||
# 配置变更(PUT/DELETE channel)时通过 _invalidate_adapter_cache 清除对应条目。
|
||||
_adapter_cache: dict[str, MessageAdapter] = {}
|
||||
|
||||
|
||||
def _get_webhook_semaphore() -> asyncio.Semaphore:
|
||||
"""懒构造 webhook 并发信号量(事件循环首次需要时创建)。"""
|
||||
global _webhook_semaphore
|
||||
if _webhook_semaphore is None:
|
||||
_webhook_semaphore = asyncio.Semaphore(_WEBHOOK_MAX_CONCURRENT)
|
||||
return _webhook_semaphore
|
||||
|
||||
|
||||
def _check_rate_limit(client_ip: str) -> bool:
|
||||
"""滑动窗口限流。返回 True 表示放行,False 表示超限。"""
|
||||
"""滑动窗口限流。返回 True 表示放行,False 表示超限。
|
||||
|
||||
ponytail: 过滤后 timestamps 为空时清除 IP 条目,避免 _rate_limits 无界增长。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
timestamps = _rate_limits.get(client_ip, [])
|
||||
cutoff = now - _RATE_LIMIT_WINDOW
|
||||
timestamps = [t for t in timestamps if t > cutoff]
|
||||
timestamps = [t for t in _rate_limits.get(client_ip, []) if t > cutoff]
|
||||
if not timestamps:
|
||||
# 清理过期空条目 — 不活跃 IP 不再占用 dict 槽位
|
||||
_rate_limits.pop(client_ip, None)
|
||||
if len(timestamps) >= _RATE_LIMIT_MAX:
|
||||
_rate_limits[client_ip] = timestamps
|
||||
return False
|
||||
|
|
@ -108,12 +136,20 @@ def _check_rate_limit(client_ip: str) -> bool:
|
|||
|
||||
|
||||
def _check_nonce_dedup(nonce: str) -> bool:
|
||||
"""Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。"""
|
||||
"""Nonce 去重。返回 True 表示新 nonce(应处理),False 表示重复(跳过)。
|
||||
|
||||
使用 OrderedDict 按插入顺序清理过期项:从头部弹出已过期条目,
|
||||
O(1) 摊销(旧实现遍历整个字典为 O(N))。fail-closed 语义:过期 nonce 拒绝。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
# 惰性清理过期项
|
||||
expired = [k for k, v in _seen_nonces.items() if v < now]
|
||||
for k in expired:
|
||||
del _seen_nonces[k]
|
||||
# 从头部弹出过期项(nonce 按单调时间插入,头部最旧即最先过期)。
|
||||
# ponytail: 摊销 O(1) 清理。旧实现遍历整个 dict 为 O(N)。
|
||||
while _seen_nonces:
|
||||
_oldest_key, oldest_expiry = next(iter(_seen_nonces.items()))
|
||||
if oldest_expiry >= now:
|
||||
break
|
||||
_seen_nonces.popitem(last=False)
|
||||
# 过期 nonce 已被清理 — 视为新 nonce(与旧实现先删全部过期项再查一致)
|
||||
if nonce in _seen_nonces:
|
||||
return False
|
||||
_seen_nonces[nonce] = now + _NONCE_TTL
|
||||
|
|
@ -124,6 +160,8 @@ def _reset_webhook_state() -> None:
|
|||
"""重置限流与 nonce 状态(仅供测试使用)。"""
|
||||
_rate_limits.clear()
|
||||
_seen_nonces.clear()
|
||||
_adapter_cache.clear()
|
||||
_pending_webhook_tasks.clear()
|
||||
|
||||
|
||||
def _validate_channel_id(channel_id: str) -> str:
|
||||
|
|
@ -288,6 +326,9 @@ async def update_channel(
|
|||
if payload.config is not None:
|
||||
cfg["config"] = payload.config
|
||||
|
||||
# 配置变更 — 清除缓存适配器,下次 webhook 重建(凭证可能已更新)
|
||||
await _invalidate_adapter_cache(channel_id)
|
||||
|
||||
return ChannelInfo(
|
||||
channel_id=channel_id,
|
||||
channel_type=ChannelType(cfg["channel_type"]),
|
||||
|
|
@ -311,6 +352,9 @@ async def delete_channel(
|
|||
for name in cfg.get("secret_keys", []):
|
||||
await store.delete_secret(f"{channel_id}:{name}")
|
||||
|
||||
# 清除缓存适配器并关闭旧实例
|
||||
await _invalidate_adapter_cache(channel_id)
|
||||
|
||||
return {"deleted": channel_id}
|
||||
|
||||
|
||||
|
|
@ -320,13 +364,20 @@ async def delete_channel(
|
|||
|
||||
|
||||
async def _build_adapter(channel_id: str) -> MessageAdapter:
|
||||
"""根据渠道配置与 secrets 构造适配器实例。
|
||||
"""根据渠道配置与 secrets 构造适配器实例(带缓存)。
|
||||
|
||||
支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型,按 ``channel_type`` 分发。
|
||||
首次构造后缓存到 ``_adapter_cache``,后续请求复用同一实例(token TTL 缓存命中)。
|
||||
secrets 获取使用 ``asyncio.gather`` 并行,避免串行 await。
|
||||
|
||||
Raises:
|
||||
HTTPException: 渠道不存在(404)、渠道类型不支持(400)、缺少必要凭证(500)。
|
||||
"""
|
||||
# 命中缓存直接返回(per-request 重建会使 token TTL 缓存失效)
|
||||
cached = _adapter_cache.get(channel_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
cfg = _channels.get(channel_id)
|
||||
if cfg is None:
|
||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||
|
|
@ -335,43 +386,53 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
|
|||
channel_type = cfg["channel_type"]
|
||||
|
||||
if channel_type == ChannelType.FEISHU.value:
|
||||
app_id = await store.get_secret(f"{channel_id}:app_id")
|
||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
||||
encrypt_key = await store.get_secret(f"{channel_id}:encrypt_key")
|
||||
verification_token = await store.get_secret(f"{channel_id}:verification_token")
|
||||
app_id, app_secret, encrypt_key, verification_token = await asyncio.gather(
|
||||
store.get_secret(f"{channel_id}:app_id"),
|
||||
store.get_secret(f"{channel_id}:app_secret"),
|
||||
store.get_secret(f"{channel_id}:encrypt_key"),
|
||||
store.get_secret(f"{channel_id}:verification_token"),
|
||||
)
|
||||
if not app_id or not app_secret:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 app_id 或 app_secret"
|
||||
)
|
||||
return FeishuMessageAdapter(
|
||||
adapter = FeishuMessageAdapter(
|
||||
app_id=app_id,
|
||||
app_secret=app_secret,
|
||||
encrypt_key=encrypt_key,
|
||||
verification_token=verification_token,
|
||||
)
|
||||
_adapter_cache[channel_id] = adapter
|
||||
return adapter
|
||||
|
||||
if channel_type == ChannelType.DINGTALK.value:
|
||||
app_key = await store.get_secret(f"{channel_id}:app_key")
|
||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
||||
robot_code = await store.get_secret(f"{channel_id}:robot_code")
|
||||
token = await store.get_secret(f"{channel_id}:token")
|
||||
app_key, app_secret, robot_code, token = await asyncio.gather(
|
||||
store.get_secret(f"{channel_id}:app_key"),
|
||||
store.get_secret(f"{channel_id}:app_secret"),
|
||||
store.get_secret(f"{channel_id}:robot_code"),
|
||||
store.get_secret(f"{channel_id}:token"),
|
||||
)
|
||||
if not all([app_key, app_secret, robot_code]):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
|
||||
)
|
||||
return DingTalkMessageAdapter(
|
||||
adapter = DingTalkMessageAdapter(
|
||||
app_key=app_key,
|
||||
app_secret=app_secret,
|
||||
robot_code=robot_code,
|
||||
token=token,
|
||||
)
|
||||
_adapter_cache[channel_id] = adapter
|
||||
return adapter
|
||||
|
||||
if channel_type == ChannelType.WECOM.value:
|
||||
corp_id = await store.get_secret(f"{channel_id}:corp_id")
|
||||
corp_secret = await store.get_secret(f"{channel_id}:corp_secret")
|
||||
token = await store.get_secret(f"{channel_id}:token")
|
||||
encoding_aes_key = await store.get_secret(f"{channel_id}:encoding_aes_key")
|
||||
agent_id_raw = await store.get_secret(f"{channel_id}:agent_id")
|
||||
corp_id, corp_secret, token, encoding_aes_key, agent_id_raw = await asyncio.gather(
|
||||
store.get_secret(f"{channel_id}:corp_id"),
|
||||
store.get_secret(f"{channel_id}:corp_secret"),
|
||||
store.get_secret(f"{channel_id}:token"),
|
||||
store.get_secret(f"{channel_id}:encoding_aes_key"),
|
||||
store.get_secret(f"{channel_id}:agent_id"),
|
||||
)
|
||||
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
|
||||
|
|
@ -383,45 +444,76 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
|
|||
status_code=500,
|
||||
detail=f"渠道 '{channel_id}' agent_id 不是合法整数",
|
||||
) from exc
|
||||
return WeComMessageAdapter(
|
||||
adapter = WeComMessageAdapter(
|
||||
corp_id=corp_id,
|
||||
agent_id=agent_id,
|
||||
corp_secret=corp_secret,
|
||||
token=token,
|
||||
encoding_aes_key=encoding_aes_key,
|
||||
)
|
||||
_adapter_cache[channel_id] = adapter
|
||||
return adapter
|
||||
|
||||
if channel_type == ChannelType.SLACK.value:
|
||||
bot_token = await store.get_secret(f"{channel_id}:bot_token")
|
||||
signing_secret = await store.get_secret(f"{channel_id}:signing_secret")
|
||||
verification_token = await store.get_secret(f"{channel_id}:verification_token")
|
||||
bot_token, signing_secret, verification_token = await asyncio.gather(
|
||||
store.get_secret(f"{channel_id}:bot_token"),
|
||||
store.get_secret(f"{channel_id}:signing_secret"),
|
||||
store.get_secret(f"{channel_id}:verification_token"),
|
||||
)
|
||||
if not bot_token or not signing_secret:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
|
||||
)
|
||||
return SlackMessageAdapter(
|
||||
adapter = SlackMessageAdapter(
|
||||
bot_token=bot_token,
|
||||
signing_secret=signing_secret,
|
||||
verification_token=verification_token,
|
||||
)
|
||||
_adapter_cache[channel_id] = adapter
|
||||
return adapter
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}")
|
||||
|
||||
|
||||
async def _invalidate_adapter_cache(channel_id: str) -> None:
|
||||
"""清除指定渠道的缓存适配器并关闭旧实例(配置变更/删除时调用)。"""
|
||||
old = _adapter_cache.pop(channel_id, None)
|
||||
if old is not None:
|
||||
try:
|
||||
await old.close()
|
||||
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更
|
||||
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
|
||||
|
||||
|
||||
async def close_all_adapters() -> None:
|
||||
"""关闭所有缓存的适配器(供 app shutdown 调用)。"""
|
||||
for channel_id, adapter in list(_adapter_cache.items()):
|
||||
try:
|
||||
await adapter.close()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
|
||||
_adapter_cache.clear()
|
||||
|
||||
|
||||
async def _process_inbound_message(
|
||||
app_state: Any, adapter: MessageAdapter, message: Any
|
||||
) -> None:
|
||||
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。
|
||||
|
||||
整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出
|
||||
(webhook 必须保持响应能力)。``adapter.close()`` 在 finally 中调用。
|
||||
(webhook 必须保持响应能力)。处理逻辑受全局信号量限流
|
||||
(``_WEBHOOK_MAX_CONCURRENT``),防止高流量下 LLM 调用无界并发。
|
||||
适配器由 ``_adapter_cache`` 管理,不在 per-request 关闭(关闭会清空 token 缓存)。
|
||||
适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。
|
||||
"""
|
||||
async with _get_webhook_semaphore():
|
||||
try:
|
||||
request_preprocessor = getattr(app_state, "request_preprocessor", None)
|
||||
llm_gateway = getattr(app_state, "llm_gateway", None)
|
||||
if request_preprocessor is None or llm_gateway is None:
|
||||
logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理")
|
||||
logger.warning(
|
||||
"app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理"
|
||||
)
|
||||
return
|
||||
|
||||
# 路由预处理 — IM 场景使用默认 agent,无需技能注册表
|
||||
|
|
@ -432,7 +524,7 @@ async def _process_inbound_message(
|
|||
final_content = ""
|
||||
execution_mode = getattr(routing, "execution_mode", None)
|
||||
# DIRECT_CHAT 模式 — 直接调用 LLM
|
||||
if execution_mode is not None and execution_mode.value == "direct_chat":
|
||||
if execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||
response = await llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": message.content}],
|
||||
model=routing.model or "default",
|
||||
|
|
@ -470,11 +562,6 @@ async def _process_inbound_message(
|
|||
await adapter.send_message(outgoing)
|
||||
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
||||
logger.exception("处理入站消息失败: %s", exc)
|
||||
finally:
|
||||
try:
|
||||
await adapter.close()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("adapter.close() 异常已忽略")
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
|
|
@ -527,7 +614,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
|
|||
|
||||
try:
|
||||
message = await adapter.receive_message(headers_dict, body)
|
||||
except (URLVerificationChallenge, SlackURLVerificationChallenge) as e:
|
||||
except URLVerificationChallenge as e:
|
||||
# URL 验证流程 — 飞书 / Slack 配置 webhook 时发送
|
||||
return {"challenge": e.challenge}
|
||||
except WeComURLVerification as e:
|
||||
|
|
@ -535,6 +622,9 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
|
|||
return Response(content=e.response_xml, media_type="application/xml")
|
||||
|
||||
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200)
|
||||
asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
|
||||
# 持有 task 引用防止 GC 回收正在运行的后台任务
|
||||
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
|
||||
_pending_webhook_tasks.add(task)
|
||||
task.add_done_callback(_pending_webhook_tasks.discard)
|
||||
|
||||
return {"code": 0}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
||||
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
||||
8. cache_params_for_hit / no_cache — 返回正确 dict
|
||||
9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True
|
||||
9. record_cache_result — 记录命中/未命中到 stats 计数器
|
||||
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
||||
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
||||
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
||||
|
|
@ -163,10 +163,10 @@ class TestCacheStats:
|
|||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
|
||||
# 2 hits
|
||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
|
||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
|
||||
manager.record_cache_result(True)
|
||||
manager.record_cache_result(True)
|
||||
# 1 miss
|
||||
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
|
||||
manager.record_cache_result(False)
|
||||
|
||||
stats = manager.stats()
|
||||
assert stats["total_hits"] == 2
|
||||
|
|
@ -302,39 +302,6 @@ class TestCacheParams:
|
|||
assert params == {"no-cache": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. detect_cache_hit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectCacheHit:
|
||||
def test_hit_with_cache_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response(cache_key="some_key")
|
||||
assert manager.detect_cache_hit(resp) is True
|
||||
|
||||
def test_hit_with_cache_hit_flag(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response()
|
||||
resp._hidden_params = {"cache_hit": True}
|
||||
assert manager.detect_cache_hit(resp) is True
|
||||
|
||||
def test_miss_without_cache_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response()
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
def test_miss_with_no_hidden_params(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = SimpleNamespace(_hidden_params=None)
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
def test_miss_with_no_hidden_params_attr(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = SimpleNamespace()
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. LitellmCacheConfig.from_cache_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue