fix(review): ce-code-review P1+P2 修复 — 安全/可靠性/性能

P1 安全与可靠性(4 项):
- wecom: verify_signature 增加时间戳新鲜度校验(5 分钟窗口防重放)
- cache: should_cache 在 per_user_namespace 开启时拒绝 user_id=None
  匿名请求,避免跨用户缓存泄漏(安全要求 a/e)
- channels: webhook receive_message 异常兜底,防止 500 触发平台重试风暴
- app: shutdown 调用 close_all_adapters + await _pending_webhook_tasks,
  防止 httpx 连接泄漏和丢失 IM 回复

P2 效率与可维护性(5 项):
- feishu: _TOKEN_CACHE_TTL 300 → 6900(2h 减 5min 余量,避免 24x 过频刷新)
- channels: _pending_webhook_tasks 有界化(2x 并发上限时 429 拒绝)
- gateway: quota 检查每 period 单次 get_usage,复用 summary 检查 token+cost
- cache_key: generate_cache_key 合并为单次 SHA-256(消除 8-10 次冗余哈希)
- config: ProviderConfig.get_api_key 移除未用的 secrets_store 参数

P3 去重(1 项):
- channels: _process_inbound_message DIRECT_CHAT 路径提取 _direct_chat 辅助函数

测试:
- test_wecom: 时间戳改用 int(time.time()),新增 test_expired_timestamp_rejected
- test_cache: should_cache 测试覆盖匿名拒绝 + namespace_off 兼容
- test_config_migration: get_api_key 测试适配新签名
- channels/config_migration/quota_enforcement 测试全部通过
This commit is contained in:
chiguyong 2026-06-26 01:40:31 +08:00
parent 1ccaf56b9a
commit 53faa60472
11 changed files with 138 additions and 120 deletions

View File

@ -36,8 +36,8 @@ logger = logging.getLogger(__name__)
# 签名时间戳允许的最大偏移(秒)— 与飞书官方文档保持一致
_SIGNATURE_MAX_AGE_SECONDS = 300
# tenant_access_token 缓存 TTL— 飞书 token 实际有效期 2h留 5min 余量
_TOKEN_CACHE_TTL = 300.0
# tenant_access_token 缓存 TTL— 飞书 token 实际有效期 2h(7200s),留 5min 余量
_TOKEN_CACHE_TTL = 6900.0
# 飞书开放平台 API 端点
_TENANT_TOKEN_URL = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"

View File

