feat(llm): U15 — LiteLLM unified provider + api_key encrypted secrets migration

This commit is contained in:
chiguyong 2026-06-25 21:41:15 +08:00
parent 13c516a54f
commit 069dbc22b1
8 changed files with 1541 additions and 40 deletions

View File

@ -26,6 +26,8 @@ dependencies = [
"aiosqlite>=0.20", "aiosqlite>=0.20",
# 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM) # 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM)
"cryptography>=42.0", "cryptography>=42.0",
# U15 — LiteLLM 统一 Provider 适配层(替换 6 个直接 API provider
"litellm>=1.50",
# Calendar & schedule (RRULE expansion) # Calendar & schedule (RRULE expansion)
"python-dateutil>=2.9", "python-dateutil>=2.9",
# Calendar ICS import/export (U8) # Calendar ICS import/export (U8)

View File

@ -1,10 +1,17 @@
"""LLM Config - 配置加载""" """LLM Config - 配置加载"""
import json
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, TYPE_CHECKING
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
if TYPE_CHECKING:
from agentkit.channels.secrets import SecretsStore
logger = logging.getLogger(__name__)
@dataclass @dataclass
class CacheConfig: class CacheConfig:
@ -58,6 +65,140 @@ class ProviderConfig:
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒) keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
retry: RetryConfig | None = None retry: RetryConfig | None = None
circuit_breaker: CircuitBreakerConfig | None = None circuit_breaker: CircuitBreakerConfig | None = None
# U15 — API Key 加密迁移plaintext → SecretsStore
# api_key_encrypted: JSON 编码的 SecretEntrybase64 nonce+salt+ciphertext
# 为 None 表示未迁移,仍走 plaintext api_key。
# api_key_source: 当前 key 来源标记,用于双写/双读迁移窗口。
# "plaintext" — 仅 plaintext 列;"secrets_store" — 仅加密列;
# "dual" — 双写窗口中(两列都有值,读时优先加密)。
api_key_encrypted: str | None = None
api_key_source: str = "plaintext"
def get_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:
"""同步读取 API Key。
双读窗口优先级``api_key_encrypted`` + ``secrets_store`` > plaintext
若加密列存在但解密失败store None 或解密异常回退到 plaintext
``api_key``保证迁移期可用性
注意``SecretsStore.get_secret`` async本同步方法无法调用
``api_key_encrypted`` 已设置但 ``secrets_store`` None仍回退到
plaintext需要解密时请用 ``aget_api_key``
"""
if self.api_key_encrypted and secrets_store is not None:
# 同步上下文无法 await get_secret调用方应改用 aget_api_key。
# 这里保持双读语义的回退:返回 plaintext。
logger.debug("get_api_key: encrypted key set but sync access — fallback to plaintext")
return self.api_key
async def aget_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:
"""异步读取 API Key — 双读窗口优先加密列,失败回退 plaintext。
Args:
secrets_store: 可选的加密存储 None 时直接返回 plaintext
Returns:
解密后的 API Key加密列解密失败时回退到 plaintext ``api_key``
"""
if self.api_key_encrypted and secrets_store is not None:
try:
entry = self._decode_secret_entry(self.api_key_encrypted)
decrypted = await secrets_store.get_secret(entry.key)
if decrypted is not None:
return decrypted
# store 里没有这个 key可能已被删除— 回退 plaintext
logger.warning(
f"aget_api_key: encrypted key for provider type={self.type} "
f"not found in secrets_store — fallback to plaintext"
)
except Exception as e:
# 解密失败master key 不匹配 / 密文损坏)— 回退 plaintext
logger.warning(
f"aget_api_key: decrypt failed for type={self.type}: {e} — fallback to plaintext"
)
return self.api_key
async def migrate_to_secrets(self, secrets_store: "SecretsStore") -> None:
"""把 plaintext api_key 迁移到 SecretsStore幂等
迁移步骤双写窗口
1. ``api_key_source == "secrets_store"`` plaintext 已清空 已迁移no-op
2. 否则调用 ``secrets_store.set_secret`` 加密存储 key
3. 把返回的 SecretEntry JSON 编码写入 ``api_key_encrypted``
4. 标记 ``api_key_source = "secrets_store"``
5. 验证调用 ``get_secret`` 读回对比成功后清空 plaintext ``api_key=""``
幂等性重复调用不会重复加密已迁移时直接返回
部分失败恢复 set 成功但验证失败保留 plaintext 不清空
``api_key_encrypted`` 已写入 下次重试时由幂等性保证最终一致
Args:
secrets_store: 用于加密存储的 SecretsStore 实例
"""
# 幂等已迁移完成source 标记 + plaintext 已清空)
if self.api_key_source == "secrets_store" and not self.api_key and self.api_key_encrypted:
return
# 没有 plaintext 可迁移(空 key— 跳过
if not self.api_key:
return
secret_key = self._secret_key_for_type()
# 双写:先写加密列
entry = await secrets_store.set_secret(secret_key, self.api_key)
self.api_key_encrypted = self._encode_secret_entry(entry, secret_key)
# 验证:读回对比
try:
decrypted = await secrets_store.get_secret(secret_key)
except Exception as e:
logger.warning(f"migrate_to_secrets: verify read failed for type={self.type}: {e}")
# 加密列已写但验证失败 — 保留 plaintext标记 dual 待重试
self.api_key_source = "dual"
return
if decrypted != self.api_key:
logger.error(
f"migrate_to_secrets: verify mismatch for type={self.type} "
f"— plaintext retained, source=dual"
)
self.api_key_source = "dual"
return
# 验证通过:清空 plaintext标记完成
self.api_key_source = "secrets_store"
self.api_key = ""
def _secret_key_for_type(self) -> str:
"""生成 SecretsStore 中的 key按 provider type 命名空间隔离)。"""
return f"llm:provider:{self.type}:api_key"
@staticmethod
def _encode_secret_entry(entry: Any, key: str) -> str:
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
# entry 是 SecretEntry pydantic 模型,有 model_dump()
if hasattr(entry, "model_dump"):
data = entry.model_dump()
else:
data = dict(entry)
data["key"] = key
return json.dumps(data)
@staticmethod
def _decode_secret_entry(encoded: str) -> Any:
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
from agentkit.channels.secrets import SecretEntry
data = json.loads(encoded)
return SecretEntry(
key=data.get("key", ""),
value=data["value"],
nonce=data["nonce"],
salt=data["salt"],
key_id=data.get("key_id", "default"),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
)
@dataclass @dataclass
@ -105,6 +246,9 @@ class LLMConfig:
keepalive_expiry=pconf.get("keepalive_expiry", 30.0), keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
retry=retry, retry=retry,
circuit_breaker=circuit_breaker, circuit_breaker=circuit_breaker,
# U15 — 新增加密迁移字段,缺省时保持 plaintext 行为
api_key_encrypted=pconf.get("api_key_encrypted"),
api_key_source=pconf.get("api_key_source", "plaintext"),
) )
cache = None cache = None
cache_data = data.get("cache") cache_data = data.get("cache")

