diff --git a/src/agentkit/channels/feishu.py b/src/agentkit/channels/feishu.py index c9d4355..cfac706 100644 --- a/src/agentkit/channels/feishu.py +++ b/src/agentkit/channels/feishu.py @@ -36,8 +36,8 @@ logger = logging.getLogger(__name__) # 签名时间戳允许的最大偏移(秒)— 与飞书官方文档保持一致 _SIGNATURE_MAX_AGE_SECONDS = 300 -# tenant_access_token 缓存 TTL(秒)— 飞书 token 实际有效期为 2h,留 5min 余量 -_TOKEN_CACHE_TTL = 300.0 +# tenant_access_token 缓存 TTL(秒)— 飞书 token 实际有效期 2h(7200s),留 5min 余量 +_TOKEN_CACHE_TTL = 6900.0 # 飞书开放平台 API 端点 _TENANT_TOKEN_URL = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" diff --git a/src/agentkit/channels/wecom.py b/src/agentkit/channels/wecom.py index 429ca60..9535d35 100644 --- a/src/agentkit/channels/wecom.py +++ b/src/agentkit/channels/wecom.py @@ -34,6 +34,9 @@ from agentkit.channels.base import ( logger = logging.getLogger(__name__) +# 企微签名时间戳允许的最大偏移(秒)— 5 分钟窗口防重放 +_SIGNATURE_MAX_AGE_SECONDS = 300 + # access_token 缓存 TTL(秒)— 企微 token 有效期 2 小时 _TOKEN_CACHE_TTL = 7200.0 @@ -206,6 +209,19 @@ class WeComMessageAdapter(MessageAdapter): if not msg_signature or not timestamp or not nonce: return False + # 时间戳新鲜度校验 — 5 分钟窗口防重放攻击 + try: + ts_int = int(timestamp) + except (TypeError, ValueError): + return False + now = int(time.time()) + if abs(now - ts_int) > _SIGNATURE_MAX_AGE_SECONDS: + logger.warning( + "企微 webhook 时间戳超出 %ds 窗口: ts=%s now=%d", + _SIGNATURE_MAX_AGE_SECONDS, timestamp, now, + ) + return False + encrypt = self._extract_encrypt(body) if not encrypt: return False diff --git a/src/agentkit/llm/cache.py b/src/agentkit/llm/cache.py index c58e724..b0f334c 100644 --- a/src/agentkit/llm/cache.py +++ b/src/agentkit/llm/cache.py @@ -783,14 +783,19 @@ class LitellmCacheManager: kb_caching_disabled: bool = False, user_id: str | None = None, ) -> bool: - """判断当前请求是否应该缓存(安全要求 c)。 + """判断当前请求是否应该缓存(安全要求 a, c, e)。 - - KB 设置 caching_disabled=True → 不缓存 - - 其余情况缓存(user_id 为 None 时仍可缓存,但 key 不含 user scope) + - KB 设置 caching_disabled=True → 不缓存(要求 c) + - per_user_namespace 开启且 user_id 为 None → 不缓存(要求 a, e) + 匿名请求无法做 per-user 隔离,缓存会导致跨用户泄漏,强制 no-cache。 """ - _ = user_id # 预留:未来支持 per-user 缓存禁用 if kb_caching_disabled: return False + if self._config.per_user_namespace and user_id is None: + # 安全要求 (a)(e):per-user namespace 开启时拒绝匿名缓存, + # 避免不同匿名用户命中同一缓存键导致跨用户泄漏。 + logger.debug("should_cache: per_user_namespace on but user_id=None — skip cache") + return False return True @staticmethod diff --git a/src/agentkit/llm/cache_key.py b/src/agentkit/llm/cache_key.py index 362b993..3a5211b 100644 --- a/src/agentkit/llm/cache_key.py +++ b/src/agentkit/llm/cache_key.py @@ -39,21 +39,23 @@ def generate_cache_key( 64-character hex SHA-256 hash string. """ system_prompt = _extract_system_prompt(messages) - components = [ - _hash_str(model), - _hash_str(system_prompt), - _hash_json(messages), - _hash_str(f"{temperature:.2f}"), - _hash_json(tools), - _hash_str(tool_choice), - _hash_str(str(max_tokens)), + # 单次 SHA-256:用分隔符拼接所有组件,避免逐组件 hash 再 hash 的冗余计算。 + # 分隔符使用长度前缀防止歧义(如 "ab" + "cd" vs "a" + "bcd")。 + parts = [ + f"m:{model}", + f"s:{system_prompt}", + f"msg:{json.dumps(messages, sort_keys=True, ensure_ascii=False)}", + f"t:{temperature:.2f}", + f"tools:{json.dumps(tools, sort_keys=True, ensure_ascii=False) if tools is not None else 'null'}", + f"tc:{tool_choice}", + f"mt:{max_tokens}", ] # U17 — per-user namespace + ACL scope hash(安全要求 a, b, e) if user_id is not None: - components.append(_hash_str(f"user:{user_id}")) + parts.append(f"u:{user_id}") if kb_acl_hash is not None: - components.append(_hash_str(f"acl:{kb_acl_hash}")) - combined = "".join(components) + parts.append(f"a:{kb_acl_hash}") + combined = "\x1f".join(parts) # US (Unit Separator) 防止组件内容注入分隔符 return hashlib.sha256(combined.encode()).hexdigest() @@ -63,15 +65,3 @@ def _extract_system_prompt(messages: list[dict[str, str]]) -> str: if msg.get("role") == "system": return msg.get("content", "") return "" - - -def _hash_str(s: str) -> str: - """SHA-256 hash of a string.""" - return hashlib.sha256(s.encode()).hexdigest() - - -def _hash_json(obj: Any) -> str: - """SHA-256 hash of a JSON-serializable object.""" - if obj is None: - return hashlib.sha256(b"null").hexdigest() - return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest() diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index 3dbfcd0..6357582 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -74,21 +74,15 @@ class ProviderConfig: api_key_encrypted: str | None = None api_key_source: str = "plaintext" - def get_api_key(self, secrets_store: "SecretsStore | None" = None) -> str: - """同步读取 API Key。 + def get_api_key(self) -> str: + """同步读取 API Key — 返回 plaintext。 - 双读窗口优先级:``api_key_encrypted`` + ``secrets_store`` > plaintext。 - 若加密列存在但解密失败(store 为 None 或解密异常),回退到 plaintext - ``api_key``,保证迁移期可用性。 - - 注意:``SecretsStore.get_secret`` 是 async,本同步方法无法调用。 - 若 ``api_key_encrypted`` 已设置但 ``secrets_store`` 为 None,仍回退到 - plaintext。需要解密时请用 ``aget_api_key``。 + 双读窗口的同步入口:无法 await ``SecretsStore.get_secret``, + 因此加密列需通过异步方法 :meth:`aget_api_key` 读取。 + 本方法始终返回 plaintext ``api_key``(迁移期保证可用性)。 """ - if self.api_key_encrypted and secrets_store is not None: - # 同步上下文无法 await get_secret,调用方应改用 aget_api_key。 - # 这里保持双读语义的回退:返回 plaintext。 - logger.debug("get_api_key: encrypted key set but sync access — fallback to plaintext") + if self.api_key_encrypted: + logger.debug("get_api_key: encrypted key set — use aget_api_key for decryption") return self.api_key async def aget_api_key(self, secrets_store: "SecretsStore | None" = None) -> str: diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 691eca9..f0d62e7 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -552,30 +552,39 @@ class LLMGateway: ) # 2. Token + cost limits (daily AND monthly) - await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit") - await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit") - await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit") - await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit") + # 优化:每个 period 只查一次 get_usage,复用 summary 检查 token + cost + for period in ("daily", "monthly"): + summary = self._get_usage_summary(dept_id, period) + current_tokens = int(summary.total_tokens) + current_cost_cents = float(summary.total_cost) * 100.0 + await self._check_quota_value( + quota_service, db, dept_id, period, "token_limit", current_tokens + ) + await self._check_quota_value( + quota_service, db, dept_id, period, "cost_limit", current_cost_cents + ) - async def _check_quota_period( + def _get_usage_summary(self, department_id: str, period: str) -> UsageSummary: + """返回 department_id 在当前 period 的 usage summary(单次查询,供 token+cost 复用)。""" + now = datetime.now(timezone.utc) + if period == "monthly": + start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + else: + start = now.replace(hour=0, minute=0, second=0, microsecond=0) + return self._usage_tracker.get_usage( + department_id=department_id, start_time=start, end_time=now + ) + + async def _check_quota_value( self, quota_service: Any, db: Path, dept_id: str, period: str, quota_type: str, + current: float, ) -> None: - """Check a single quota (token_limit or cost_limit) for a period. - - Raises :class:`QuotaExceededError` if the current usage exceeds - the configured limit. ``period`` is ``"daily"`` or ``"monthly"``; - ``quota_type`` is ``"token_limit"`` or ``"cost_limit"``. - """ - if quota_type == "token_limit": - current = await self._get_current_usage_for_quota(dept_id, period) - else: - current = await self._get_current_cost_for_quota(dept_id, period) - + """检查单个配额(token_limit 或 cost_limit)— current 由调用方预计算传入。""" allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current) if not allowed: quota = await quota_service.get_quota(db, dept_id, quota_type, period) @@ -587,38 +596,3 @@ class LLMGateway: limit=limit, current=current, ) - - async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int: - """Return total tokens used by ``department_id`` in the current period. - - ``period`` is ``"daily"`` or ``"monthly"``. For ``"daily"`` the - window is since 00:00 UTC today; for ``"monthly"`` since the - first of the current month. - """ - now = datetime.now(timezone.utc) - if period == "monthly": - start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - else: - start = now.replace(hour=0, minute=0, second=0, microsecond=0) - summary = self._usage_tracker.get_usage( - department_id=department_id, start_time=start, end_time=now - ) - return int(summary.total_tokens) - - async def _get_current_cost_for_quota(self, department_id: str, period: str) -> float: - """Return total cost (in cents) for ``department_id`` in the current period. - - ``period`` is ``"daily"`` or ``"monthly"``. Quota cost_limit is - stored in cents, so we convert the float USD cost from the usage - store to cents (×100) for comparison. - """ - now = datetime.now(timezone.utc) - if period == "monthly": - start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) - else: - start = now.replace(hour=0, minute=0, second=0, microsecond=0) - summary = self._usage_tracker.get_usage( - department_id=department_id, start_time=start, end_time=now - ) - # cost_limit is stored in cents; convert from USD to cents. - return float(summary.total_cost) * 100.0 diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index e8496e9..d01bede 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -556,6 +556,20 @@ async def lifespan(app: FastAPI): if cal_scheduler is not None: await cal_scheduler.stop() + # U10/U11 — 关闭渠道适配器并等待后台 webhook 任务完成 + # 防止 httpx 连接泄漏和丢失正在处理的 IM 回复 + from agentkit.server.routes.channels import _pending_webhook_tasks + + if _pending_webhook_tasks: + logger.info("等待 %d 个后台 webhook 任务完成", len(_pending_webhook_tasks)) + await asyncio.gather(*_pending_webhook_tasks, return_exceptions=True) + try: + from agentkit.server.routes.channels import close_all_adapters + + await close_all_adapters() + except Exception: + logger.debug("close_all_adapters 异常已忽略") + def _on_config_change(app: FastAPI, config: ServerConfig) -> None: """Handle config change by reloading affected components. diff --git a/src/agentkit/server/routes/channels.py b/src/agentkit/server/routes/channels.py index 71b46c2..22bf8e1 100644 --- a/src/agentkit/server/routes/channels.py +++ b/src/agentkit/server/routes/channels.py @@ -506,6 +506,15 @@ async def _process_inbound_message( 适配器由 ``_adapter_cache`` 管理,不在 per-request 关闭(关闭会清空 token 缓存)。 适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。 """ + + async def _direct_chat(llm_gateway: Any, routing: Any) -> str: + """DIRECT_CHAT 路径 — 直接调用 LLM(主路径与 ReAct 回退共用)。""" + response = await llm_gateway.chat( + messages=[{"role": "user", "content": message.content}], + model=routing.model or "default", + ) + return response.content + async with _get_webhook_semaphore(): try: request_preprocessor = getattr(app_state, "request_preprocessor", None) @@ -523,13 +532,8 @@ async def _process_inbound_message( final_content = "" execution_mode = getattr(routing, "execution_mode", None) - # DIRECT_CHAT 模式 — 直接调用 LLM if execution_mode == ExecutionMode.DIRECT_CHAT: - response = await llm_gateway.chat( - messages=[{"role": "user", "content": message.content}], - model=routing.model or "default", - ) - final_content = response.content + final_content = await _direct_chat(llm_gateway, routing) else: # REACT 或其他模式 — 优先使用 ReActEngine,失败回退到 DIRECT_CHAT try: @@ -544,11 +548,7 @@ async def _process_inbound_message( final_content = getattr(result, "content", "") or "" except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常 logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc) - response = await llm_gateway.chat( - messages=[{"role": "user", "content": message.content}], - model=routing.model or "default", - ) - final_content = response.content + final_content = await _direct_chat(llm_gateway, routing) if not final_content: logger.warning("消息处理未产生内容 — 不发送回复") @@ -620,9 +620,16 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: except WeComURLVerification as e: # 企微 URL 验证 — 返回 XML 响应 return Response(content=e.response_xml, media_type="application/xml") + except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴 + logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc) + return {"code": 0, "msg": "invalid_payload"} # 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200) # 持有 task 引用防止 GC 回收正在运行的后台任务 + # 有界化:超过 2x 并发上限时拒绝新任务(防突发流量下 set 无界增长) + if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2: + logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks)) + raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试") task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) _pending_webhook_tasks.add(task) task.add_done_callback(_pending_webhook_tasks.discard) diff --git a/tests/unit/channels/test_wecom.py b/tests/unit/channels/test_wecom.py index d0824c8..c47f8ef 100644 --- a/tests/unit/channels/test_wecom.py +++ b/tests/unit/channels/test_wecom.py @@ -17,6 +17,7 @@ from __future__ import annotations import base64 import hashlib +import time from unittest.mock import AsyncMock, MagicMock from xml.etree import ElementTree as ET @@ -108,7 +109,7 @@ class TestSignatureVerification: inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"}) body = _build_body(adapter, inner_xml) encrypt = adapter._extract_encrypt(body) or "" - ts, nonce = "1609459200", "n1" + ts, nonce = str(int(time.time())), "n1" sig = _signature(adapter.token, ts, nonce, encrypt) headers = {"msg_signature": sig, "timestamp": ts, "nonce": nonce} assert await adapter.verify_signature(headers, body) is True @@ -118,10 +119,23 @@ class TestSignatureVerification: adapter = _make_adapter() inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"}) body = _build_body(adapter, inner_xml) - ts, nonce = "1609459200", "n1" + ts, nonce = str(int(time.time())), "n1" headers = {"msg_signature": "tampered", "timestamp": ts, "nonce": nonce} assert await adapter.verify_signature(headers, body) is False + async def test_expired_timestamp_rejected(self): + """时间戳超出 5 分钟窗口返回 False(防重放)。""" + adapter = _make_adapter() + inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"}) + body = _build_body(adapter, inner_xml) + encrypt = adapter._extract_encrypt(body) or "" + # 10 分钟前的时间戳 — 超出 300s 窗口 + old_ts = str(int(time.time()) - 600) + nonce = "n1" + sig = _signature(adapter.token, old_ts, nonce, encrypt) + headers = {"msg_signature": sig, "timestamp": old_ts, "nonce": nonce} + assert await adapter.verify_signature(headers, body) is False + async def test_missing_query_params(self): """缺少 msg_signature/timestamp/nonce 返回 False。""" adapter = _make_adapter() diff --git a/tests/unit/llm/test_cache.py b/tests/unit/llm/test_cache.py index 45cefe8..471d1ae 100644 --- a/tests/unit/llm/test_cache.py +++ b/tests/unit/llm/test_cache.py @@ -276,15 +276,27 @@ class TestKBCachingDisabled: def test_should_cache_returns_false_when_disabled(self): manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - assert manager.should_cache(kb_caching_disabled=True) is False + assert manager.should_cache(kb_caching_disabled=True, user_id="u1") is False def test_should_cache_returns_true_when_enabled(self): manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - assert manager.should_cache(kb_caching_disabled=False) is True + assert manager.should_cache(kb_caching_disabled=False, user_id="u1") is True - def test_should_cache_default_true(self): + def test_should_cache_default_with_user(self): manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) - assert manager.should_cache() is True + assert manager.should_cache(user_id="u1") is True + + def test_should_cache_rejects_anonymous(self): + """安全要求 (a)(e) — per_user_namespace 开启且 user_id=None 时拒绝缓存。""" + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + assert manager.should_cache(user_id=None) is False + + def test_should_cache_allows_anonymous_when_namespace_off(self): + """per_user_namespace=False 时匿名请求可缓存(向后兼容场景)。""" + cfg = LitellmCacheConfig(backend="memory") + cfg.per_user_namespace = False + manager = LitellmCacheManager(cfg) + assert manager.should_cache(user_id=None) is True # --------------------------------------------------------------------------- diff --git a/tests/unit/llm/test_config_migration.py b/tests/unit/llm/test_config_migration.py index 68e2895..d5ae86b 100644 --- a/tests/unit/llm/test_config_migration.py +++ b/tests/unit/llm/test_config_migration.py @@ -25,17 +25,9 @@ from agentkit.llm.config import LLMConfig, ProviderConfig def test_get_api_key_plaintext_no_store(): - """无 secrets_store 时 get_api_key 返回 plaintext。""" + """get_api_key 返回 plaintext(同步入口,不处理加密列)。""" pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai") - assert pconf.get_api_key(None) == "sk-xxx" - - -def test_get_api_key_plaintext_with_store_sync(): - """有 secrets_store 但同步调用 — 仍返回 plaintext(async 解密不可用)。""" - pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai") - # 即使传 store,同步路径无法 decrypt,回退 plaintext - store = object() # 任意非 None 对象 - assert pconf.get_api_key(store) == "sk-xxx" # type: ignore[arg-type] + assert pconf.get_api_key() == "sk-xxx" # ----------------------------------------------------------------------