@ -34,6 +34,9 @@ from agentkit.channels.base import (
logger = logging.getLogger(__name__)
# 企微签名时间戳允许的最大偏移(秒)— 5 分钟窗口防重放
_SIGNATURE_MAX_AGE_SECONDS = 300
# access_token 缓存 TTL— 企微 token 有效期 2 小时
_TOKEN_CACHE_TTL = 7200.0
@ -206,6 +209,19 @@ class WeComMessageAdapter(MessageAdapter):
if not msg_signature or not timestamp or not nonce:
return False
# 时间戳新鲜度校验 — 5 分钟窗口防重放攻击
try:
ts_int = int(timestamp)
except (TypeError, ValueError):
return False
now = int(time.time())
if abs(now - ts_int) > _SIGNATURE_MAX_AGE_SECONDS:
logger.warning(
"企微 webhook 时间戳超出 %ds 窗口: ts=%s now=%d",
_SIGNATURE_MAX_AGE_SECONDS, timestamp, now,
)
return False
encrypt = self._extract_encrypt(body)
if not encrypt:
return False

View File

@ -783,14 +783,19 @@ class LitellmCacheManager:
kb_caching_disabled: bool = False,
user_id: str | None = None,
) -> bool:
"""判断当前请求是否应该缓存(安全要求 c
"""判断当前请求是否应该缓存(安全要求 a, c, e)。
- KB 设置 caching_disabled=True 不缓存
- 其余情况缓存user_id None 时仍可缓存 key 不含 user scope
- KB 设置 caching_disabled=True 不缓存要求 c
- per_user_namespace 开启且 user_id None 不缓存要求 a, e
匿名请求无法做 per-user 隔离缓存会导致跨用户泄漏强制 no-cache
"""
_ = user_id # 预留:未来支持 per-user 缓存禁用
if kb_caching_disabled:
return False
if self._config.per_user_namespace and user_id is None:
# 安全要求 (a)(e)per-user namespace 开启时拒绝匿名缓存,
# 避免不同匿名用户命中同一缓存键导致跨用户泄漏。
logger.debug("should_cache: per_user_namespace on but user_id=None — skip cache")
return False
return True
@staticmethod

View File

@ -39,21 +39,23 @@ def generate_cache_key(
64-character hex SHA-256 hash string.
"""
system_prompt = _extract_system_prompt(messages)
components = [
_hash_str(model),
_hash_str(system_prompt),
_hash_json(messages),
_hash_str(f"{temperature:.2f}"),
_hash_json(tools),
_hash_str(tool_choice),
_hash_str(str(max_tokens)),
# 单次 SHA-256用分隔符拼接所有组件避免逐组件 hash 再 hash 的冗余计算。
# 分隔符使用长度前缀防止歧义(如 "ab" + "cd" vs "a" + "bcd")。
parts = [
f"m:{model}",
f"s:{system_prompt}",
f"msg:{json.dumps(messages, sort_keys=True, ensure_ascii=False)}",
f"t:{temperature:.2f}",
f"tools:{json.dumps(tools, sort_keys=True, ensure_ascii=False) if tools is not None else 'null'}",
f"tc:{tool_choice}",
f"mt:{max_tokens}",
]
# U17 — per-user namespace + ACL scope hash安全要求 a, b, e
if user_id is not None:
components.append(_hash_str(f"user:{user_id}"))
parts.append(f"u:{user_id}")
if kb_acl_hash is not None:
components.append(_hash_str(f"acl:{kb_acl_hash}"))
combined = "".join(components)
parts.append(f"a:{kb_acl_hash}")
combined = "\x1f".join(parts) # US (Unit Separator) 防止组件内容注入分隔符
return hashlib.sha256(combined.encode()).hexdigest()
@ -63,15 +65,3 @@ def _extract_system_prompt(messages: list[dict[str, str]]) -> str:
if msg.get("role") == "system":
return msg.get("content", "")
return ""
def _hash_str(s: str) -> str:
"""SHA-256 hash of a string."""
return hashlib.sha256(s.encode()).hexdigest()
def _hash_json(obj: Any) -> str:
"""SHA-256 hash of a JSON-serializable object."""
if obj is None:
return hashlib.sha256(b"null").hexdigest()
return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest()

View File

@ -74,21 +74,15 @@ class ProviderConfig:
api_key_encrypted: str | None = None
api_key_source: str = "plaintext"
def get_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:
"""同步读取 API Key
def get_api_key(self) -> str:
"""同步读取 API Key — 返回 plaintext
双读窗口优先级``api_key_encrypted`` + ``secrets_store`` > plaintext
若加密列存在但解密失败store None 或解密异常回退到 plaintext
``api_key``保证迁移期可用性
注意``SecretsStore.get_secret`` async本同步方法无法调用
``api_key_encrypted`` 已设置但 ``secrets_store`` None仍回退到
plaintext需要解密时请用 ``aget_api_key``
双读窗口的同步入口无法 await ``SecretsStore.get_secret``
因此加密列需通过异步方法 :meth:`aget_api_key` 读取
本方法始终返回 plaintext ``api_key``迁移期保证可用性
"""
if self.api_key_encrypted and secrets_store is not None:
# 同步上下文无法 await get_secret调用方应改用 aget_api_key。
# 这里保持双读语义的回退:返回 plaintext。
logger.debug("get_api_key: encrypted key set but sync access — fallback to plaintext")
if self.api_key_encrypted:
logger.debug("get_api_key: encrypted key set — use aget_api_key for decryption")
return self.api_key
async def aget_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:

View File

@ -552,30 +552,39 @@ class LLMGateway:
)
# 2. Token + cost limits (daily AND monthly)
await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit")
await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit")
await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit")
await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit")
# 优化:每个 period 只查一次 get_usage复用 summary 检查 token + cost
for period in ("daily", "monthly"):
summary = self._get_usage_summary(dept_id, period)
current_tokens = int(summary.total_tokens)
current_cost_cents = float(summary.total_cost) * 100.0
await self._check_quota_value(
quota_service, db, dept_id, period, "token_limit", current_tokens
)
await self._check_quota_value(
quota_service, db, dept_id, period, "cost_limit", current_cost_cents
)
async def _check_quota_period(
def _get_usage_summary(self, department_id: str, period: str) -> UsageSummary:
"""返回 department_id 在当前 period 的 usage summary单次查询供 token+cost 复用)。"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
async def _check_quota_value(
self,
quota_service: Any,
db: Path,
dept_id: str,
period: str,
quota_type: str,
current: float,
) -> None:
"""Check a single quota (token_limit or cost_limit) for a period.
Raises :class:`QuotaExceededError` if the current usage exceeds
the configured limit. ``period`` is ``"daily"`` or ``"monthly"``;
``quota_type`` is ``"token_limit"`` or ``"cost_limit"``.
"""
if quota_type == "token_limit":
current = await self._get_current_usage_for_quota(dept_id, period)
else:
current = await self._get_current_cost_for_quota(dept_id, period)
"""检查单个配额token_limit 或 cost_limit— current 由调用方预计算传入。"""
allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
@ -587,38 +596,3 @@ class LLMGateway:
limit=limit,
current=current,
)
async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int:
"""Return total tokens used by ``department_id`` in the current period.
``period`` is ``"daily"`` or ``"monthly"``. For ``"daily"`` the
window is since 00:00 UTC today; for ``"monthly"`` since the
first of the current month.
"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
summary = self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
return int(summary.total_tokens)
async def _get_current_cost_for_quota(self, department_id: str, period: str) -> float:
"""Return total cost (in cents) for ``department_id`` in the current period.
``period`` is ``"daily"`` or ``"monthly"``. Quota cost_limit is
stored in cents, so we convert the float USD cost from the usage
store to cents (×100) for comparison.
"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
summary = self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
# cost_limit is stored in cents; convert from USD to cents.
return float(summary.total_cost) * 100.0