View File

@ -0,0 +1,137 @@
"""U15 — API Key 加密迁移辅助函数。
``agentkit.yaml`` LLM provider plaintext ``api_key`` 迁移到
:class:`SecretsStore` 加密存储并更新配置文件中的 ``api_key_encrypted`` /
``api_key_source`` 字段
迁移模型双写/双读窗口
1. ``agentkit.yaml`` ``LLMConfig.providers``
2. 对每个 provider 调用 ``ProviderConfig.migrate_to_secrets(store)``
- 加密 plaintext 写入 SecretsStore
- 验证读回一致后清空 plaintext标记 ``api_key_source="secrets_store"``
- 部分失败时保留 plaintext标记 ``api_key_source="dual"`` 待重试
3. 把更新后的 ``ProviderConfig`` 写回 YAMLplaintext 清空 + encrypted 列写入
4. 返回每个 provider 的迁移状态
回滚步骤 documented
- YAML ``api_key_source`` 改回 ``"plaintext"``
- plaintext ``api_key`` 重新填回从备份或 KMS 重新注入
- ``api_key_encrypted`` 列可保留不影响 plaintext 读取路径
- 重启服务即可
ponytail: CLI 命令接线``agentkit llm migrate-keys``延迟实现
本模块的 ``migrate_api_keys_to_secrets`` 函数已可被 CLI / 运维脚本直接调用
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
流程
1. 加载 YAML 配置不依赖 ServerConfig直接用 LLMConfig.from_dict
2. 初始化 SecretsStoremaster key env ``AGENTKIT_MASTER_KEY`` 读取
3. 对每个 provider 异步执行 ``migrate_to_secrets``
4. 把更新后的 providers 段写回 YAML保留其它段不变
5. 返回 ``{provider_name: {"status": ..., "source": ...}}`` 状态报告
Args:
config_path: ``agentkit.yaml`` 路径
Returns:
每个 provider 的迁移状态字典
``{"status": "migrated"|"skipped"|"failed", "source": str, "error"?: str}``
"""
import asyncio
import yaml
from agentkit.channels.secrets import SecretsStore
from agentkit.llm.config import LLMConfig
config_path = Path(config_path)
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
llm_section = raw.get("llm", {})
llm_config = LLMConfig.from_dict(llm_section)
store = SecretsStore() # master key 从 env 加载
async def _run() -> dict[str, dict[str, Any]]:
report: dict[str, dict[str, Any]] = {}
for name, pconf in llm_config.providers.items():
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
report[name] = {"status": "skipped", "source": pconf.api_key_source}
continue
if not pconf.api_key:
report[name] = {
"status": "skipped",
"source": pconf.api_key_source,
"error": "no plaintext api_key to migrate",
}
continue
try:
await pconf.migrate_to_secrets(store)
report[name] = {
"status": "migrated" if pconf.api_key_source == "secrets_store" else "partial",
"source": pconf.api_key_source,
}
except Exception as e:
report[name] = {
"status": "failed",
"source": pconf.api_key_source,
"error": str(e),
}
return report
report = asyncio.run(_run())
# 写回 YAML更新 llm.providers 段,保留其它段
providers_out: dict[str, dict[str, Any]] = {}
for name, pconf in llm_config.providers.items():
entry: dict[str, Any] = {
"type": pconf.type,
"base_url": pconf.base_url,
"models": pconf.models,
"max_tokens": pconf.max_tokens,
"timeout": pconf.timeout,
"max_connections": pconf.max_connections,
"max_keepalive_connections": pconf.max_keepalive_connections,
"keepalive_expiry": pconf.keepalive_expiry,
# plaintext 已清空(迁移成功)或保留(迁移失败 / dual
"api_key": pconf.api_key,
"api_key_encrypted": pconf.api_key_encrypted,
"api_key_source": pconf.api_key_source,
}
if pconf.retry is not None:
entry["retry"] = {
"max_retries": pconf.retry.max_retries,
"base_delay": pconf.retry.base_delay,
"max_delay": pconf.retry.max_delay,
"exponential_base": pconf.retry.exponential_base,
}
if pconf.circuit_breaker is not None:
entry["circuit_breaker"] = {
"failure_threshold": pconf.circuit_breaker.failure_threshold,
"recovery_timeout": pconf.circuit_breaker.recovery_timeout,
"half_open_max": pconf.circuit_breaker.half_open_max,
}
providers_out[name] = entry
raw.setdefault("llm", {})["providers"] = providers_out
config_path.write_text(
yaml.safe_dump(raw, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
return report
# ponytail: CLI wiring deferred — 把本函数接到 ``agentkit llm migrate-keys``
# Typer 命令时,只需在 cli/ 下新增 thin wrapper 调用本函数并打印 report。

View File

@ -3,6 +3,10 @@
from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.doubao import DoubaoProvider from agentkit.llm.providers.doubao import DoubaoProvider
from agentkit.llm.providers.gemini import GeminiProvider from agentkit.llm.providers.gemini import GeminiProvider
from agentkit.llm.providers.litellm_provider import (
LitellmProvider,
create_litellm_provider,
)
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.llm.providers.usage_store import UsageRecord from agentkit.llm.providers.usage_store import UsageRecord
@ -13,10 +17,12 @@ __all__ = [
"AnthropicProvider", "AnthropicProvider",
"DoubaoProvider", "DoubaoProvider",
"GeminiProvider", "GeminiProvider",
"LitellmProvider",
"OpenAICompatibleProvider", "OpenAICompatibleProvider",
"UsageRecord", "UsageRecord",
"UsageSummary", "UsageSummary",
"UsageTracker", "UsageTracker",
"WenxinProvider", "WenxinProvider",
"YuanbaoProvider", "YuanbaoProvider",
"create_litellm_provider",
] ]

View File

@ -0,0 +1,384 @@
"""U15 — LiteLLM 统一 Provider 适配层。
LiteLLM ``acompletion()`` 统一接口替换 6 个直接 API provider 适配器
OpenAI/Anthropic/Gemini/Doubao/Wenxin/YuanbaoLiteLLM 内部处理各家 API
的差异消息格式tool_calls 格式streaming SSE 协议本模块只负责
1. ``LLMRequest`` 翻译成 LiteLLM ``acompletion`` kwargs
2. LiteLLM 响应OpenAI ChatCompletion 格式翻译回 ``LLMResponse`` /
``StreamChunk``
3. LiteLLM 异常包装成 ``LLMProviderError``保留 fallback 链可用性
设计取舍ponytail
- 自建的 fallback / usage tracking / 部门配额 全部保留在 ``LLMGateway``
provider 不重复实现LiteLLM 自带的 fallback / retry 在此处禁用
``num_retries=0``避免与 gateway fallback 重复叠加导致超时放大
- ``wenxin`` LiteLLM 中无原生支持回退到 ``openai/`` 前缀 + 自定义
``api_base``千帆 OpenAI 兼容端点升级路径LiteLLM 上游支持后切换到
``wenxin/`` 前缀
- 旧的 6 provider ``OpenAICompatibleProvider`` 保留为死代码一个
release便于回滚新代码通过 ``create_litellm_provider`` 工厂构造
"""
from __future__ import annotations
import inspect
import json
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import (
LLMProvider,
LLMRequest,
LLMResponse,
StreamChunk,
TokenUsage,
ToolCall,
)
logger = logging.getLogger(__name__)
# provider_type → LiteLLM model 前缀映射。
# LiteLLM 通过 model 字符串前缀路由到对应 SDK``openai/gpt-4o``、
# ``anthropic/claude-...``、``gemini/gemini-...``、``volcengine/...``、
# ``hunyuan/...``。未知类型回退到 ``openai/``(最大兼容性)。
_PROVIDER_TYPE_TO_PREFIX: dict[str, str] = {
"openai": "openai/",
"anthropic": "anthropic/",
"gemini": "gemini/",
# 豆包由火山引擎提供LiteLLM 前缀为 volcengine/
"doubao": "volcengine/",
# ponytail: LiteLLM 暂无 wenxin 原生支持,回退到 openai/ + 自定义 api_base
# (千帆 OpenAI 兼容端点。ceiling: wenxin 专属参数(如 AK/SK 鉴权)不支持;
# 升级路径:上游支持后改 "wenxin/"。
"wenxin": "openai/",
# 腾讯混元
"yuanbao": "hunyuan/",
}
def _model_prefix_for(provider_type: str) -> str:
"""根据 provider_type 返回 LiteLLM model 前缀;未知类型回退到 openai/。"""
return _PROVIDER_TYPE_TO_PREFIX.get(provider_type, "openai/")
class LitellmProvider(LLMProvider):
"""基于 LiteLLM ``acompletion`` 的统一 Provider。
一个实例对应一个 provider 配置api_key + base_url + provider_type
``chat`` / ``chat_stream`` ``LLMRequest`` 转发给 LiteLLM 并翻译响应
注意本类不持有 httpx 客户端 LiteLLM 内部管理连接池
"""
def __init__(
self,
model_prefix: str,
api_key: str,
base_url: str | None = None,
provider_type: str = "openai",
**default_kwargs: Any,
) -> None:
self._model_prefix = model_prefix
self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
import litellm
kwargs = self._build_kwargs(request, stream=False)
start = time.monotonic()
try:
response = await litellm.acompletion(**kwargs)
except Exception as e:
# LiteLLM 抛出各种异常openai.APIError / anthropic.APIError / 自定义),
# 统一包装成 LLMProviderError 以便 gateway fallback 链识别。
raise LLMProviderError(self._provider_type, str(e)) from e
latency_ms = (time.monotonic() - start) * 1000
return self._parse_response(response, request.model, latency_ms)
async def chat_stream(self, request: LLMRequest) -> AsyncGenerator[StreamChunk, None]:
"""流式 chat — 调用 ``litellm.acompletion(stream=True)`` 并翻译 chunks。
异步生成器安全本项目规则禁止在第一个 ``yield`` 前使用 ``return``
本函数无早退分支无需 ``return; yield`` 守卫
"""
import litellm
kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
final_usage: TokenUsage | None = None
final_model: str = request.model
yielded_any = False
try:
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
# - 真实 litellm返回 coroutineawait 后得到 async generator
# - 测试 mockasync def + yield调用即返回 async generator不可 await。
# 用 isawaitable 兼容两种路径。
raw = litellm.acompletion(**kwargs)
stream = await raw if inspect.isawaitable(raw) else raw
async for chunk in stream:
yielded_any = True
parsed = self._parse_stream_chunk(
chunk,
request.model,
accumulated_tool_calls,
)
# 更新累计状态
if parsed.usage is not None:
final_usage = parsed.usage
if parsed.model:
final_model = parsed.model
# 内容块含空内容yieldusage-only 块也 yield部分 provider
# 在最后一个 chunk 才给 usage内容为空
yield parsed
# 流结束yield 终止 chunk聚合 tool_calls + usage
tool_calls_list = self._finalize_tool_calls(accumulated_tool_calls)
yield StreamChunk(
content="",
model=final_model,
tool_calls=tool_calls_list,
usage=final_usage,
is_final=True,
)
except Exception as e:
raise LLMProviderError(self._provider_type, str(e)) from e
# ponytail: 若流完全为空yielded_any=False上面仍会 yield 一个
# is_final=True 的空 chunk调用方据此判断空响应。无需额外分支。
_ = yielded_any # 标记保留(调试 / 未来扩展)
# ------------------------------------------------------------------
# 内部辅助
# ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, Any] = {
"model": f"{self._model_prefix}{request.model}",
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"api_key": self._api_key,
"stream": stream,
# 禁用 LiteLLM 自带 retry — 由 gateway 的 fallback 链 / RetryPolicy 负责
"num_retries": 0,
}
if self._base_url:
kwargs["api_base"] = self._base_url
if request.tools:
kwargs["tools"] = request.tools
kwargs["tool_choice"] = request.tool_choice
if request.timeout is not None:
kwargs["timeout"] = request.timeout
# 合并构造时传入的默认 kwargs如 max_connections 等provider特定参数
kwargs.update(self._default_kwargs)
return kwargs
def _parse_response(
self,
response: Any,
request_model: str,
latency_ms: float,
) -> LLMResponse:
"""把 litellm 响应OpenAI ChatCompletion 格式)翻译成 LLMResponse。"""
# LiteLLM 响应统一为 OpenAI ChatCompletion 格式:
# response.choices[0].message.{content, tool_calls}
# response.usage.{prompt_tokens, completion_tokens}
# response.model
choices = getattr(response, "choices", None) or []
content = ""
tool_calls: list[ToolCall] = []
if choices:
message = getattr(choices[0], "message", None)
if message is not None:
content = getattr(message, "content", None) or ""
raw_tool_calls = getattr(message, "tool_calls", None)
if raw_tool_calls:
tool_calls = _parse_tool_calls(raw_tool_calls)
usage_obj = getattr(response, "usage", None)
usage = _parse_usage(usage_obj) if usage_obj is not None else TokenUsage()
model_name = getattr(response, "model", None) or request_model
return LLMResponse(
content=content,
model=model_name,
usage=usage,
tool_calls=tool_calls,
latency_ms=latency_ms,
)
def _parse_stream_chunk(
self,
chunk: Any,
request_model: str,
accumulated_tool_calls: dict[int, dict[str, Any]],
) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or []
content = ""
model_name = getattr(chunk, "model", None) or request_model
usage: TokenUsage | None = None
if choices:
delta = getattr(choices[0], "delta", None)
if delta is not None:
content = getattr(delta, "content", None) or ""
raw_tool_calls = getattr(delta, "tool_calls", None)
if raw_tool_calls:
_accumulate_stream_tool_calls(raw_tool_calls, accumulated_tool_calls)
# 部分 provider 在最后一个 chunk 附带 usage
usage_obj = getattr(chunk, "usage", None)
if usage_obj is not None:
usage = _parse_usage(usage_obj)
return StreamChunk(
content=content,
model=model_name,
tool_calls=[], # 流式中间块不带 tool_calls由 final chunk 聚合
usage=usage,
is_final=False,
)
def _finalize_tool_calls(
self,
accumulated: dict[int, dict[str, Any]],
) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = []
for idx in sorted(accumulated.keys()):
tc_data = accumulated[idx]
args_str = tc_data.get("arguments_str", "")
try:
arguments = json.loads(args_str) if args_str else {}
except json.JSONDecodeError:
arguments = {"raw": args_str}
tool_calls.append(
ToolCall(
id=tc_data.get("id", ""),
name=tc_data.get("name", ""),
arguments=arguments,
)
)
return tool_calls
# ----------------------------------------------------------------------
# 响应解析辅助(模块级,便于测试 mock
# ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = []
for tc in raw_tool_calls:
# LiteLLM 返回的对象有 .id / .function.{name,arguments}
tc_id = getattr(tc, "id", "") or ""
func = getattr(tc, "function", None)
if func is None:
continue
name = getattr(func, "name", "") or ""
args = getattr(func, "arguments", "{}")
if isinstance(args, str):
try:
args_dict = json.loads(args) if args else {}
except json.JSONDecodeError:
args_dict = {"raw": args}
elif isinstance(args, dict):
args_dict = args
else:
args_dict = {"raw": str(args)}
result.append(ToolCall(id=tc_id, name=name, arguments=args_dict))
return result
def _parse_usage(usage_obj: Any) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None)
if prompt is None and isinstance(usage_obj, dict):
prompt = usage_obj.get("prompt_tokens", 0)
completion = usage_obj.get("completion_tokens", 0)
return TokenUsage(
prompt_tokens=int(prompt or 0),
completion_tokens=int(completion or 0),
)
def _accumulate_stream_tool_calls(
raw_tool_calls: Any,
accumulated: dict[int, dict[str, Any]],
) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
每个 delta tool_call index/id/function.{name,arguments}arguments
是分片字符串需跨 chunk 拼接
"""
for tc in raw_tool_calls:
idx = getattr(tc, "index", 0) or 0
if idx not in accumulated:
accumulated[idx] = {
"id": "",
"name": "",
"arguments_str": "",
}
tc_id = getattr(tc, "id", None)
if tc_id:
accumulated[idx]["id"] = tc_id
func = getattr(tc, "function", None)
if func is not None:
name = getattr(func, "name", None)
if name:
accumulated[idx]["name"] = name
args_fragment = getattr(func, "arguments", None)
if args_fragment:
accumulated[idx]["arguments_str"] += args_fragment
# ----------------------------------------------------------------------
# 工厂函数
# ----------------------------------------------------------------------
def create_litellm_provider(
provider_type: str,
api_key: str,
base_url: str | None = None,
**kwargs: Any,
) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。
Args:
provider_type: 配置中的 provider 类型字符串
"openai"|"anthropic"|"gemini"|"doubao"|"wenxin"|"yuanbao"
未知类型回退到 "openai/" 前缀
api_key: provider API key
base_url: 可选自定义 API base URL
**kwargs: 透传给 LitellmProvider 的额外默认参数
Returns:
配置好 model_prefix LitellmProvider
"""
prefix = _model_prefix_for(provider_type)
return LitellmProvider(
model_prefix=prefix,
api_key=api_key,
base_url=base_url,
provider_type=provider_type,
**kwargs,
)

