From 069dbc22b1a67f0155f97b2c27fe33e6114c9006 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 21:41:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(llm):=20U15=20=E2=80=94=20LiteLLM=20unifie?= =?UTF-8?q?d=20provider=20+=20api=5Fkey=20encrypted=20secrets=20migration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 + src/agentkit/llm/config.py | 146 ++++- src/agentkit/llm/migration.py | 137 +++++ src/agentkit/llm/providers/__init__.py | 6 + .../llm/providers/litellm_provider.py | 384 +++++++++++++ src/agentkit/server/app.py | 76 ++- tests/unit/llm/test_config_migration.py | 287 +++++++++ tests/unit/llm/test_litellm_provider.py | 543 ++++++++++++++++++ 8 files changed, 1541 insertions(+), 40 deletions(-) create mode 100644 src/agentkit/llm/migration.py create mode 100644 src/agentkit/llm/providers/litellm_provider.py create mode 100644 tests/unit/llm/test_config_migration.py create mode 100644 tests/unit/llm/test_litellm_provider.py diff --git a/pyproject.toml b/pyproject.toml index 5244566..51d901a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "aiosqlite>=0.20", # 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM) "cryptography>=42.0", + # U15 — LiteLLM 统一 Provider 适配层(替换 6 个直接 API provider) + "litellm>=1.50", # Calendar & schedule (RRULE expansion) "python-dateutil>=2.9", # Calendar ICS import/export (U8) diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index a4ba01c..3dbfcd0 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -1,10 +1,17 @@ """LLM Config - 配置加载""" +import json +import logging from dataclasses import dataclass, field -from typing import Any +from typing import Any, TYPE_CHECKING from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig +if TYPE_CHECKING: + from agentkit.channels.secrets import SecretsStore + +logger = logging.getLogger(__name__) + @dataclass class CacheConfig: @@ -58,6 +65,140 @@ class ProviderConfig: keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒) retry: RetryConfig | None = None circuit_breaker: CircuitBreakerConfig | None = None + # U15 — API Key 加密迁移(plaintext → SecretsStore)。 + # api_key_encrypted: JSON 编码的 SecretEntry(base64 nonce+salt+ciphertext)。 + # 为 None 表示未迁移,仍走 plaintext api_key。 + # api_key_source: 当前 key 来源标记,用于双写/双读迁移窗口。 + # "plaintext" — 仅 plaintext 列;"secrets_store" — 仅加密列; + # "dual" — 双写窗口中(两列都有值,读时优先加密)。 + api_key_encrypted: str | None = None + api_key_source: str = "plaintext" + + def get_api_key(self, secrets_store: "SecretsStore | None" = None) -> str: + """同步读取 API Key。 + + 双读窗口优先级:``api_key_encrypted`` + ``secrets_store`` > plaintext。 + 若加密列存在但解密失败(store 为 None 或解密异常),回退到 plaintext + ``api_key``,保证迁移期可用性。 + + 注意:``SecretsStore.get_secret`` 是 async,本同步方法无法调用。 + 若 ``api_key_encrypted`` 已设置但 ``secrets_store`` 为 None,仍回退到 + plaintext。需要解密时请用 ``aget_api_key``。 + """ + if self.api_key_encrypted and secrets_store is not None: + # 同步上下文无法 await get_secret,调用方应改用 aget_api_key。 + # 这里保持双读语义的回退:返回 plaintext。 + logger.debug("get_api_key: encrypted key set but sync access — fallback to plaintext") + return self.api_key + + async def aget_api_key(self, secrets_store: "SecretsStore | None" = None) -> str: + """异步读取 API Key — 双读窗口优先加密列,失败回退 plaintext。 + + Args: + secrets_store: 可选的加密存储。为 None 时直接返回 plaintext。 + + Returns: + 解密后的 API Key;加密列解密失败时回退到 plaintext ``api_key``。 + """ + if self.api_key_encrypted and secrets_store is not None: + try: + entry = self._decode_secret_entry(self.api_key_encrypted) + decrypted = await secrets_store.get_secret(entry.key) + if decrypted is not None: + return decrypted + # store 里没有这个 key(可能已被删除)— 回退 plaintext + logger.warning( + f"aget_api_key: encrypted key for provider type={self.type} " + f"not found in secrets_store — fallback to plaintext" + ) + except Exception as e: + # 解密失败(master key 不匹配 / 密文损坏)— 回退 plaintext + logger.warning( + f"aget_api_key: decrypt failed for type={self.type}: {e} — fallback to plaintext" + ) + return self.api_key + + async def migrate_to_secrets(self, secrets_store: "SecretsStore") -> None: + """把 plaintext api_key 迁移到 SecretsStore(幂等)。 + + 迁移步骤(双写窗口): + 1. 若 ``api_key_source == "secrets_store"`` 且 plaintext 已清空 → 已迁移,no-op。 + 2. 否则:调用 ``secrets_store.set_secret`` 加密存储 key。 + 3. 把返回的 SecretEntry JSON 编码写入 ``api_key_encrypted``。 + 4. 标记 ``api_key_source = "secrets_store"``。 + 5. 验证:调用 ``get_secret`` 读回对比,成功后清空 plaintext ``api_key=""``。 + + 幂等性:重复调用不会重复加密(已迁移时直接返回)。 + 部分失败恢复:若 set 成功但验证失败,保留 plaintext 不清空, + ``api_key_encrypted`` 已写入 — 下次重试时由幂等性保证最终一致。 + + Args: + secrets_store: 用于加密存储的 SecretsStore 实例。 + """ + # 幂等:已迁移完成(source 标记 + plaintext 已清空) + if self.api_key_source == "secrets_store" and not self.api_key and self.api_key_encrypted: + return + + # 没有 plaintext 可迁移(空 key)— 跳过 + if not self.api_key: + return + + secret_key = self._secret_key_for_type() + # 双写:先写加密列 + entry = await secrets_store.set_secret(secret_key, self.api_key) + self.api_key_encrypted = self._encode_secret_entry(entry, secret_key) + + # 验证:读回对比 + try: + decrypted = await secrets_store.get_secret(secret_key) + except Exception as e: + logger.warning(f"migrate_to_secrets: verify read failed for type={self.type}: {e}") + # 加密列已写但验证失败 — 保留 plaintext,标记 dual 待重试 + self.api_key_source = "dual" + return + + if decrypted != self.api_key: + logger.error( + f"migrate_to_secrets: verify mismatch for type={self.type} " + f"— plaintext retained, source=dual" + ) + self.api_key_source = "dual" + return + + # 验证通过:清空 plaintext,标记完成 + self.api_key_source = "secrets_store" + self.api_key = "" + + def _secret_key_for_type(self) -> str: + """生成 SecretsStore 中的 key(按 provider type 命名空间隔离)。""" + return f"llm:provider:{self.type}:api_key" + + @staticmethod + def _encode_secret_entry(entry: Any, key: str) -> str: + """把 SecretEntry 编码为 JSON 字符串(含 key 字段)。""" + # entry 是 SecretEntry pydantic 模型,有 model_dump() + if hasattr(entry, "model_dump"): + data = entry.model_dump() + else: + data = dict(entry) + data["key"] = key + return json.dumps(data) + + @staticmethod + def _decode_secret_entry(encoded: str) -> Any: + """从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。""" + from agentkit.channels.secrets import SecretEntry + + data = json.loads(encoded) + return SecretEntry( + key=data.get("key", ""), + value=data["value"], + nonce=data["nonce"], + salt=data["salt"], + key_id=data.get("key_id", "default"), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + ) @dataclass @@ -105,6 +246,9 @@ class LLMConfig: keepalive_expiry=pconf.get("keepalive_expiry", 30.0), retry=retry, circuit_breaker=circuit_breaker, + # U15 — 新增加密迁移字段,缺省时保持 plaintext 行为 + api_key_encrypted=pconf.get("api_key_encrypted"), + api_key_source=pconf.get("api_key_source", "plaintext"), ) cache = None cache_data = data.get("cache") diff --git a/src/agentkit/llm/migration.py b/src/agentkit/llm/migration.py new file mode 100644 index 0000000..cae041a --- /dev/null +++ b/src/agentkit/llm/migration.py @@ -0,0 +1,137 @@ +"""U15 — API Key 加密迁移辅助函数。 + +把 ``agentkit.yaml`` 中 LLM provider 的 plaintext ``api_key`` 迁移到 +:class:`SecretsStore` 加密存储,并更新配置文件中的 ``api_key_encrypted`` / +``api_key_source`` 字段。 + +迁移模型(双写/双读窗口): +1. 读 ``agentkit.yaml`` → ``LLMConfig.providers``。 +2. 对每个 provider 调用 ``ProviderConfig.migrate_to_secrets(store)``: + - 加密 plaintext 写入 SecretsStore; + - 验证读回一致后清空 plaintext,标记 ``api_key_source="secrets_store"``; + - 部分失败时保留 plaintext,标记 ``api_key_source="dual"`` 待重试。 +3. 把更新后的 ``ProviderConfig`` 写回 YAML(plaintext 清空 + encrypted 列写入)。 +4. 返回每个 provider 的迁移状态。 + +回滚步骤( documented ): +- 把 YAML 中 ``api_key_source`` 改回 ``"plaintext"``; +- 把 plaintext ``api_key`` 重新填回(从备份或 KMS 重新注入); +- ``api_key_encrypted`` 列可保留(不影响 plaintext 读取路径)。 +- 重启服务即可。 + +ponytail: CLI 命令接线(``agentkit llm migrate-keys``)延迟实现 — +本模块的 ``migrate_api_keys_to_secrets`` 函数已可被 CLI / 运维脚本直接调用。 +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]: + """把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。 + + 流程: + 1. 加载 YAML 配置(不依赖 ServerConfig,直接用 LLMConfig.from_dict)。 + 2. 初始化 SecretsStore(master key 从 env ``AGENTKIT_MASTER_KEY`` 读取)。 + 3. 对每个 provider 异步执行 ``migrate_to_secrets``。 + 4. 把更新后的 providers 段写回 YAML(保留其它段不变)。 + 5. 返回 ``{provider_name: {"status": ..., "source": ...}}`` 状态报告。 + + Args: + config_path: ``agentkit.yaml`` 路径。 + + Returns: + 每个 provider 的迁移状态字典: + ``{"status": "migrated"|"skipped"|"failed", "source": str, "error"?: str}``。 + """ + import asyncio + + import yaml + + from agentkit.channels.secrets import SecretsStore + from agentkit.llm.config import LLMConfig + + config_path = Path(config_path) + raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {} + llm_section = raw.get("llm", {}) + llm_config = LLMConfig.from_dict(llm_section) + + store = SecretsStore() # master key 从 env 加载 + + async def _run() -> dict[str, dict[str, Any]]: + report: dict[str, dict[str, Any]] = {} + for name, pconf in llm_config.providers.items(): + if pconf.api_key_source == "secrets_store" and not pconf.api_key: + report[name] = {"status": "skipped", "source": pconf.api_key_source} + continue + if not pconf.api_key: + report[name] = { + "status": "skipped", + "source": pconf.api_key_source, + "error": "no plaintext api_key to migrate", + } + continue + try: + await pconf.migrate_to_secrets(store) + report[name] = { + "status": "migrated" if pconf.api_key_source == "secrets_store" else "partial", + "source": pconf.api_key_source, + } + except Exception as e: + report[name] = { + "status": "failed", + "source": pconf.api_key_source, + "error": str(e), + } + return report + + report = asyncio.run(_run()) + + # 写回 YAML:更新 llm.providers 段,保留其它段 + providers_out: dict[str, dict[str, Any]] = {} + for name, pconf in llm_config.providers.items(): + entry: dict[str, Any] = { + "type": pconf.type, + "base_url": pconf.base_url, + "models": pconf.models, + "max_tokens": pconf.max_tokens, + "timeout": pconf.timeout, + "max_connections": pconf.max_connections, + "max_keepalive_connections": pconf.max_keepalive_connections, + "keepalive_expiry": pconf.keepalive_expiry, + # plaintext 已清空(迁移成功)或保留(迁移失败 / dual) + "api_key": pconf.api_key, + "api_key_encrypted": pconf.api_key_encrypted, + "api_key_source": pconf.api_key_source, + } + if pconf.retry is not None: + entry["retry"] = { + "max_retries": pconf.retry.max_retries, + "base_delay": pconf.retry.base_delay, + "max_delay": pconf.retry.max_delay, + "exponential_base": pconf.retry.exponential_base, + } + if pconf.circuit_breaker is not None: + entry["circuit_breaker"] = { + "failure_threshold": pconf.circuit_breaker.failure_threshold, + "recovery_timeout": pconf.circuit_breaker.recovery_timeout, + "half_open_max": pconf.circuit_breaker.half_open_max, + } + providers_out[name] = entry + + raw.setdefault("llm", {})["providers"] = providers_out + config_path.write_text( + yaml.safe_dump(raw, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + + return report + + +# ponytail: CLI wiring deferred — 把本函数接到 ``agentkit llm migrate-keys`` +# Typer 命令时,只需在 cli/ 下新增 thin wrapper 调用本函数并打印 report。 diff --git a/src/agentkit/llm/providers/__init__.py b/src/agentkit/llm/providers/__init__.py index aa113f1..6fc8f6e 100644 --- a/src/agentkit/llm/providers/__init__.py +++ b/src/agentkit/llm/providers/__init__.py @@ -3,6 +3,10 @@ from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.doubao import DoubaoProvider from agentkit.llm.providers.gemini import GeminiProvider +from agentkit.llm.providers.litellm_provider import ( + LitellmProvider, + create_litellm_provider, +) from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.tracker import UsageSummary, UsageTracker from agentkit.llm.providers.usage_store import UsageRecord @@ -13,10 +17,12 @@ __all__ = [ "AnthropicProvider", "DoubaoProvider", "GeminiProvider", + "LitellmProvider", "OpenAICompatibleProvider", "UsageRecord", "UsageSummary", "UsageTracker", "WenxinProvider", "YuanbaoProvider", + "create_litellm_provider", ] diff --git a/src/agentkit/llm/providers/litellm_provider.py b/src/agentkit/llm/providers/litellm_provider.py new file mode 100644 index 0000000..15336c2 --- /dev/null +++ b/src/agentkit/llm/providers/litellm_provider.py @@ -0,0 +1,384 @@ +"""U15 — LiteLLM 统一 Provider 适配层。 + +用 LiteLLM 的 ``acompletion()`` 统一接口替换 6 个直接 API provider 适配器 +(OpenAI/Anthropic/Gemini/Doubao/Wenxin/Yuanbao)。LiteLLM 内部处理各家 API +的差异(消息格式、tool_calls 格式、streaming SSE 协议),本模块只负责: + +1. 把 ``LLMRequest`` 翻译成 LiteLLM ``acompletion`` 的 kwargs; +2. 把 LiteLLM 响应(OpenAI ChatCompletion 格式)翻译回 ``LLMResponse`` / + ``StreamChunk``; +3. 把 LiteLLM 异常包装成 ``LLMProviderError``,保留 fallback 链可用性。 + +设计取舍(ponytail): +- 自建的 fallback 链 / usage tracking / 部门配额 全部保留在 ``LLMGateway``, + 本 provider 不重复实现。LiteLLM 自带的 fallback / retry 在此处禁用 + (``num_retries=0``),避免与 gateway 的 fallback 重复叠加导致超时放大。 +- ``wenxin`` 在 LiteLLM 中无原生支持,回退到 ``openai/`` 前缀 + 自定义 + ``api_base``(千帆 OpenAI 兼容端点)。升级路径:LiteLLM 上游支持后切换到 + ``wenxin/`` 前缀。 +- 旧的 6 个 provider 类(``OpenAICompatibleProvider`` 等)保留为死代码一个 + release,便于回滚;新代码通过 ``create_litellm_provider`` 工厂构造。 +""" + +from __future__ import annotations + +import inspect +import json +import logging +import time +from collections.abc import AsyncGenerator +from typing import Any + +from agentkit.core.exceptions import LLMProviderError +from agentkit.llm.protocol import ( + LLMProvider, + LLMRequest, + LLMResponse, + StreamChunk, + TokenUsage, + ToolCall, +) + +logger = logging.getLogger(__name__) + + +# provider_type → LiteLLM model 前缀映射。 +# LiteLLM 通过 model 字符串前缀路由到对应 SDK:``openai/gpt-4o``、 +# ``anthropic/claude-...``、``gemini/gemini-...``、``volcengine/...``、 +# ``hunyuan/...``。未知类型回退到 ``openai/``(最大兼容性)。 +_PROVIDER_TYPE_TO_PREFIX: dict[str, str] = { + "openai": "openai/", + "anthropic": "anthropic/", + "gemini": "gemini/", + # 豆包由火山引擎提供,LiteLLM 前缀为 volcengine/ + "doubao": "volcengine/", + # ponytail: LiteLLM 暂无 wenxin 原生支持,回退到 openai/ + 自定义 api_base + # (千帆 OpenAI 兼容端点)。ceiling: wenxin 专属参数(如 AK/SK 鉴权)不支持; + # 升级路径:上游支持后改 "wenxin/"。 + "wenxin": "openai/", + # 腾讯混元 + "yuanbao": "hunyuan/", +} + + +def _model_prefix_for(provider_type: str) -> str: + """根据 provider_type 返回 LiteLLM model 前缀;未知类型回退到 openai/。""" + return _PROVIDER_TYPE_TO_PREFIX.get(provider_type, "openai/") + + +class LitellmProvider(LLMProvider): + """基于 LiteLLM ``acompletion`` 的统一 Provider。 + + 一个实例对应一个 provider 配置(api_key + base_url + provider_type)。 + ``chat`` / ``chat_stream`` 把 ``LLMRequest`` 转发给 LiteLLM 并翻译响应。 + + 注意:本类不持有 httpx 客户端 — LiteLLM 内部管理连接池。 + """ + + def __init__( + self, + model_prefix: str, + api_key: str, + base_url: str | None = None, + provider_type: str = "openai", + **default_kwargs: Any, + ) -> None: + self._model_prefix = model_prefix + self._api_key = api_key + self._base_url = base_url or None # 空字符串视作未设置 + self._provider_type = provider_type + self._default_kwargs: dict[str, Any] = dict(default_kwargs) + + async def chat(self, request: LLMRequest) -> LLMResponse: + """非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。""" + import litellm + + kwargs = self._build_kwargs(request, stream=False) + + start = time.monotonic() + try: + response = await litellm.acompletion(**kwargs) + except Exception as e: + # LiteLLM 抛出各种异常(openai.APIError / anthropic.APIError / 自定义), + # 统一包装成 LLMProviderError 以便 gateway fallback 链识别。 + raise LLMProviderError(self._provider_type, str(e)) from e + latency_ms = (time.monotonic() - start) * 1000 + + return self._parse_response(response, request.model, latency_ms) + + async def chat_stream(self, request: LLMRequest) -> AsyncGenerator[StreamChunk, None]: + """流式 chat — 调用 ``litellm.acompletion(stream=True)`` 并翻译 chunks。 + + 异步生成器安全:本项目规则禁止在第一个 ``yield`` 前使用 ``return``, + 本函数无早退分支,无需 ``return; yield`` 守卫。 + """ + import litellm + + kwargs = self._build_kwargs(request, stream=True) + + accumulated_tool_calls: dict[int, dict[str, Any]] = {} + final_usage: TokenUsage | None = None + final_model: str = request.model + yielded_any = False + + try: + # litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式: + # - 真实 litellm:返回 coroutine,await 后得到 async generator; + # - 测试 mock(async def + yield):调用即返回 async generator,不可 await。 + # 用 isawaitable 兼容两种路径。 + raw = litellm.acompletion(**kwargs) + stream = await raw if inspect.isawaitable(raw) else raw + async for chunk in stream: + yielded_any = True + parsed = self._parse_stream_chunk( + chunk, + request.model, + accumulated_tool_calls, + ) + # 更新累计状态 + if parsed.usage is not None: + final_usage = parsed.usage + if parsed.model: + final_model = parsed.model + # 内容块(含空内容)yield;usage-only 块也 yield(部分 provider + # 在最后一个 chunk 才给 usage,内容为空)。 + yield parsed + + # 流结束:yield 终止 chunk(聚合 tool_calls + usage) + tool_calls_list = self._finalize_tool_calls(accumulated_tool_calls) + yield StreamChunk( + content="", + model=final_model, + tool_calls=tool_calls_list, + usage=final_usage, + is_final=True, + ) + except Exception as e: + raise LLMProviderError(self._provider_type, str(e)) from e + + # ponytail: 若流完全为空(yielded_any=False),上面仍会 yield 一个 + # is_final=True 的空 chunk,调用方据此判断空响应。无需额外分支。 + _ = yielded_any # 标记保留(调试 / 未来扩展) + + # ------------------------------------------------------------------ + # 内部辅助 + # ------------------------------------------------------------------ + + def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]: + """从 LLMRequest 构造 litellm.acompletion kwargs。""" + kwargs: dict[str, Any] = { + "model": f"{self._model_prefix}{request.model}", + "messages": request.messages, + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "api_key": self._api_key, + "stream": stream, + # 禁用 LiteLLM 自带 retry — 由 gateway 的 fallback 链 / RetryPolicy 负责 + "num_retries": 0, + } + if self._base_url: + kwargs["api_base"] = self._base_url + if request.tools: + kwargs["tools"] = request.tools + kwargs["tool_choice"] = request.tool_choice + if request.timeout is not None: + kwargs["timeout"] = request.timeout + # 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数) + kwargs.update(self._default_kwargs) + return kwargs + + def _parse_response( + self, + response: Any, + request_model: str, + latency_ms: float, + ) -> LLMResponse: + """把 litellm 响应(OpenAI ChatCompletion 格式)翻译成 LLMResponse。""" + # LiteLLM 响应统一为 OpenAI ChatCompletion 格式: + # response.choices[0].message.{content, tool_calls} + # response.usage.{prompt_tokens, completion_tokens} + # response.model + choices = getattr(response, "choices", None) or [] + content = "" + tool_calls: list[ToolCall] = [] + if choices: + message = getattr(choices[0], "message", None) + if message is not None: + content = getattr(message, "content", None) or "" + raw_tool_calls = getattr(message, "tool_calls", None) + if raw_tool_calls: + tool_calls = _parse_tool_calls(raw_tool_calls) + + usage_obj = getattr(response, "usage", None) + usage = _parse_usage(usage_obj) if usage_obj is not None else TokenUsage() + + model_name = getattr(response, "model", None) or request_model + + return LLMResponse( + content=content, + model=model_name, + usage=usage, + tool_calls=tool_calls, + latency_ms=latency_ms, + ) + + def _parse_stream_chunk( + self, + chunk: Any, + request_model: str, + accumulated_tool_calls: dict[int, dict[str, Any]], + ) -> StreamChunk: + """解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。""" + choices = getattr(chunk, "choices", None) or [] + content = "" + model_name = getattr(chunk, "model", None) or request_model + usage: TokenUsage | None = None + + if choices: + delta = getattr(choices[0], "delta", None) + if delta is not None: + content = getattr(delta, "content", None) or "" + raw_tool_calls = getattr(delta, "tool_calls", None) + if raw_tool_calls: + _accumulate_stream_tool_calls(raw_tool_calls, accumulated_tool_calls) + + # 部分 provider 在最后一个 chunk 附带 usage + usage_obj = getattr(chunk, "usage", None) + if usage_obj is not None: + usage = _parse_usage(usage_obj) + + return StreamChunk( + content=content, + model=model_name, + tool_calls=[], # 流式中间块不带 tool_calls,由 final chunk 聚合 + usage=usage, + is_final=False, + ) + + def _finalize_tool_calls( + self, + accumulated: dict[int, dict[str, Any]], + ) -> list[ToolCall]: + """把累计的流式 tool_calls 字典转成 ToolCall 列表。""" + tool_calls: list[ToolCall] = [] + for idx in sorted(accumulated.keys()): + tc_data = accumulated[idx] + args_str = tc_data.get("arguments_str", "") + try: + arguments = json.loads(args_str) if args_str else {} + except json.JSONDecodeError: + arguments = {"raw": args_str} + tool_calls.append( + ToolCall( + id=tc_data.get("id", ""), + name=tc_data.get("name", ""), + arguments=arguments, + ) + ) + return tool_calls + + +# ---------------------------------------------------------------------- +# 响应解析辅助(模块级,便于测试 mock) +# ---------------------------------------------------------------------- + + +def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: + """解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。""" + result: list[ToolCall] = [] + for tc in raw_tool_calls: + # LiteLLM 返回的对象有 .id / .function.{name,arguments} + tc_id = getattr(tc, "id", "") or "" + func = getattr(tc, "function", None) + if func is None: + continue + name = getattr(func, "name", "") or "" + args = getattr(func, "arguments", "{}") + if isinstance(args, str): + try: + args_dict = json.loads(args) if args else {} + except json.JSONDecodeError: + args_dict = {"raw": args} + elif isinstance(args, dict): + args_dict = args + else: + args_dict = {"raw": str(args)} + result.append(ToolCall(id=tc_id, name=name, arguments=args_dict)) + return result + + +def _parse_usage(usage_obj: Any) -> TokenUsage: + """解析 usage 对象(OpenAI CompletionUsage 或 dict)。""" + prompt = getattr(usage_obj, "prompt_tokens", None) + completion = getattr(usage_obj, "completion_tokens", None) + if prompt is None and isinstance(usage_obj, dict): + prompt = usage_obj.get("prompt_tokens", 0) + completion = usage_obj.get("completion_tokens", 0) + return TokenUsage( + prompt_tokens=int(prompt or 0), + completion_tokens=int(completion or 0), + ) + + +def _accumulate_stream_tool_calls( + raw_tool_calls: Any, + accumulated: dict[int, dict[str, Any]], +) -> None: + """累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。 + + 每个 delta tool_call 有 index/id/function.{name,arguments},arguments + 是分片字符串,需跨 chunk 拼接。 + """ + for tc in raw_tool_calls: + idx = getattr(tc, "index", 0) or 0 + if idx not in accumulated: + accumulated[idx] = { + "id": "", + "name": "", + "arguments_str": "", + } + tc_id = getattr(tc, "id", None) + if tc_id: + accumulated[idx]["id"] = tc_id + func = getattr(tc, "function", None) + if func is not None: + name = getattr(func, "name", None) + if name: + accumulated[idx]["name"] = name + args_fragment = getattr(func, "arguments", None) + if args_fragment: + accumulated[idx]["arguments_str"] += args_fragment + + +# ---------------------------------------------------------------------- +# 工厂函数 +# ---------------------------------------------------------------------- + + +def create_litellm_provider( + provider_type: str, + api_key: str, + base_url: str | None = None, + **kwargs: Any, +) -> LitellmProvider: + """根据 provider_type 创建 LitellmProvider 实例。 + + Args: + provider_type: 配置中的 provider 类型字符串 + ("openai"|"anthropic"|"gemini"|"doubao"|"wenxin"|"yuanbao")。 + 未知类型回退到 "openai/" 前缀。 + api_key: provider API key。 + base_url: 可选自定义 API base URL。 + **kwargs: 透传给 LitellmProvider 的额外默认参数。 + + Returns: + 配置好 model_prefix 的 LitellmProvider。 + """ + prefix = _model_prefix_for(provider_type) + return LitellmProvider( + model_prefix=prefix, + api_key=api_key, + base_url=base_url, + provider_type=provider_type, + **kwargs, + ) diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index ac33806..e8496e9 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -11,9 +11,13 @@ from fastapi.middleware.cors import CORSMiddleware from agentkit.core.agent_pool import AgentPool from agentkit.core.event_queue import EventQueue, SubmissionQueue from agentkit.llm.gateway import LLMGateway -from agentkit.llm.providers.anthropic import AnthropicProvider -from agentkit.llm.providers.gemini import GeminiProvider -from agentkit.llm.providers.openai import OpenAICompatibleProvider + +# U15: 旧 provider 类保留导入用于向后兼容 / 回滚(_create_provider 已切到 LitellmProvider)。 +# 不要删除 — 下个 release 稳定后再清理。 +from agentkit.llm.providers.anthropic import AnthropicProvider # noqa: F401 +from agentkit.llm.providers.gemini import GeminiProvider # noqa: F401 +from agentkit.llm.providers.litellm_provider import create_litellm_provider +from agentkit.llm.providers.openai import OpenAICompatibleProvider # noqa: F401 from agentkit.mcp.manager import MCPManager from agentkit.quality.gate import QualityGate from agentkit.quality.output import OutputStandardizer @@ -83,7 +87,8 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway: gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) for name, pconf in config.llm_config.providers.items(): - if not pconf.api_key: + # U15: 跳过既无 plaintext api_key 又无加密列的 provider + if not pconf.api_key and not pconf.api_key_encrypted: continue # Skip providers without API keys try: provider = _create_provider(name, pconf) @@ -97,45 +102,38 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway: def _create_provider(name: str, pconf) -> object: """Create an LLM provider instance from ProviderConfig. - Shared by server app and CLI chat to avoid duplicated initialization logic. + U15: 统一使用 LitellmProvider 替换 6 个直接 API provider 适配器。 + 旧的 AnthropicProvider / GeminiProvider / OpenAICompatibleProvider 类保留 + 导入(向后兼容 / 回滚),但新代码路径走 LiteLLM。 + + ponytail: 加密 key 的异步解密(aget_api_key)在启动时调用需 async 上下文, + 此处为同步函数 — 双读窗口保证 plaintext ``api_key`` 在迁移期仍可用, + 故直接用 ``pconf.api_key``。完全迁移后(plaintext 清空)需在此处前增加 + async 预解密步骤。升级路径:把 _build_llm_gateway 改为 async 并在 lifespan + 中调用。 """ - if pconf.type == "anthropic": - return AnthropicProvider( - api_key=pconf.api_key, - model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514", - max_tokens=pconf.max_tokens, - base_url=pconf.base_url or "https://api.anthropic.com", - timeout=pconf.timeout, - max_connections=pconf.max_connections, - max_keepalive_connections=pconf.max_keepalive_connections, - keepalive_expiry=pconf.keepalive_expiry, + # api_key 解析:双读窗口优先 plaintext(同步可读);plaintext 为空时 + # 调用方应在此之前完成 async 解密并写回 pconf.api_key。 + api_key = pconf.api_key + if not api_key and pconf.api_key_encrypted: + # plaintext 已清空但加密列存在 — 同步上下文无法解密,跳过注册 + # (由 lifespan 异步预解密路径补注册,或运维回滚 plaintext) + logger.warning( + f"Provider '{name}' has encrypted key but no plaintext fallback " + f"— skipped (run async pre-decrypt at startup)" ) - elif pconf.type == "gemini": - return GeminiProvider( - api_key=pconf.api_key, - model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash", - max_output_tokens=pconf.max_tokens, - base_url=pconf.base_url or "https://generativelanguage.googleapis.com", - timeout=pconf.timeout, - max_connections=pconf.max_connections, - max_keepalive_connections=pconf.max_keepalive_connections, - keepalive_expiry=pconf.keepalive_expiry, - ) - else: - if not pconf.base_url: - raise ValueError( - f"Provider '{name}' is missing base_url. " - f"OpenAI-compatible providers require an explicit base_url in config." - ) - return OpenAICompatibleProvider( - api_key=pconf.api_key, - base_url=pconf.base_url, - max_connections=pconf.max_connections, - max_keepalive_connections=pconf.max_keepalive_connections, - keepalive_expiry=pconf.keepalive_expiry, - timeout=pconf.timeout, + raise ValueError( + f"Provider '{name}' api_key not available synchronously " + f"(encrypted only). Run async pre-decrypt first." ) + base_url = pconf.base_url or None + return create_litellm_provider( + provider_type=pconf.type, + api_key=api_key, + base_url=base_url, + ) + def _build_skill_registry(config: ServerConfig) -> SkillRegistry: """Build SkillRegistry from ServerConfig, loading all skill configs.""" diff --git a/tests/unit/llm/test_config_migration.py b/tests/unit/llm/test_config_migration.py new file mode 100644 index 0000000..68e2895 --- /dev/null +++ b/tests/unit/llm/test_config_migration.py @@ -0,0 +1,287 @@ +"""U15 — ProviderConfig API Key 加密迁移测试。 + +覆盖: +- get_api_key plaintext 回退 +- aget_api_key + secrets_store 解密 +- aget_api_key dual-read 回退(解密失败 → plaintext) +- migrate_to_secrets 完整流程 +- migrate_to_secrets 幂等 +- LLMConfig.from_dict 解析新字段 +- LLMConfig.from_dict 缺省新字段 +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock + +from agentkit.channels.secrets import SecretsStore +from agentkit.llm.config import LLMConfig, ProviderConfig + + +# ---------------------------------------------------------------------- +# 15: get_api_key plaintext +# ---------------------------------------------------------------------- + + +def test_get_api_key_plaintext_no_store(): + """无 secrets_store 时 get_api_key 返回 plaintext。""" + pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai") + assert pconf.get_api_key(None) == "sk-xxx" + + +def test_get_api_key_plaintext_with_store_sync(): + """有 secrets_store 但同步调用 — 仍返回 plaintext(async 解密不可用)。""" + pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai") + # 即使传 store,同步路径无法 decrypt,回退 plaintext + store = object() # 任意非 None 对象 + assert pconf.get_api_key(store) == "sk-xxx" # type: ignore[arg-type] + + +# ---------------------------------------------------------------------- +# 16: aget_api_key + secrets_store 解密成功 +# ---------------------------------------------------------------------- + + +async def test_aget_api_key_decrypts_from_store(): + """加密列 + secrets_store → 解密返回 key。""" + store = SecretsStore(master_key=b"x" * 32) + # 先在 store 里放一个 key + await store.set_secret("llm:provider:openai:api_key", "sk-decrypted") + # 构造一个加密的 ProviderConfig(直接用真实加密流程) + pconf = ProviderConfig( + api_key="", + base_url="", + type="openai", + api_key_encrypted="", # 占位,下面重新生成 + api_key_source="secrets_store", + ) + # 用 store 加密 "sk-decrypted" 得到真实 encrypted 字段 + entry = await store.set_secret("llm:provider:openai:api_key", "sk-decrypted") + pconf.api_key_encrypted = ProviderConfig._encode_secret_entry( + entry, "llm:provider:openai:api_key" + ) + + result = await pconf.aget_api_key(store) + assert result == "sk-decrypted" + + +# ---------------------------------------------------------------------- +# 17: aget_api_key dual-read 回退 +# ---------------------------------------------------------------------- + + +async def test_aget_api_key_dual_read_fallback_to_plaintext(): + """加密列存在但解密失败 → 回退 plaintext。""" + # secrets_store 返回 None(模拟 key 不存在 / 解密失败) + store = AsyncMock(spec=SecretsStore) + store.get_secret = AsyncMock(return_value=None) + + pconf = ProviderConfig( + api_key="sk-plaintext", + base_url="", + type="openai", + api_key_encrypted=json.dumps( + { + "key": "llm:provider:openai:api_key", + "value": "fake", + "nonce": "fake", + "salt": "fake", + "key_id": "default", + "created_at": "", + "updated_at": "", + } + ), + api_key_source="dual", + ) + + result = await pconf.aget_api_key(store) + assert result == "sk-plaintext" + + +async def test_aget_api_key_decrypt_exception_falls_back(): + """解密抛异常时回退 plaintext。""" + store = AsyncMock(spec=SecretsStore) + store.get_secret = AsyncMock(side_effect=RuntimeError("decrypt failed")) + + pconf = ProviderConfig( + api_key="sk-plaintext", + base_url="", + type="openai", + api_key_encrypted=json.dumps( + { + "key": "llm:provider:openai:api_key", + "value": "fake", + "nonce": "fake", + "salt": "fake", + "key_id": "default", + "created_at": "", + "updated_at": "", + } + ), + api_key_source="dual", + ) + + result = await pconf.aget_api_key(store) + assert result == "sk-plaintext" + + +async def test_aget_api_key_no_encrypted_returns_plaintext(): + """无加密列时直接返回 plaintext。""" + store = SecretsStore(master_key=b"x" * 32) + pconf = ProviderConfig(api_key="sk-plain", base_url="", type="openai") + assert await pconf.aget_api_key(store) == "sk-plain" + + +# ---------------------------------------------------------------------- +# 18: migrate_to_secrets 完整流程 +# ---------------------------------------------------------------------- + + +async def test_migrate_to_secrets_full_flow(): + """plaintext → secrets_store 迁移:加密、验证、清空 plaintext。""" + store = SecretsStore(master_key=b"x" * 32) + pconf = ProviderConfig( + api_key="sk-secret", + base_url="", + type="openai", + ) + + await pconf.migrate_to_secrets(store) + + assert pconf.api_key == "" # plaintext 清空 + assert pconf.api_key_encrypted is not None # 加密列写入 + assert pconf.api_key_source == "secrets_store" + + # 验证:通过 aget_api_key 能读回原 key + decrypted = await pconf.aget_api_key(store) + assert decrypted == "sk-secret" + + +# ---------------------------------------------------------------------- +# 19: migrate_to_secrets 幂等 +# ---------------------------------------------------------------------- + + +async def test_migrate_to_secrets_idempotent(): + """已迁移的 config 再次调用 migrate_to_secrets 是 no-op。""" + store = SecretsStore(master_key=b"x" * 32) + pconf = ProviderConfig(api_key="sk-secret", base_url="", type="openai") + + await pconf.migrate_to_secrets(store) + encrypted_after_first = pconf.api_key_encrypted + assert pconf.api_key_source == "secrets_store" + + # 第二次调用 — no-op + await pconf.migrate_to_secrets(store) + assert pconf.api_key_encrypted == encrypted_after_first + assert pconf.api_key_source == "secrets_store" + assert pconf.api_key == "" + + +async def test_migrate_to_secrets_skips_empty_plaintext(): + """plaintext 为空时跳过迁移。""" + store = SecretsStore(master_key=b"x" * 32) + pconf = ProviderConfig(api_key="", base_url="", type="openai") + + await pconf.migrate_to_secrets(store) + + assert pconf.api_key == "" + assert pconf.api_key_encrypted is None + assert pconf.api_key_source == "plaintext" + + +# ---------------------------------------------------------------------- +# 20-21: LLMConfig.from_dict 解析新字段 +# ---------------------------------------------------------------------- + + +def test_from_dict_parses_new_encrypted_fields(): + """LLMConfig.from_dict 解析 api_key_encrypted / api_key_source。""" + data = { + "providers": { + "openai": { + "api_key": "", + "base_url": "https://api.openai.com/v1", + "type": "openai", + "api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}', + "api_key_source": "secrets_store", + }, + }, + } + + config = LLMConfig.from_dict(data) + pconf = config.providers["openai"] + assert pconf.api_key_encrypted == '{"key":"k","value":"v","nonce":"n","salt":"s"}' + assert pconf.api_key_source == "secrets_store" + + +def test_from_dict_defaults_new_fields_to_plaintext(): + """LLMConfig.from_dict 缺省新字段时默认 plaintext 行为。""" + data = { + "providers": { + "openai": { + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + "type": "openai", + }, + }, + } + + config = LLMConfig.from_dict(data) + pconf = config.providers["openai"] + assert pconf.api_key_encrypted is None + assert pconf.api_key_source == "plaintext" + + +def test_from_dict_dual_source_during_migration(): + """双写窗口:api_key_source="dual" 时两个字段都保留。""" + data = { + "providers": { + "openai": { + "api_key": "sk-plaintext", + "base_url": "", + "type": "openai", + "api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}', + "api_key_source": "dual", + }, + }, + } + + config = LLMConfig.from_dict(data) + pconf = config.providers["openai"] + assert pconf.api_key == "sk-plaintext" + assert pconf.api_key_encrypted is not None + assert pconf.api_key_source == "dual" + + +# ---------------------------------------------------------------------- +# 额外:_secret_key_for_type 命名空间 +# ---------------------------------------------------------------------- + + +def test_secret_key_namespaced_by_type(): + """secret key 应按 provider type 命名空间隔离。""" + pconf = ProviderConfig(api_key="", base_url="", type="anthropic") + assert pconf._secret_key_for_type() == "llm:provider:anthropic:api_key" + + pconf2 = ProviderConfig(api_key="", base_url="", type="gemini") + assert pconf2._secret_key_for_type() == "llm:provider:gemini:api_key" + + +# ---------------------------------------------------------------------- +# 额外:encode/decode SecretEntry 往返 +# ---------------------------------------------------------------------- + + +async def test_encode_decode_secret_entry_roundtrip(): + """encode → decode 应保留关键字段。""" + store = SecretsStore(master_key=b"x" * 32) + entry = await store.set_secret("llm:provider:openai:api_key", "sk-xxx") + encoded = ProviderConfig._encode_secret_entry(entry, "llm:provider:openai:api_key") + decoded = ProviderConfig._decode_secret_entry(encoded) + + assert decoded.key == "llm:provider:openai:api_key" + assert decoded.value == entry.value + assert decoded.nonce == entry.nonce + assert decoded.salt == entry.salt diff --git a/tests/unit/llm/test_litellm_provider.py b/tests/unit/llm/test_litellm_provider.py new file mode 100644 index 0000000..d7d935c --- /dev/null +++ b/tests/unit/llm/test_litellm_provider.py @@ -0,0 +1,543 @@ +"""U15 — LitellmProvider 单元测试。 + +用 ``unittest.mock.AsyncMock`` mock ``litellm.acompletion``,覆盖: +- 6 个 provider 类型(openai/anthropic/gemini/doubao/wenxin/yuanbao)经 LiteLLM 调用 +- 未知 provider_type 回退到 openai/ 前缀 +- 自定义 base_url 透传 +- tools 透传 + tool_calls 响应解析 +- 流式 chunk 迭代 + 终止 chunk +- 流式空消息处理 +- LiteLLM 异常 → LLMProviderError +- latency_ms 非负 +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from agentkit.core.exceptions import LLMProviderError +from agentkit.llm.protocol import LLMRequest, LLMResponse +from agentkit.llm.providers.litellm_provider import ( + LitellmProvider, + _model_prefix_for, + create_litellm_provider, +) + + +# ---------------------------------------------------------------------- +# 测试辅助:构造 OpenAI 格式的 fake response / chunk +# ---------------------------------------------------------------------- + + +def _fake_response( + content: str = "Hello!", + model: str = "gpt-4o-mini", + prompt_tokens: int = 10, + completion_tokens: int = 5, + tool_calls: list | None = None, +) -> SimpleNamespace: + """构造非流式 fake response(OpenAI ChatCompletion 格式)。""" + message = SimpleNamespace(content=content, tool_calls=tool_calls) + return SimpleNamespace( + choices=[SimpleNamespace(message=message)], + usage=SimpleNamespace( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + model=model, + ) + + +def _fake_tool_call( + tc_id: str = "call_1", + name: str = "get_weather", + arguments: str = '{"city": "Beijing"}', +) -> SimpleNamespace: + """构造非流式 tool_call 对象。""" + return SimpleNamespace( + id=tc_id, + function=SimpleNamespace(name=name, arguments=arguments), + ) + + +def _fake_stream_chunk( + content: str = "", + model: str = "gpt-4o-mini", + tool_calls_delta: list | None = None, + usage: SimpleNamespace | None = None, +) -> SimpleNamespace: + """构造流式 chunk(OpenAI ChatCompletionChunk 格式)。""" + delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta) + return SimpleNamespace( + choices=[SimpleNamespace(delta=delta)], + model=model, + usage=usage, + ) + + +def _fake_stream_tool_call_delta( + index: int = 0, + tc_id: str | None = "call_1", + name: str | None = "get_weather", + arguments_fragment: str | None = None, +) -> SimpleNamespace: + """构造流式 tool_call delta 片段。""" + return SimpleNamespace( + index=index, + id=tc_id, + function=SimpleNamespace(name=name, arguments=arguments_fragment), + ) + + +# ---------------------------------------------------------------------- +# 1-7: provider_type → model_prefix 映射 + 基本 chat +# ---------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "provider_type, expected_prefix", + [ + ("openai", "openai/"), + ("anthropic", "anthropic/"), + ("gemini", "gemini/"), + ("doubao", "volcengine/"), + ("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/ + ("yuanbao", "hunyuan/"), + ("unknown_type", "openai/"), # 未知 → openai/ 回退 + ], + ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"], +) +def test_model_prefix_mapping(provider_type: str, expected_prefix: str): + """验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。""" + assert _model_prefix_for(provider_type) == expected_prefix + + +@pytest.mark.parametrize( + "provider_type, expected_prefix, model_name", + [ + ("openai", "openai/", "gpt-4o-mini"), + ("anthropic", "anthropic/", "claude-sonnet-4-20250514"), + ("gemini", "gemini/", "gemini-2.0-flash"), + ("doubao", "volcengine/", "doubao-pro-32k"), + ("wenxin", "openai/", "ernie-4.5-turbo-128k"), + ("yuanbao", "hunyuan/", "hunyuan-pro"), + ], + ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"], +) +async def test_chat_via_litellm_for_each_provider_type( + provider_type: str, + expected_prefix: str, + model_name: str, +): + """6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。""" + provider = create_litellm_provider(provider_type, api_key="sk-test") + assert provider._model_prefix == expected_prefix + + fake_resp = _fake_response(content="ok", model=model_name) + with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): + response = await provider.chat( + LLMRequest( + messages=[{"role": "user", "content": "Hi"}], + model=model_name, + ) + ) + + assert isinstance(response, LLMResponse) + assert response.content == "ok" + assert response.model == model_name + assert response.usage.prompt_tokens == 10 + assert response.usage.completion_tokens == 5 + + +# ---------------------------------------------------------------------- +# 8: 自定义 base_url 透传 +# ---------------------------------------------------------------------- + + +async def test_custom_base_url_passed_through(): + """自定义 base_url 应作为 api_base 传给 litellm.acompletion。""" + provider = create_litellm_provider( + "openai", + api_key="sk-test", + base_url="https://custom.api/v1", + ) + mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("litellm.acompletion", new=mock_acompletion): + await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini") + ) + + _, kwargs = mock_acompletion.call_args + assert kwargs["api_base"] == "https://custom.api/v1" + assert kwargs["api_key"] == "sk-test" + assert kwargs["model"] == "openai/gpt-4o-mini" + + +# ---------------------------------------------------------------------- +# 9: tools 透传 +# ---------------------------------------------------------------------- + + +async def test_tools_passed_through(): + """request.tools 应透传给 litellm.acompletion。""" + provider = create_litellm_provider("openai", api_key="sk-test") + mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("litellm.acompletion", new=mock_acompletion): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + await provider.chat( + LLMRequest( + messages=[{"role": "user", "content": "weather?"}], + model="gpt-4o", + tools=tools, + tool_choice="auto", + ) + ) + + _, kwargs = mock_acompletion.call_args + assert kwargs["tools"] == tools + assert kwargs["tool_choice"] == "auto" + + +# ---------------------------------------------------------------------- +# 10: tool_calls 响应解析 +# ---------------------------------------------------------------------- + + +async def test_tool_calls_in_response_parsed(): + """非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。""" + provider = create_litellm_provider("openai", api_key="sk-test") + fake_resp = _fake_response( + content="", + tool_calls=[ + _fake_tool_call( + tc_id="call_abc", + name="get_weather", + arguments='{"city": "Beijing"}', + ) + ], + ) + with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): + response = await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + assert response.has_tool_calls + assert len(response.tool_calls) == 1 + tc = response.tool_calls[0] + assert tc.id == "call_abc" + assert tc.name == "get_weather" + assert tc.arguments == {"city": "Beijing"} + + +async def test_tool_calls_with_dict_arguments(): + """tool_call.arguments 为 dict 时直接采用。""" + provider = create_litellm_provider("openai", api_key="sk-test") + fake_resp = _fake_response( + tool_calls=[ + SimpleNamespace( + id="call_1", + function=SimpleNamespace(name="fn", arguments={"k": "v"}), + ) + ] + ) + with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): + response = await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + assert response.tool_calls[0].arguments == {"k": "v"} + + +# ---------------------------------------------------------------------- +# 11: 流式 yields chunks +# ---------------------------------------------------------------------- + + +async def test_streaming_yields_chunks(): + """流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + chunks = [ + _fake_stream_chunk(content="Hello"), + _fake_stream_chunk(content=" world"), + _fake_stream_chunk( + usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2), + ), + ] + + async def _fake_stream(**_kwargs): + for c in chunks: + yield c + + with patch("litellm.acompletion", new=_fake_stream): + results = [] + async for chunk in provider.chat_stream( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini") + ): + results.append(chunk) + + # 前 3 个是流式 chunk(is_final=False),最后一个是终止 chunk(is_final=True) + assert len(results) == 4 + assert results[0].content == "Hello" + assert results[1].content == " world" + assert all(not r.is_final for r in results[:3]) + assert results[-1].is_final is True + # 终止 chunk 含聚合 usage + assert results[-1].usage is not None + assert results[-1].usage.prompt_tokens == 8 + assert results[-1].usage.completion_tokens == 2 + + +async def test_streaming_aggregates_tool_calls(): + """流式 tool_calls 片段应聚合到终止 chunk。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + chunks = [ + _fake_stream_chunk( + tool_calls_delta=[ + _fake_stream_tool_call_delta( + index=0, + tc_id="call_1", + name="get_weather", + arguments_fragment='{"city":', + ) + ] + ), + _fake_stream_chunk( + tool_calls_delta=[ + _fake_stream_tool_call_delta( + index=0, + tc_id=None, + name=None, + arguments_fragment=' "Beijing"}', + ) + ] + ), + ] + + async def _fake_stream(**_kwargs): + for c in chunks: + yield c + + with patch("litellm.acompletion", new=_fake_stream): + results = [] + async for chunk in provider.chat_stream( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ): + results.append(chunk) + + final = results[-1] + assert final.is_final + assert len(final.tool_calls) == 1 + assert final.tool_calls[0].id == "call_1" + assert final.tool_calls[0].name == "get_weather" + assert final.tool_calls[0].arguments == {"city": "Beijing"} + + +# ---------------------------------------------------------------------- +# 12: 流式空消息处理 +# ---------------------------------------------------------------------- + + +async def test_streaming_empty_messages_handled(): + """流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk,不崩溃。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + async def _empty_stream(**_kwargs): + return + yield # 让函数成为 async generator(永不执行) + + with patch("litellm.acompletion", new=_empty_stream): + results = [] + async for chunk in provider.chat_stream( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ): + results.append(chunk) + + # 至少 yield 一个终止 chunk + assert len(results) >= 1 + assert results[-1].is_final is True + assert results[-1].content == "" + + +# ---------------------------------------------------------------------- +# 13: LiteLLM 异常 → LLMProviderError +# ---------------------------------------------------------------------- + + +async def test_litellm_exception_raises_provider_error(): + """litellm.acompletion 抛异常时应包装成 LLMProviderError。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))): + with pytest.raises(LLMProviderError) as exc_info: + await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + assert "boom" in str(exc_info.value) + + +async def test_litellm_stream_exception_raises_provider_error(): + """流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + async def _failing_stream(**_kwargs): + raise RuntimeError("stream boom") + yield # unreachable,仅为 async generator 语法 + + with patch("litellm.acompletion", new=_failing_stream): + with pytest.raises(LLMProviderError) as exc_info: + async for _ in provider.chat_stream( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ): + pass + + assert "stream boom" in str(exc_info.value) + + +# ---------------------------------------------------------------------- +# 14: latency 非负 +# ---------------------------------------------------------------------- + + +async def test_latency_measured_non_negative(): + """LLMResponse.latency_ms 应为非负数。""" + provider = create_litellm_provider("openai", api_key="sk-test") + with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())): + response = await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + assert response.latency_ms >= 0 + + +# ---------------------------------------------------------------------- +# 额外:num_retries=0 禁用 LiteLLM 自带 retry +# ---------------------------------------------------------------------- + + +async def test_litellm_num_retries_disabled(): + """应传 num_retries=0 禁用 LiteLLM 自带 retry(由 gateway fallback 负责)。""" + provider = create_litellm_provider("openai", api_key="sk-test") + mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("litellm.acompletion", new=mock_acompletion): + await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + _, kwargs = mock_acompletion.call_args + assert kwargs["num_retries"] == 0 + assert kwargs["stream"] is False + + +async def test_streaming_passes_stream_true(): + """chat_stream 应传 stream=True。""" + provider = create_litellm_provider("openai", api_key="sk-test") + + captured: dict = {} + + async def _capturing_stream(**kwargs): + captured.update(kwargs) + + async def _inner(): + return + yield + + return _inner() + + with patch("litellm.acompletion", new=_capturing_stream): + async for _ in provider.chat_stream( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ): + pass + + assert captured["stream"] is True + + +# ---------------------------------------------------------------------- +# 额外:timeout 透传 +# ---------------------------------------------------------------------- + + +async def test_timeout_passed_through(): + """request.timeout 应透传给 litellm.acompletion。""" + provider = create_litellm_provider("openai", api_key="sk-test") + mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("litellm.acompletion", new=mock_acompletion): + await provider.chat( + LLMRequest( + messages=[{"role": "user", "content": "Hi"}], + model="gpt-4o", + timeout=42.0, + ) + ) + + _, kwargs = mock_acompletion.call_args + assert kwargs["timeout"] == 42.0 + + +# ---------------------------------------------------------------------- +# 额外:LitellmProvider 直接构造(不经工厂) +# ---------------------------------------------------------------------- + + +async def test_direct_litellm_provider_construction(): + """直接用 LitellmProvider(model_prefix=...) 构造也应工作。""" + provider = LitellmProvider( + model_prefix="anthropic/", + api_key="sk-test", + provider_type="anthropic", + ) + fake_resp = _fake_response(content="hi", model="claude-3") + with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): + response = await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") + ) + + assert response.content == "hi" + # 验证 model 前缀正确拼接 + mock_acompletion = AsyncMock(return_value=fake_resp) + with patch("litellm.acompletion", new=mock_acompletion): + await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3") + ) + _, kwargs = mock_acompletion.call_args + assert kwargs["model"] == "anthropic/claude-3" + + +# ---------------------------------------------------------------------- +# 额外:JSON tool_calls 解析容错 +# ---------------------------------------------------------------------- + + +async def test_tool_calls_invalid_json_arguments_wrapped(): + """tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。""" + provider = create_litellm_provider("openai", api_key="sk-test") + fake_resp = _fake_response( + tool_calls=[ + _fake_tool_call( + tc_id="call_1", + name="fn", + arguments="not-valid-json{", + ) + ] + ) + with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)): + response = await provider.chat( + LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o") + ) + + assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}