View File

@ -556,6 +556,20 @@ async def lifespan(app: FastAPI):
if cal_scheduler is not None:
await cal_scheduler.stop()
# U10/U11 — 关闭渠道适配器并等待后台 webhook 任务完成
# 防止 httpx 连接泄漏和丢失正在处理的 IM 回复
from agentkit.server.routes.channels import _pending_webhook_tasks
if _pending_webhook_tasks:
logger.info("等待 %d 个后台 webhook 任务完成", len(_pending_webhook_tasks))
await asyncio.gather(*_pending_webhook_tasks, return_exceptions=True)
try:
from agentkit.server.routes.channels import close_all_adapters
await close_all_adapters()
except Exception:
logger.debug("close_all_adapters 异常已忽略")
def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
"""Handle config change by reloading affected components.

View File

@ -506,6 +506,15 @@ async def _process_inbound_message(
适配器由 ``_adapter_cache`` 管理不在 per-request 关闭关闭会清空 token 缓存
适配器类型不限 出站消息的 ``channel`` 取自入站消息以匹配平台
"""
async def _direct_chat(llm_gateway: Any, routing: Any) -> str:
"""DIRECT_CHAT 路径 — 直接调用 LLM主路径与 ReAct 回退共用)。"""
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
return response.content
async with _get_webhook_semaphore():
try:
request_preprocessor = getattr(app_state, "request_preprocessor", None)
@ -523,13 +532,8 @@ async def _process_inbound_message(
final_content = ""
execution_mode = getattr(routing, "execution_mode", None)
# DIRECT_CHAT 模式 — 直接调用 LLM
if execution_mode == ExecutionMode.DIRECT_CHAT:
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
final_content = response.content
final_content = await _direct_chat(llm_gateway, routing)
else:
# REACT 或其他模式 — 优先使用 ReActEngine失败回退到 DIRECT_CHAT
try:
@ -544,11 +548,7 @@ async def _process_inbound_message(
final_content = getattr(result, "content", "") or ""
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
response = await llm_gateway.chat(
messages=[{"role": "user", "content": message.content}],
model=routing.model or "default",
)
final_content = response.content
final_content = await _direct_chat(llm_gateway, routing)
if not final_content:
logger.warning("消息处理未产生内容 — 不发送回复")
@ -620,9 +620,16 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
except WeComURLVerification as e:
# 企微 URL 验证 — 返回 XML 响应
return Response(content=e.response_xml, media_type="application/xml")
except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
return {"code": 0, "msg": "invalid_payload"}
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200
# 持有 task 引用防止 GC 回收正在运行的后台任务
# 有界化:超过 2x 并发上限时拒绝新任务(防突发流量下 set 无界增长)
if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2:
logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks))
raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试")
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
_pending_webhook_tasks.add(task)
task.add_done_callback(_pending_webhook_tasks.discard)

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import base64
import hashlib
import time
from unittest.mock import AsyncMock, MagicMock
from xml.etree import ElementTree as ET
@ -108,7 +109,7 @@ class TestSignatureVerification:
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
encrypt = adapter._extract_encrypt(body) or ""
ts, nonce = "1609459200", "n1"
ts, nonce = str(int(time.time())), "n1"
sig = _signature(adapter.token, ts, nonce, encrypt)
headers = {"msg_signature": sig, "timestamp": ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is True
@ -118,10 +119,23 @@ class TestSignatureVerification:
adapter = _make_adapter()
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
ts, nonce = "1609459200", "n1"
ts, nonce = str(int(time.time())), "n1"
headers = {"msg_signature": "tampered", "timestamp": ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is False
async def test_expired_timestamp_rejected(self):
"""时间戳超出 5 分钟窗口返回 False防重放"""
adapter = _make_adapter()
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
encrypt = adapter._extract_encrypt(body) or ""
# 10 分钟前的时间戳 — 超出 300s 窗口
old_ts = str(int(time.time()) - 600)
nonce = "n1"
sig = _signature(adapter.token, old_ts, nonce, encrypt)
headers = {"msg_signature": sig, "timestamp": old_ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is False
async def test_missing_query_params(self):
"""缺少 msg_signature/timestamp/nonce 返回 False。"""
adapter = _make_adapter()

View File

@ -276,15 +276,27 @@ class TestKBCachingDisabled:
def test_should_cache_returns_false_when_disabled(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache(kb_caching_disabled=True) is False
assert manager.should_cache(kb_caching_disabled=True, user_id="u1") is False
def test_should_cache_returns_true_when_enabled(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache(kb_caching_disabled=False) is True
assert manager.should_cache(kb_caching_disabled=False, user_id="u1") is True
def test_should_cache_default_true(self):
def test_should_cache_default_with_user(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache() is True
assert manager.should_cache(user_id="u1") is True
def test_should_cache_rejects_anonymous(self):
"""安全要求 (a)(e) — per_user_namespace 开启且 user_id=None 时拒绝缓存。"""
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache(user_id=None) is False
def test_should_cache_allows_anonymous_when_namespace_off(self):
"""per_user_namespace=False 时匿名请求可缓存(向后兼容场景)。"""
cfg = LitellmCacheConfig(backend="memory")
cfg.per_user_namespace = False
manager = LitellmCacheManager(cfg)
assert manager.should_cache(user_id=None) is True
# ---------------------------------------------------------------------------

View File

@ -25,17 +25,9 @@ from agentkit.llm.config import LLMConfig, ProviderConfig
def test_get_api_key_plaintext_no_store():
"""无 secrets_store 时 get_api_key 返回 plaintext。"""
"""get_api_key 返回 plaintext(同步入口,不处理加密列)"""
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
assert pconf.get_api_key(None) == "sk-xxx"
def test_get_api_key_plaintext_with_store_sync():
"""有 secrets_store 但同步调用 — 仍返回 plaintextasync 解密不可用)。"""
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
# 即使传 store同步路径无法 decrypt回退 plaintext
store = object() # 任意非 None 对象
assert pconf.get_api_key(store) == "sk-xxx" # type: ignore[arg-type]
assert pconf.get_api_key() == "sk-xxx"
# ----------------------------------------------------------------------