View File

@ -11,9 +11,13 @@ from fastapi.middleware.cors import CORSMiddleware
from agentkit.core.agent_pool import AgentPool from agentkit.core.agent_pool import AgentPool
from agentkit.core.event_queue import EventQueue, SubmissionQueue from agentkit.core.event_queue import EventQueue, SubmissionQueue
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.gemini import GeminiProvider # U15: 旧 provider 类保留导入用于向后兼容 / 回滚_create_provider 已切到 LitellmProvider
from agentkit.llm.providers.openai import OpenAICompatibleProvider # 不要删除 — 下个 release 稳定后再清理。
from agentkit.llm.providers.anthropic import AnthropicProvider # noqa: F401
from agentkit.llm.providers.gemini import GeminiProvider # noqa: F401
from agentkit.llm.providers.litellm_provider import create_litellm_provider
from agentkit.llm.providers.openai import OpenAICompatibleProvider # noqa: F401
from agentkit.mcp.manager import MCPManager from agentkit.mcp.manager import MCPManager
from agentkit.quality.gate import QualityGate from agentkit.quality.gate import QualityGate
from agentkit.quality.output import OutputStandardizer from agentkit.quality.output import OutputStandardizer
@ -83,7 +87,8 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
for name, pconf in config.llm_config.providers.items(): for name, pconf in config.llm_config.providers.items():
if not pconf.api_key: # U15: 跳过既无 plaintext api_key 又无加密列的 provider
if not pconf.api_key and not pconf.api_key_encrypted:
continue # Skip providers without API keys continue # Skip providers without API keys
try: try:
provider = _create_provider(name, pconf) provider = _create_provider(name, pconf)
@ -97,45 +102,38 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
def _create_provider(name: str, pconf) -> object: def _create_provider(name: str, pconf) -> object:
"""Create an LLM provider instance from ProviderConfig. """Create an LLM provider instance from ProviderConfig.
Shared by server app and CLI chat to avoid duplicated initialization logic. U15: 统一使用 LitellmProvider 替换 6 个直接 API provider 适配器
旧的 AnthropicProvider / GeminiProvider / OpenAICompatibleProvider 类保留
导入向后兼容 / 回滚但新代码路径走 LiteLLM
ponytail: 加密 key 的异步解密aget_api_key在启动时调用需 async 上下文
此处为同步函数 双读窗口保证 plaintext ``api_key`` 在迁移期仍可用
故直接用 ``pconf.api_key``完全迁移后plaintext 清空需在此处前增加
async 预解密步骤升级路径 _build_llm_gateway 改为 async 并在 lifespan
中调用
""" """
if pconf.type == "anthropic": # api_key 解析:双读窗口优先 plaintext同步可读plaintext 为空时
return AnthropicProvider( # 调用方应在此之前完成 async 解密并写回 pconf.api_key。
api_key=pconf.api_key, api_key = pconf.api_key
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514", if not api_key and pconf.api_key_encrypted:
max_tokens=pconf.max_tokens, # plaintext 已清空但加密列存在 — 同步上下文无法解密,跳过注册
base_url=pconf.base_url or "https://api.anthropic.com", # (由 lifespan 异步预解密路径补注册,或运维回滚 plaintext
timeout=pconf.timeout, logger.warning(
max_connections=pconf.max_connections, f"Provider '{name}' has encrypted key but no plaintext fallback "
max_keepalive_connections=pconf.max_keepalive_connections, f"— skipped (run async pre-decrypt at startup)"
keepalive_expiry=pconf.keepalive_expiry,
) )
elif pconf.type == "gemini": raise ValueError(
return GeminiProvider( f"Provider '{name}' api_key not available synchronously "
api_key=pconf.api_key, f"(encrypted only). Run async pre-decrypt first."
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
max_output_tokens=pconf.max_tokens,
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
timeout=pconf.timeout,
max_connections=pconf.max_connections,
max_keepalive_connections=pconf.max_keepalive_connections,
keepalive_expiry=pconf.keepalive_expiry,
)
else:
if not pconf.base_url:
raise ValueError(
f"Provider '{name}' is missing base_url. "
f"OpenAI-compatible providers require an explicit base_url in config."
)
return OpenAICompatibleProvider(
api_key=pconf.api_key,
base_url=pconf.base_url,
max_connections=pconf.max_connections,
max_keepalive_connections=pconf.max_keepalive_connections,
keepalive_expiry=pconf.keepalive_expiry,
timeout=pconf.timeout,
) )
base_url = pconf.base_url or None
return create_litellm_provider(
provider_type=pconf.type,
api_key=api_key,
base_url=base_url,
)
def _build_skill_registry(config: ServerConfig) -> SkillRegistry: def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
"""Build SkillRegistry from ServerConfig, loading all skill configs.""" """Build SkillRegistry from ServerConfig, loading all skill configs."""

View File

@ -0,0 +1,287 @@
"""U15 — ProviderConfig API Key 加密迁移测试。
覆盖
- get_api_key plaintext 回退
- aget_api_key + secrets_store 解密
- aget_api_key dual-read 回退解密失败 plaintext
- migrate_to_secrets 完整流程
- migrate_to_secrets 幂等
- LLMConfig.from_dict 解析新字段
- LLMConfig.from_dict 缺省新字段
"""
from __future__ import annotations
import json
from unittest.mock import AsyncMock
from agentkit.channels.secrets import SecretsStore
from agentkit.llm.config import LLMConfig, ProviderConfig
# ----------------------------------------------------------------------
# 15: get_api_key plaintext
# ----------------------------------------------------------------------
def test_get_api_key_plaintext_no_store():
"""无 secrets_store 时 get_api_key 返回 plaintext。"""
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
assert pconf.get_api_key(None) == "sk-xxx"
def test_get_api_key_plaintext_with_store_sync():
"""有 secrets_store 但同步调用 — 仍返回 plaintextasync 解密不可用)。"""
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
# 即使传 store同步路径无法 decrypt回退 plaintext
store = object() # 任意非 None 对象
assert pconf.get_api_key(store) == "sk-xxx" # type: ignore[arg-type]
# ----------------------------------------------------------------------
# 16: aget_api_key + secrets_store 解密成功
# ----------------------------------------------------------------------
async def test_aget_api_key_decrypts_from_store():
"""加密列 + secrets_store → 解密返回 key。"""
store = SecretsStore(master_key=b"x" * 32)
# 先在 store 里放一个 key
await store.set_secret("llm:provider:openai:api_key", "sk-decrypted")
# 构造一个加密的 ProviderConfig直接用真实加密流程
pconf = ProviderConfig(
api_key="",
base_url="",
type="openai",
api_key_encrypted="", # 占位,下面重新生成
api_key_source="secrets_store",
)
# 用 store 加密 "sk-decrypted" 得到真实 encrypted 字段
entry = await store.set_secret("llm:provider:openai:api_key", "sk-decrypted")
pconf.api_key_encrypted = ProviderConfig._encode_secret_entry(
entry, "llm:provider:openai:api_key"
)
result = await pconf.aget_api_key(store)
assert result == "sk-decrypted"
# ----------------------------------------------------------------------
# 17: aget_api_key dual-read 回退
# ----------------------------------------------------------------------
async def test_aget_api_key_dual_read_fallback_to_plaintext():
"""加密列存在但解密失败 → 回退 plaintext。"""
# secrets_store 返回 None模拟 key 不存在 / 解密失败)
store = AsyncMock(spec=SecretsStore)
store.get_secret = AsyncMock(return_value=None)
pconf = ProviderConfig(
api_key="sk-plaintext",
base_url="",
type="openai",
api_key_encrypted=json.dumps(
{
"key": "llm:provider:openai:api_key",
"value": "fake",
"nonce": "fake",
"salt": "fake",
"key_id": "default",
"created_at": "",
"updated_at": "",
}
),
api_key_source="dual",
)
result = await pconf.aget_api_key(store)
assert result == "sk-plaintext"
async def test_aget_api_key_decrypt_exception_falls_back():
"""解密抛异常时回退 plaintext。"""
store = AsyncMock(spec=SecretsStore)
store.get_secret = AsyncMock(side_effect=RuntimeError("decrypt failed"))
pconf = ProviderConfig(
api_key="sk-plaintext",
base_url="",
type="openai",
api_key_encrypted=json.dumps(
{
"key": "llm:provider:openai:api_key",
"value": "fake",
"nonce": "fake",
"salt": "fake",
"key_id": "default",
"created_at": "",
"updated_at": "",
}
),
api_key_source="dual",
)
result = await pconf.aget_api_key(store)
assert result == "sk-plaintext"
async def test_aget_api_key_no_encrypted_returns_plaintext():
"""无加密列时直接返回 plaintext。"""
store = SecretsStore(master_key=b"x" * 32)
pconf = ProviderConfig(api_key="sk-plain", base_url="", type="openai")
assert await pconf.aget_api_key(store) == "sk-plain"
# ----------------------------------------------------------------------
# 18: migrate_to_secrets 完整流程
# ----------------------------------------------------------------------
async def test_migrate_to_secrets_full_flow():
"""plaintext → secrets_store 迁移:加密、验证、清空 plaintext。"""
store = SecretsStore(master_key=b"x" * 32)
pconf = ProviderConfig(
api_key="sk-secret",
base_url="",
type="openai",
)
await pconf.migrate_to_secrets(store)
assert pconf.api_key == "" # plaintext 清空
assert pconf.api_key_encrypted is not None # 加密列写入
assert pconf.api_key_source == "secrets_store"
# 验证:通过 aget_api_key 能读回原 key
decrypted = await pconf.aget_api_key(store)
assert decrypted == "sk-secret"
# ----------------------------------------------------------------------
# 19: migrate_to_secrets 幂等
# ----------------------------------------------------------------------
async def test_migrate_to_secrets_idempotent():
"""已迁移的 config 再次调用 migrate_to_secrets 是 no-op。"""
store = SecretsStore(master_key=b"x" * 32)
pconf = ProviderConfig(api_key="sk-secret", base_url="", type="openai")
await pconf.migrate_to_secrets(store)
encrypted_after_first = pconf.api_key_encrypted
assert pconf.api_key_source == "secrets_store"
# 第二次调用 — no-op
await pconf.migrate_to_secrets(store)
assert pconf.api_key_encrypted == encrypted_after_first
assert pconf.api_key_source == "secrets_store"
assert pconf.api_key == ""
async def test_migrate_to_secrets_skips_empty_plaintext():
"""plaintext 为空时跳过迁移。"""
store = SecretsStore(master_key=b"x" * 32)
pconf = ProviderConfig(api_key="", base_url="", type="openai")
await pconf.migrate_to_secrets(store)
assert pconf.api_key == ""
assert pconf.api_key_encrypted is None
assert pconf.api_key_source == "plaintext"
# ----------------------------------------------------------------------
# 20-21: LLMConfig.from_dict 解析新字段
# ----------------------------------------------------------------------
def test_from_dict_parses_new_encrypted_fields():
"""LLMConfig.from_dict 解析 api_key_encrypted / api_key_source。"""
data = {
"providers": {
"openai": {
"api_key": "",
"base_url": "https://api.openai.com/v1",
"type": "openai",
"api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}',
"api_key_source": "secrets_store",
},
},
}
config = LLMConfig.from_dict(data)
pconf = config.providers["openai"]
assert pconf.api_key_encrypted == '{"key":"k","value":"v","nonce":"n","salt":"s"}'
assert pconf.api_key_source == "secrets_store"
def test_from_dict_defaults_new_fields_to_plaintext():
"""LLMConfig.from_dict 缺省新字段时默认 plaintext 行为。"""
data = {
"providers": {
"openai": {
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1",
"type": "openai",
},
},
}
config = LLMConfig.from_dict(data)
pconf = config.providers["openai"]
assert pconf.api_key_encrypted is None
assert pconf.api_key_source == "plaintext"
def test_from_dict_dual_source_during_migration():
"""双写窗口api_key_source="dual" 时两个字段都保留。"""
data = {
"providers": {
"openai": {
"api_key": "sk-plaintext",
"base_url": "",
"type": "openai",
"api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}',
"api_key_source": "dual",
},
},
}
config = LLMConfig.from_dict(data)
pconf = config.providers["openai"]
assert pconf.api_key == "sk-plaintext"
assert pconf.api_key_encrypted is not None
assert pconf.api_key_source == "dual"
# ----------------------------------------------------------------------
# 额外_secret_key_for_type 命名空间
# ----------------------------------------------------------------------
def test_secret_key_namespaced_by_type():
"""secret key 应按 provider type 命名空间隔离。"""
pconf = ProviderConfig(api_key="", base_url="", type="anthropic")
assert pconf._secret_key_for_type() == "llm:provider:anthropic:api_key"
pconf2 = ProviderConfig(api_key="", base_url="", type="gemini")
assert pconf2._secret_key_for_type() == "llm:provider:gemini:api_key"
# ----------------------------------------------------------------------
# 额外encode/decode SecretEntry 往返
# ----------------------------------------------------------------------
async def test_encode_decode_secret_entry_roundtrip():
"""encode → decode 应保留关键字段。"""
store = SecretsStore(master_key=b"x" * 32)
entry = await store.set_secret("llm:provider:openai:api_key", "sk-xxx")
encoded = ProviderConfig._encode_secret_entry(entry, "llm:provider:openai:api_key")
decoded = ProviderConfig._decode_secret_entry(encoded)
assert decoded.key == "llm:provider:openai:api_key"
assert decoded.value == entry.value
assert decoded.nonce == entry.nonce
assert decoded.salt == entry.salt

View File

@ -0,0 +1,543 @@
"""U15 — LitellmProvider 单元测试。
``unittest.mock.AsyncMock`` mock ``litellm.acompletion``覆盖
- 6 provider 类型openai/anthropic/gemini/doubao/wenxin/yuanbao LiteLLM 调用
- 未知 provider_type 回退到 openai/ 前缀
- 自定义 base_url 透传
- tools 透传 + tool_calls 响应解析
- 流式 chunk 迭代 + 终止 chunk
- 流式空消息处理
- LiteLLM 异常 LLMProviderError
- latency_ms 非负
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import LLMRequest, LLMResponse
from agentkit.llm.providers.litellm_provider import (
LitellmProvider,
_model_prefix_for,
create_litellm_provider,
)
# ----------------------------------------------------------------------
# 测试辅助:构造 OpenAI 格式的 fake response / chunk
# ----------------------------------------------------------------------
def _fake_response(
content: str = "Hello!",
model: str = "gpt-4o-mini",
prompt_tokens: int = 10,
completion_tokens: int = 5,
tool_calls: list | None = None,
) -> SimpleNamespace:
"""构造非流式 fake responseOpenAI ChatCompletion 格式)。"""
message = SimpleNamespace(content=content, tool_calls=tool_calls)
return SimpleNamespace(
choices=[SimpleNamespace(message=message)],
usage=SimpleNamespace(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
model=model,
)
def _fake_tool_call(
tc_id: str = "call_1",
name: str = "get_weather",
arguments: str = '{"city": "Beijing"}',
) -> SimpleNamespace:
"""构造非流式 tool_call 对象。"""
return SimpleNamespace(
id=tc_id,
function=SimpleNamespace(name=name, arguments=arguments),
)
def _fake_stream_chunk(
content: str = "",
model: str = "gpt-4o-mini",
tool_calls_delta: list | None = None,
usage: SimpleNamespace | None = None,
) -> SimpleNamespace:
"""构造流式 chunkOpenAI ChatCompletionChunk 格式)。"""
delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta)
return SimpleNamespace(
choices=[SimpleNamespace(delta=delta)],
model=model,
usage=usage,
)
def _fake_stream_tool_call_delta(
index: int = 0,
tc_id: str | None = "call_1",
name: str | None = "get_weather",
arguments_fragment: str | None = None,
) -> SimpleNamespace:
"""构造流式 tool_call delta 片段。"""
return SimpleNamespace(
index=index,
id=tc_id,
function=SimpleNamespace(name=name, arguments=arguments_fragment),
)
# ----------------------------------------------------------------------
# 1-7: provider_type → model_prefix 映射 + 基本 chat
# ----------------------------------------------------------------------
@pytest.mark.parametrize(
"provider_type, expected_prefix",
[
("openai", "openai/"),
("anthropic", "anthropic/"),
("gemini", "gemini/"),
("doubao", "volcengine/"),
("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/
("yuanbao", "hunyuan/"),
("unknown_type", "openai/"), # 未知 → openai/ 回退
],
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"],
)
def test_model_prefix_mapping(provider_type: str, expected_prefix: str):
"""验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。"""
assert _model_prefix_for(provider_type) == expected_prefix
@pytest.mark.parametrize(
"provider_type, expected_prefix, model_name",
[
("openai", "openai/", "gpt-4o-mini"),
("anthropic", "anthropic/", "claude-sonnet-4-20250514"),
("gemini", "gemini/", "gemini-2.0-flash"),
("doubao", "volcengine/", "doubao-pro-32k"),
("wenxin", "openai/", "ernie-4.5-turbo-128k"),
("yuanbao", "hunyuan/", "hunyuan-pro"),
],
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"],
)
async def test_chat_via_litellm_for_each_provider_type(
provider_type: str,
expected_prefix: str,
model_name: str,
):
"""6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。"""
provider = create_litellm_provider(provider_type, api_key="sk-test")
assert provider._model_prefix == expected_prefix
fake_resp = _fake_response(content="ok", model=model_name)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model=model_name,
)
)
assert isinstance(response, LLMResponse)
assert response.content == "ok"
assert response.model == model_name
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 5
# ----------------------------------------------------------------------
# 8: 自定义 base_url 透传
# ----------------------------------------------------------------------
async def test_custom_base_url_passed_through():
"""自定义 base_url 应作为 api_base 传给 litellm.acompletion。"""
provider = create_litellm_provider(
"openai",
api_key="sk-test",
base_url="https://custom.api/v1",
)
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["api_base"] == "https://custom.api/v1"
assert kwargs["api_key"] == "sk-test"
assert kwargs["model"] == "openai/gpt-4o-mini"
# ----------------------------------------------------------------------
# 9: tools 透传
# ----------------------------------------------------------------------
async def test_tools_passed_through():
"""request.tools 应透传给 litellm.acompletion。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "weather?"}],
model="gpt-4o",
tools=tools,
tool_choice="auto",
)
)
_, kwargs = mock_acompletion.call_args
assert kwargs["tools"] == tools
assert kwargs["tool_choice"] == "auto"
# ----------------------------------------------------------------------
# 10: tool_calls 响应解析
# ----------------------------------------------------------------------
async def test_tool_calls_in_response_parsed():
"""非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
content="",
tool_calls=[
_fake_tool_call(
tc_id="call_abc",
name="get_weather",
arguments='{"city": "Beijing"}',
)
],
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.has_tool_calls
assert len(response.tool_calls) == 1
tc = response.tool_calls[0]
assert tc.id == "call_abc"
assert tc.name == "get_weather"
assert tc.arguments == {"city": "Beijing"}
async def test_tool_calls_with_dict_arguments():
"""tool_call.arguments 为 dict 时直接采用。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
tool_calls=[
SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="fn", arguments={"k": "v"}),
)
]
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.tool_calls[0].arguments == {"k": "v"}
# ----------------------------------------------------------------------
# 11: 流式 yields chunks
# ----------------------------------------------------------------------
async def test_streaming_yields_chunks():
"""流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。"""
provider = create_litellm_provider("openai", api_key="sk-test")
chunks = [
_fake_stream_chunk(content="Hello"),
_fake_stream_chunk(content=" world"),
_fake_stream_chunk(
usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2),
),
]
async def _fake_stream(**_kwargs):
for c in chunks:
yield c
with patch("litellm.acompletion", new=_fake_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
):
results.append(chunk)
# 前 3 个是流式 chunkis_final=False最后一个是终止 chunkis_final=True
assert len(results) == 4
assert results[0].content == "Hello"
assert results[1].content == " world"
assert all(not r.is_final for r in results[:3])
assert results[-1].is_final is True
# 终止 chunk 含聚合 usage
assert results[-1].usage is not None
assert results[-1].usage.prompt_tokens == 8
assert results[-1].usage.completion_tokens == 2
async def test_streaming_aggregates_tool_calls():
"""流式 tool_calls 片段应聚合到终止 chunk。"""
provider = create_litellm_provider("openai", api_key="sk-test")
chunks = [
_fake_stream_chunk(
tool_calls_delta=[
_fake_stream_tool_call_delta(
index=0,
tc_id="call_1",
name="get_weather",
arguments_fragment='{"city":',
)
]
),
_fake_stream_chunk(
tool_calls_delta=[
_fake_stream_tool_call_delta(
index=0,
tc_id=None,
name=None,
arguments_fragment=' "Beijing"}',
)
]
),
]
async def _fake_stream(**_kwargs):
for c in chunks:
yield c
with patch("litellm.acompletion", new=_fake_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
results.append(chunk)
final = results[-1]
assert final.is_final
assert len(final.tool_calls) == 1
assert final.tool_calls[0].id == "call_1"
assert final.tool_calls[0].name == "get_weather"
assert final.tool_calls[0].arguments == {"city": "Beijing"}
# ----------------------------------------------------------------------
# 12: 流式空消息处理
# ----------------------------------------------------------------------
async def test_streaming_empty_messages_handled():
"""流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk不崩溃。"""
provider = create_litellm_provider("openai", api_key="sk-test")
async def _empty_stream(**_kwargs):
return
yield # 让函数成为 async generator永不执行
with patch("litellm.acompletion", new=_empty_stream):
results = []
async for chunk in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
results.append(chunk)
# 至少 yield 一个终止 chunk
assert len(results) >= 1
assert results[-1].is_final is True
assert results[-1].content == ""
# ----------------------------------------------------------------------
# 13: LiteLLM 异常 → LLMProviderError
# ----------------------------------------------------------------------
async def test_litellm_exception_raises_provider_error():
"""litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
provider = create_litellm_provider("openai", api_key="sk-test")
with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))):
with pytest.raises(LLMProviderError) as exc_info:
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert "boom" in str(exc_info.value)
async def test_litellm_stream_exception_raises_provider_error():
"""流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
provider = create_litellm_provider("openai", api_key="sk-test")
async def _failing_stream(**_kwargs):
raise RuntimeError("stream boom")
yield # unreachable仅为 async generator 语法
with patch("litellm.acompletion", new=_failing_stream):
with pytest.raises(LLMProviderError) as exc_info:
async for _ in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
pass
assert "stream boom" in str(exc_info.value)
# ----------------------------------------------------------------------
# 14: latency 非负
# ----------------------------------------------------------------------
async def test_latency_measured_non_negative():
"""LLMResponse.latency_ms 应为非负数。"""
provider = create_litellm_provider("openai", api_key="sk-test")
with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.latency_ms >= 0
# ----------------------------------------------------------------------
# 额外num_retries=0 禁用 LiteLLM 自带 retry
# ----------------------------------------------------------------------
async def test_litellm_num_retries_disabled():
"""应传 num_retries=0 禁用 LiteLLM 自带 retry由 gateway fallback 负责)。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["num_retries"] == 0
assert kwargs["stream"] is False
async def test_streaming_passes_stream_true():
"""chat_stream 应传 stream=True。"""
provider = create_litellm_provider("openai", api_key="sk-test")
captured: dict = {}
async def _capturing_stream(**kwargs):
captured.update(kwargs)
async def _inner():
return
yield
return _inner()
with patch("litellm.acompletion", new=_capturing_stream):
async for _ in provider.chat_stream(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
):
pass
assert captured["stream"] is True
# ----------------------------------------------------------------------
# 额外timeout 透传
# ----------------------------------------------------------------------
async def test_timeout_passed_through():
"""request.timeout 应透传给 litellm.acompletion。"""
provider = create_litellm_provider("openai", api_key="sk-test")
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o",
timeout=42.0,
)
)
_, kwargs = mock_acompletion.call_args
assert kwargs["timeout"] == 42.0
# ----------------------------------------------------------------------
# 额外LitellmProvider 直接构造(不经工厂)
# ----------------------------------------------------------------------
async def test_direct_litellm_provider_construction():
"""直接用 LitellmProvider(model_prefix=...) 构造也应工作。"""
provider = LitellmProvider(
model_prefix="anthropic/",
api_key="sk-test",
provider_type="anthropic",
)
fake_resp = _fake_response(content="hi", model="claude-3")
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
)
assert response.content == "hi"
# 验证 model 前缀正确拼接
mock_acompletion = AsyncMock(return_value=fake_resp)
with patch("litellm.acompletion", new=mock_acompletion):
await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
)
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "anthropic/claude-3"
# ----------------------------------------------------------------------
# 额外JSON tool_calls 解析容错
# ----------------------------------------------------------------------
async def test_tool_calls_invalid_json_arguments_wrapped():
"""tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。"""
provider = create_litellm_provider("openai", api_key="sk-test")
fake_resp = _fake_response(
tool_calls=[
_fake_tool_call(
tc_id="call_1",
name="fn",
arguments="not-valid-json{",
)
]
)
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
response = await provider.chat(
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
)
assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}