feat(llm): U15 — LiteLLM unified provider + api_key encrypted secrets migration
This commit is contained in:
parent
13c516a54f
commit
069dbc22b1
|
|
@ -26,6 +26,8 @@ dependencies = [
|
||||||
"aiosqlite>=0.20",
|
"aiosqlite>=0.20",
|
||||||
# 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM)
|
# 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM)
|
||||||
"cryptography>=42.0",
|
"cryptography>=42.0",
|
||||||
|
# U15 — LiteLLM 统一 Provider 适配层(替换 6 个直接 API provider)
|
||||||
|
"litellm>=1.50",
|
||||||
# Calendar & schedule (RRULE expansion)
|
# Calendar & schedule (RRULE expansion)
|
||||||
"python-dateutil>=2.9",
|
"python-dateutil>=2.9",
|
||||||
# Calendar ICS import/export (U8)
|
# Calendar ICS import/export (U8)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,17 @@
|
||||||
"""LLM Config - 配置加载"""
|
"""LLM Config - 配置加载"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.channels.secrets import SecretsStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheConfig:
|
class CacheConfig:
|
||||||
|
|
@ -58,6 +65,140 @@ class ProviderConfig:
|
||||||
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
|
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
|
||||||
retry: RetryConfig | None = None
|
retry: RetryConfig | None = None
|
||||||
circuit_breaker: CircuitBreakerConfig | None = None
|
circuit_breaker: CircuitBreakerConfig | None = None
|
||||||
|
# U15 — API Key 加密迁移(plaintext → SecretsStore)。
|
||||||
|
# api_key_encrypted: JSON 编码的 SecretEntry(base64 nonce+salt+ciphertext)。
|
||||||
|
# 为 None 表示未迁移,仍走 plaintext api_key。
|
||||||
|
# api_key_source: 当前 key 来源标记,用于双写/双读迁移窗口。
|
||||||
|
# "plaintext" — 仅 plaintext 列;"secrets_store" — 仅加密列;
|
||||||
|
# "dual" — 双写窗口中(两列都有值,读时优先加密)。
|
||||||
|
api_key_encrypted: str | None = None
|
||||||
|
api_key_source: str = "plaintext"
|
||||||
|
|
||||||
|
def get_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:
|
||||||
|
"""同步读取 API Key。
|
||||||
|
|
||||||
|
双读窗口优先级:``api_key_encrypted`` + ``secrets_store`` > plaintext。
|
||||||
|
若加密列存在但解密失败(store 为 None 或解密异常),回退到 plaintext
|
||||||
|
``api_key``,保证迁移期可用性。
|
||||||
|
|
||||||
|
注意:``SecretsStore.get_secret`` 是 async,本同步方法无法调用。
|
||||||
|
若 ``api_key_encrypted`` 已设置但 ``secrets_store`` 为 None,仍回退到
|
||||||
|
plaintext。需要解密时请用 ``aget_api_key``。
|
||||||
|
"""
|
||||||
|
if self.api_key_encrypted and secrets_store is not None:
|
||||||
|
# 同步上下文无法 await get_secret,调用方应改用 aget_api_key。
|
||||||
|
# 这里保持双读语义的回退:返回 plaintext。
|
||||||
|
logger.debug("get_api_key: encrypted key set but sync access — fallback to plaintext")
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
async def aget_api_key(self, secrets_store: "SecretsStore | None" = None) -> str:
|
||||||
|
"""异步读取 API Key — 双读窗口优先加密列,失败回退 plaintext。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secrets_store: 可选的加密存储。为 None 时直接返回 plaintext。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解密后的 API Key;加密列解密失败时回退到 plaintext ``api_key``。
|
||||||
|
"""
|
||||||
|
if self.api_key_encrypted and secrets_store is not None:
|
||||||
|
try:
|
||||||
|
entry = self._decode_secret_entry(self.api_key_encrypted)
|
||||||
|
decrypted = await secrets_store.get_secret(entry.key)
|
||||||
|
if decrypted is not None:
|
||||||
|
return decrypted
|
||||||
|
# store 里没有这个 key(可能已被删除)— 回退 plaintext
|
||||||
|
logger.warning(
|
||||||
|
f"aget_api_key: encrypted key for provider type={self.type} "
|
||||||
|
f"not found in secrets_store — fallback to plaintext"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# 解密失败(master key 不匹配 / 密文损坏)— 回退 plaintext
|
||||||
|
logger.warning(
|
||||||
|
f"aget_api_key: decrypt failed for type={self.type}: {e} — fallback to plaintext"
|
||||||
|
)
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
async def migrate_to_secrets(self, secrets_store: "SecretsStore") -> None:
|
||||||
|
"""把 plaintext api_key 迁移到 SecretsStore(幂等)。
|
||||||
|
|
||||||
|
迁移步骤(双写窗口):
|
||||||
|
1. 若 ``api_key_source == "secrets_store"`` 且 plaintext 已清空 → 已迁移,no-op。
|
||||||
|
2. 否则:调用 ``secrets_store.set_secret`` 加密存储 key。
|
||||||
|
3. 把返回的 SecretEntry JSON 编码写入 ``api_key_encrypted``。
|
||||||
|
4. 标记 ``api_key_source = "secrets_store"``。
|
||||||
|
5. 验证:调用 ``get_secret`` 读回对比,成功后清空 plaintext ``api_key=""``。
|
||||||
|
|
||||||
|
幂等性:重复调用不会重复加密(已迁移时直接返回)。
|
||||||
|
部分失败恢复:若 set 成功但验证失败,保留 plaintext 不清空,
|
||||||
|
``api_key_encrypted`` 已写入 — 下次重试时由幂等性保证最终一致。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secrets_store: 用于加密存储的 SecretsStore 实例。
|
||||||
|
"""
|
||||||
|
# 幂等:已迁移完成(source 标记 + plaintext 已清空)
|
||||||
|
if self.api_key_source == "secrets_store" and not self.api_key and self.api_key_encrypted:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 没有 plaintext 可迁移(空 key)— 跳过
|
||||||
|
if not self.api_key:
|
||||||
|
return
|
||||||
|
|
||||||
|
secret_key = self._secret_key_for_type()
|
||||||
|
# 双写:先写加密列
|
||||||
|
entry = await secrets_store.set_secret(secret_key, self.api_key)
|
||||||
|
self.api_key_encrypted = self._encode_secret_entry(entry, secret_key)
|
||||||
|
|
||||||
|
# 验证:读回对比
|
||||||
|
try:
|
||||||
|
decrypted = await secrets_store.get_secret(secret_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"migrate_to_secrets: verify read failed for type={self.type}: {e}")
|
||||||
|
# 加密列已写但验证失败 — 保留 plaintext,标记 dual 待重试
|
||||||
|
self.api_key_source = "dual"
|
||||||
|
return
|
||||||
|
|
||||||
|
if decrypted != self.api_key:
|
||||||
|
logger.error(
|
||||||
|
f"migrate_to_secrets: verify mismatch for type={self.type} "
|
||||||
|
f"— plaintext retained, source=dual"
|
||||||
|
)
|
||||||
|
self.api_key_source = "dual"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 验证通过:清空 plaintext,标记完成
|
||||||
|
self.api_key_source = "secrets_store"
|
||||||
|
self.api_key = ""
|
||||||
|
|
||||||
|
def _secret_key_for_type(self) -> str:
|
||||||
|
"""生成 SecretsStore 中的 key(按 provider type 命名空间隔离)。"""
|
||||||
|
return f"llm:provider:{self.type}:api_key"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_secret_entry(entry: Any, key: str) -> str:
|
||||||
|
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
|
||||||
|
# entry 是 SecretEntry pydantic 模型,有 model_dump()
|
||||||
|
if hasattr(entry, "model_dump"):
|
||||||
|
data = entry.model_dump()
|
||||||
|
else:
|
||||||
|
data = dict(entry)
|
||||||
|
data["key"] = key
|
||||||
|
return json.dumps(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_secret_entry(encoded: str) -> Any:
|
||||||
|
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
|
||||||
|
from agentkit.channels.secrets import SecretEntry
|
||||||
|
|
||||||
|
data = json.loads(encoded)
|
||||||
|
return SecretEntry(
|
||||||
|
key=data.get("key", ""),
|
||||||
|
value=data["value"],
|
||||||
|
nonce=data["nonce"],
|
||||||
|
salt=data["salt"],
|
||||||
|
key_id=data.get("key_id", "default"),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
updated_at=data.get("updated_at", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -105,6 +246,9 @@ class LLMConfig:
|
||||||
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||||||
retry=retry,
|
retry=retry,
|
||||||
circuit_breaker=circuit_breaker,
|
circuit_breaker=circuit_breaker,
|
||||||
|
# U15 — 新增加密迁移字段,缺省时保持 plaintext 行为
|
||||||
|
api_key_encrypted=pconf.get("api_key_encrypted"),
|
||||||
|
api_key_source=pconf.get("api_key_source", "plaintext"),
|
||||||
)
|
)
|
||||||
cache = None
|
cache = None
|
||||||
cache_data = data.get("cache")
|
cache_data = data.get("cache")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
"""U15 — API Key 加密迁移辅助函数。
|
||||||
|
|
||||||
|
把 ``agentkit.yaml`` 中 LLM provider 的 plaintext ``api_key`` 迁移到
|
||||||
|
:class:`SecretsStore` 加密存储,并更新配置文件中的 ``api_key_encrypted`` /
|
||||||
|
``api_key_source`` 字段。
|
||||||
|
|
||||||
|
迁移模型(双写/双读窗口):
|
||||||
|
1. 读 ``agentkit.yaml`` → ``LLMConfig.providers``。
|
||||||
|
2. 对每个 provider 调用 ``ProviderConfig.migrate_to_secrets(store)``:
|
||||||
|
- 加密 plaintext 写入 SecretsStore;
|
||||||
|
- 验证读回一致后清空 plaintext,标记 ``api_key_source="secrets_store"``;
|
||||||
|
- 部分失败时保留 plaintext,标记 ``api_key_source="dual"`` 待重试。
|
||||||
|
3. 把更新后的 ``ProviderConfig`` 写回 YAML(plaintext 清空 + encrypted 列写入)。
|
||||||
|
4. 返回每个 provider 的迁移状态。
|
||||||
|
|
||||||
|
回滚步骤( documented ):
|
||||||
|
- 把 YAML 中 ``api_key_source`` 改回 ``"plaintext"``;
|
||||||
|
- 把 plaintext ``api_key`` 重新填回(从备份或 KMS 重新注入);
|
||||||
|
- ``api_key_encrypted`` 列可保留(不影响 plaintext 读取路径)。
|
||||||
|
- 重启服务即可。
|
||||||
|
|
||||||
|
ponytail: CLI 命令接线(``agentkit llm migrate-keys``)延迟实现 —
|
||||||
|
本模块的 ``migrate_api_keys_to_secrets`` 函数已可被 CLI / 运维脚本直接调用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
|
||||||
|
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 加载 YAML 配置(不依赖 ServerConfig,直接用 LLMConfig.from_dict)。
|
||||||
|
2. 初始化 SecretsStore(master key 从 env ``AGENTKIT_MASTER_KEY`` 读取)。
|
||||||
|
3. 对每个 provider 异步执行 ``migrate_to_secrets``。
|
||||||
|
4. 把更新后的 providers 段写回 YAML(保留其它段不变)。
|
||||||
|
5. 返回 ``{provider_name: {"status": ..., "source": ...}}`` 状态报告。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: ``agentkit.yaml`` 路径。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每个 provider 的迁移状态字典:
|
||||||
|
``{"status": "migrated"|"skipped"|"failed", "source": str, "error"?: str}``。
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.channels.secrets import SecretsStore
|
||||||
|
from agentkit.llm.config import LLMConfig
|
||||||
|
|
||||||
|
config_path = Path(config_path)
|
||||||
|
raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
|
||||||
|
llm_section = raw.get("llm", {})
|
||||||
|
llm_config = LLMConfig.from_dict(llm_section)
|
||||||
|
|
||||||
|
store = SecretsStore() # master key 从 env 加载
|
||||||
|
|
||||||
|
async def _run() -> dict[str, dict[str, Any]]:
|
||||||
|
report: dict[str, dict[str, Any]] = {}
|
||||||
|
for name, pconf in llm_config.providers.items():
|
||||||
|
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
|
||||||
|
report[name] = {"status": "skipped", "source": pconf.api_key_source}
|
||||||
|
continue
|
||||||
|
if not pconf.api_key:
|
||||||
|
report[name] = {
|
||||||
|
"status": "skipped",
|
||||||
|
"source": pconf.api_key_source,
|
||||||
|
"error": "no plaintext api_key to migrate",
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await pconf.migrate_to_secrets(store)
|
||||||
|
report[name] = {
|
||||||
|
"status": "migrated" if pconf.api_key_source == "secrets_store" else "partial",
|
||||||
|
"source": pconf.api_key_source,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
report[name] = {
|
||||||
|
"status": "failed",
|
||||||
|
"source": pconf.api_key_source,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
return report
|
||||||
|
|
||||||
|
report = asyncio.run(_run())
|
||||||
|
|
||||||
|
# 写回 YAML:更新 llm.providers 段,保留其它段
|
||||||
|
providers_out: dict[str, dict[str, Any]] = {}
|
||||||
|
for name, pconf in llm_config.providers.items():
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"type": pconf.type,
|
||||||
|
"base_url": pconf.base_url,
|
||||||
|
"models": pconf.models,
|
||||||
|
"max_tokens": pconf.max_tokens,
|
||||||
|
"timeout": pconf.timeout,
|
||||||
|
"max_connections": pconf.max_connections,
|
||||||
|
"max_keepalive_connections": pconf.max_keepalive_connections,
|
||||||
|
"keepalive_expiry": pconf.keepalive_expiry,
|
||||||
|
# plaintext 已清空(迁移成功)或保留(迁移失败 / dual)
|
||||||
|
"api_key": pconf.api_key,
|
||||||
|
"api_key_encrypted": pconf.api_key_encrypted,
|
||||||
|
"api_key_source": pconf.api_key_source,
|
||||||
|
}
|
||||||
|
if pconf.retry is not None:
|
||||||
|
entry["retry"] = {
|
||||||
|
"max_retries": pconf.retry.max_retries,
|
||||||
|
"base_delay": pconf.retry.base_delay,
|
||||||
|
"max_delay": pconf.retry.max_delay,
|
||||||
|
"exponential_base": pconf.retry.exponential_base,
|
||||||
|
}
|
||||||
|
if pconf.circuit_breaker is not None:
|
||||||
|
entry["circuit_breaker"] = {
|
||||||
|
"failure_threshold": pconf.circuit_breaker.failure_threshold,
|
||||||
|
"recovery_timeout": pconf.circuit_breaker.recovery_timeout,
|
||||||
|
"half_open_max": pconf.circuit_breaker.half_open_max,
|
||||||
|
}
|
||||||
|
providers_out[name] = entry
|
||||||
|
|
||||||
|
raw.setdefault("llm", {})["providers"] = providers_out
|
||||||
|
config_path.write_text(
|
||||||
|
yaml.safe_dump(raw, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
# ponytail: CLI wiring deferred — 把本函数接到 ``agentkit llm migrate-keys``
|
||||||
|
# Typer 命令时,只需在 cli/ 下新增 thin wrapper 调用本函数并打印 report。
|
||||||
|
|
@ -3,6 +3,10 @@
|
||||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||||
from agentkit.llm.providers.doubao import DoubaoProvider
|
from agentkit.llm.providers.doubao import DoubaoProvider
|
||||||
from agentkit.llm.providers.gemini import GeminiProvider
|
from agentkit.llm.providers.gemini import GeminiProvider
|
||||||
|
from agentkit.llm.providers.litellm_provider import (
|
||||||
|
LitellmProvider,
|
||||||
|
create_litellm_provider,
|
||||||
|
)
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||||
from agentkit.llm.providers.usage_store import UsageRecord
|
from agentkit.llm.providers.usage_store import UsageRecord
|
||||||
|
|
@ -13,10 +17,12 @@ __all__ = [
|
||||||
"AnthropicProvider",
|
"AnthropicProvider",
|
||||||
"DoubaoProvider",
|
"DoubaoProvider",
|
||||||
"GeminiProvider",
|
"GeminiProvider",
|
||||||
|
"LitellmProvider",
|
||||||
"OpenAICompatibleProvider",
|
"OpenAICompatibleProvider",
|
||||||
"UsageRecord",
|
"UsageRecord",
|
||||||
"UsageSummary",
|
"UsageSummary",
|
||||||
"UsageTracker",
|
"UsageTracker",
|
||||||
"WenxinProvider",
|
"WenxinProvider",
|
||||||
"YuanbaoProvider",
|
"YuanbaoProvider",
|
||||||
|
"create_litellm_provider",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,384 @@
|
||||||
|
"""U15 — LiteLLM 统一 Provider 适配层。
|
||||||
|
|
||||||
|
用 LiteLLM 的 ``acompletion()`` 统一接口替换 6 个直接 API provider 适配器
|
||||||
|
(OpenAI/Anthropic/Gemini/Doubao/Wenxin/Yuanbao)。LiteLLM 内部处理各家 API
|
||||||
|
的差异(消息格式、tool_calls 格式、streaming SSE 协议),本模块只负责:
|
||||||
|
|
||||||
|
1. 把 ``LLMRequest`` 翻译成 LiteLLM ``acompletion`` 的 kwargs;
|
||||||
|
2. 把 LiteLLM 响应(OpenAI ChatCompletion 格式)翻译回 ``LLMResponse`` /
|
||||||
|
``StreamChunk``;
|
||||||
|
3. 把 LiteLLM 异常包装成 ``LLMProviderError``,保留 fallback 链可用性。
|
||||||
|
|
||||||
|
设计取舍(ponytail):
|
||||||
|
- 自建的 fallback 链 / usage tracking / 部门配额 全部保留在 ``LLMGateway``,
|
||||||
|
本 provider 不重复实现。LiteLLM 自带的 fallback / retry 在此处禁用
|
||||||
|
(``num_retries=0``),避免与 gateway 的 fallback 重复叠加导致超时放大。
|
||||||
|
- ``wenxin`` 在 LiteLLM 中无原生支持,回退到 ``openai/`` 前缀 + 自定义
|
||||||
|
``api_base``(千帆 OpenAI 兼容端点)。升级路径:LiteLLM 上游支持后切换到
|
||||||
|
``wenxin/`` 前缀。
|
||||||
|
- 旧的 6 个 provider 类(``OpenAICompatibleProvider`` 等)保留为死代码一个
|
||||||
|
release,便于回滚;新代码通过 ``create_litellm_provider`` 工厂构造。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
|
from agentkit.llm.protocol import (
|
||||||
|
LLMProvider,
|
||||||
|
LLMRequest,
|
||||||
|
LLMResponse,
|
||||||
|
StreamChunk,
|
||||||
|
TokenUsage,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# provider_type → LiteLLM model 前缀映射。
|
||||||
|
# LiteLLM 通过 model 字符串前缀路由到对应 SDK:``openai/gpt-4o``、
|
||||||
|
# ``anthropic/claude-...``、``gemini/gemini-...``、``volcengine/...``、
|
||||||
|
# ``hunyuan/...``。未知类型回退到 ``openai/``(最大兼容性)。
|
||||||
|
_PROVIDER_TYPE_TO_PREFIX: dict[str, str] = {
|
||||||
|
"openai": "openai/",
|
||||||
|
"anthropic": "anthropic/",
|
||||||
|
"gemini": "gemini/",
|
||||||
|
# 豆包由火山引擎提供,LiteLLM 前缀为 volcengine/
|
||||||
|
"doubao": "volcengine/",
|
||||||
|
# ponytail: LiteLLM 暂无 wenxin 原生支持,回退到 openai/ + 自定义 api_base
|
||||||
|
# (千帆 OpenAI 兼容端点)。ceiling: wenxin 专属参数(如 AK/SK 鉴权)不支持;
|
||||||
|
# 升级路径:上游支持后改 "wenxin/"。
|
||||||
|
"wenxin": "openai/",
|
||||||
|
# 腾讯混元
|
||||||
|
"yuanbao": "hunyuan/",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _model_prefix_for(provider_type: str) -> str:
|
||||||
|
"""根据 provider_type 返回 LiteLLM model 前缀;未知类型回退到 openai/。"""
|
||||||
|
return _PROVIDER_TYPE_TO_PREFIX.get(provider_type, "openai/")
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmProvider(LLMProvider):
|
||||||
|
"""基于 LiteLLM ``acompletion`` 的统一 Provider。
|
||||||
|
|
||||||
|
一个实例对应一个 provider 配置(api_key + base_url + provider_type)。
|
||||||
|
``chat`` / ``chat_stream`` 把 ``LLMRequest`` 转发给 LiteLLM 并翻译响应。
|
||||||
|
|
||||||
|
注意:本类不持有 httpx 客户端 — LiteLLM 内部管理连接池。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_prefix: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
provider_type: str = "openai",
|
||||||
|
**default_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self._model_prefix = model_prefix
|
||||||
|
self._api_key = api_key
|
||||||
|
self._base_url = base_url or None # 空字符串视作未设置
|
||||||
|
self._provider_type = provider_type
|
||||||
|
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
|
||||||
|
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
kwargs = self._build_kwargs(request, stream=False)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# LiteLLM 抛出各种异常(openai.APIError / anthropic.APIError / 自定义),
|
||||||
|
# 统一包装成 LLMProviderError 以便 gateway fallback 链识别。
|
||||||
|
raise LLMProviderError(self._provider_type, str(e)) from e
|
||||||
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
|
|
||||||
|
return self._parse_response(response, request.model, latency_ms)
|
||||||
|
|
||||||
|
async def chat_stream(self, request: LLMRequest) -> AsyncGenerator[StreamChunk, None]:
|
||||||
|
"""流式 chat — 调用 ``litellm.acompletion(stream=True)`` 并翻译 chunks。
|
||||||
|
|
||||||
|
异步生成器安全:本项目规则禁止在第一个 ``yield`` 前使用 ``return``,
|
||||||
|
本函数无早退分支,无需 ``return; yield`` 守卫。
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
kwargs = self._build_kwargs(request, stream=True)
|
||||||
|
|
||||||
|
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
final_usage: TokenUsage | None = None
|
||||||
|
final_model: str = request.model
|
||||||
|
yielded_any = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# litellm.acompletion(stream=True) 的返回类型取决于版本 / 调用方式:
|
||||||
|
# - 真实 litellm:返回 coroutine,await 后得到 async generator;
|
||||||
|
# - 测试 mock(async def + yield):调用即返回 async generator,不可 await。
|
||||||
|
# 用 isawaitable 兼容两种路径。
|
||||||
|
raw = litellm.acompletion(**kwargs)
|
||||||
|
stream = await raw if inspect.isawaitable(raw) else raw
|
||||||
|
async for chunk in stream:
|
||||||
|
yielded_any = True
|
||||||
|
parsed = self._parse_stream_chunk(
|
||||||
|
chunk,
|
||||||
|
request.model,
|
||||||
|
accumulated_tool_calls,
|
||||||
|
)
|
||||||
|
# 更新累计状态
|
||||||
|
if parsed.usage is not None:
|
||||||
|
final_usage = parsed.usage
|
||||||
|
if parsed.model:
|
||||||
|
final_model = parsed.model
|
||||||
|
# 内容块(含空内容)yield;usage-only 块也 yield(部分 provider
|
||||||
|
# 在最后一个 chunk 才给 usage,内容为空)。
|
||||||
|
yield parsed
|
||||||
|
|
||||||
|
# 流结束:yield 终止 chunk(聚合 tool_calls + usage)
|
||||||
|
tool_calls_list = self._finalize_tool_calls(accumulated_tool_calls)
|
||||||
|
yield StreamChunk(
|
||||||
|
content="",
|
||||||
|
model=final_model,
|
||||||
|
tool_calls=tool_calls_list,
|
||||||
|
usage=final_usage,
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMProviderError(self._provider_type, str(e)) from e
|
||||||
|
|
||||||
|
# ponytail: 若流完全为空(yielded_any=False),上面仍会 yield 一个
|
||||||
|
# is_final=True 的空 chunk,调用方据此判断空响应。无需额外分支。
|
||||||
|
_ = yielded_any # 标记保留(调试 / 未来扩展)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 内部辅助
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
|
||||||
|
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": f"{self._model_prefix}{request.model}",
|
||||||
|
"messages": request.messages,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"api_key": self._api_key,
|
||||||
|
"stream": stream,
|
||||||
|
# 禁用 LiteLLM 自带 retry — 由 gateway 的 fallback 链 / RetryPolicy 负责
|
||||||
|
"num_retries": 0,
|
||||||
|
}
|
||||||
|
if self._base_url:
|
||||||
|
kwargs["api_base"] = self._base_url
|
||||||
|
if request.tools:
|
||||||
|
kwargs["tools"] = request.tools
|
||||||
|
kwargs["tool_choice"] = request.tool_choice
|
||||||
|
if request.timeout is not None:
|
||||||
|
kwargs["timeout"] = request.timeout
|
||||||
|
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||||||
|
kwargs.update(self._default_kwargs)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _parse_response(
|
||||||
|
self,
|
||||||
|
response: Any,
|
||||||
|
request_model: str,
|
||||||
|
latency_ms: float,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""把 litellm 响应(OpenAI ChatCompletion 格式)翻译成 LLMResponse。"""
|
||||||
|
# LiteLLM 响应统一为 OpenAI ChatCompletion 格式:
|
||||||
|
# response.choices[0].message.{content, tool_calls}
|
||||||
|
# response.usage.{prompt_tokens, completion_tokens}
|
||||||
|
# response.model
|
||||||
|
choices = getattr(response, "choices", None) or []
|
||||||
|
content = ""
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
if choices:
|
||||||
|
message = getattr(choices[0], "message", None)
|
||||||
|
if message is not None:
|
||||||
|
content = getattr(message, "content", None) or ""
|
||||||
|
raw_tool_calls = getattr(message, "tool_calls", None)
|
||||||
|
if raw_tool_calls:
|
||||||
|
tool_calls = _parse_tool_calls(raw_tool_calls)
|
||||||
|
|
||||||
|
usage_obj = getattr(response, "usage", None)
|
||||||
|
usage = _parse_usage(usage_obj) if usage_obj is not None else TokenUsage()
|
||||||
|
|
||||||
|
model_name = getattr(response, "model", None) or request_model
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model_name,
|
||||||
|
usage=usage,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_stream_chunk(
|
||||||
|
self,
|
||||||
|
chunk: Any,
|
||||||
|
request_model: str,
|
||||||
|
accumulated_tool_calls: dict[int, dict[str, Any]],
|
||||||
|
) -> StreamChunk:
|
||||||
|
"""解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。"""
|
||||||
|
choices = getattr(chunk, "choices", None) or []
|
||||||
|
content = ""
|
||||||
|
model_name = getattr(chunk, "model", None) or request_model
|
||||||
|
usage: TokenUsage | None = None
|
||||||
|
|
||||||
|
if choices:
|
||||||
|
delta = getattr(choices[0], "delta", None)
|
||||||
|
if delta is not None:
|
||||||
|
content = getattr(delta, "content", None) or ""
|
||||||
|
raw_tool_calls = getattr(delta, "tool_calls", None)
|
||||||
|
if raw_tool_calls:
|
||||||
|
_accumulate_stream_tool_calls(raw_tool_calls, accumulated_tool_calls)
|
||||||
|
|
||||||
|
# 部分 provider 在最后一个 chunk 附带 usage
|
||||||
|
usage_obj = getattr(chunk, "usage", None)
|
||||||
|
if usage_obj is not None:
|
||||||
|
usage = _parse_usage(usage_obj)
|
||||||
|
|
||||||
|
return StreamChunk(
|
||||||
|
content=content,
|
||||||
|
model=model_name,
|
||||||
|
tool_calls=[], # 流式中间块不带 tool_calls,由 final chunk 聚合
|
||||||
|
usage=usage,
|
||||||
|
is_final=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _finalize_tool_calls(
|
||||||
|
self,
|
||||||
|
accumulated: dict[int, dict[str, Any]],
|
||||||
|
) -> list[ToolCall]:
|
||||||
|
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
for idx in sorted(accumulated.keys()):
|
||||||
|
tc_data = accumulated[idx]
|
||||||
|
args_str = tc_data.get("arguments_str", "")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(args_str) if args_str else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {"raw": args_str}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc_data.get("id", ""),
|
||||||
|
name=tc_data.get("name", ""),
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 响应解析辅助(模块级,便于测试 mock)
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
|
||||||
|
"""解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。"""
|
||||||
|
result: list[ToolCall] = []
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
# LiteLLM 返回的对象有 .id / .function.{name,arguments}
|
||||||
|
tc_id = getattr(tc, "id", "") or ""
|
||||||
|
func = getattr(tc, "function", None)
|
||||||
|
if func is None:
|
||||||
|
continue
|
||||||
|
name = getattr(func, "name", "") or ""
|
||||||
|
args = getattr(func, "arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args_dict = json.loads(args) if args else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args_dict = {"raw": args}
|
||||||
|
elif isinstance(args, dict):
|
||||||
|
args_dict = args
|
||||||
|
else:
|
||||||
|
args_dict = {"raw": str(args)}
|
||||||
|
result.append(ToolCall(id=tc_id, name=name, arguments=args_dict))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_usage(usage_obj: Any) -> TokenUsage:
|
||||||
|
"""解析 usage 对象(OpenAI CompletionUsage 或 dict)。"""
|
||||||
|
prompt = getattr(usage_obj, "prompt_tokens", None)
|
||||||
|
completion = getattr(usage_obj, "completion_tokens", None)
|
||||||
|
if prompt is None and isinstance(usage_obj, dict):
|
||||||
|
prompt = usage_obj.get("prompt_tokens", 0)
|
||||||
|
completion = usage_obj.get("completion_tokens", 0)
|
||||||
|
return TokenUsage(
|
||||||
|
prompt_tokens=int(prompt or 0),
|
||||||
|
completion_tokens=int(completion or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_stream_tool_calls(
|
||||||
|
raw_tool_calls: Any,
|
||||||
|
accumulated: dict[int, dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。
|
||||||
|
|
||||||
|
每个 delta tool_call 有 index/id/function.{name,arguments},arguments
|
||||||
|
是分片字符串,需跨 chunk 拼接。
|
||||||
|
"""
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
idx = getattr(tc, "index", 0) or 0
|
||||||
|
if idx not in accumulated:
|
||||||
|
accumulated[idx] = {
|
||||||
|
"id": "",
|
||||||
|
"name": "",
|
||||||
|
"arguments_str": "",
|
||||||
|
}
|
||||||
|
tc_id = getattr(tc, "id", None)
|
||||||
|
if tc_id:
|
||||||
|
accumulated[idx]["id"] = tc_id
|
||||||
|
func = getattr(tc, "function", None)
|
||||||
|
if func is not None:
|
||||||
|
name = getattr(func, "name", None)
|
||||||
|
if name:
|
||||||
|
accumulated[idx]["name"] = name
|
||||||
|
args_fragment = getattr(func, "arguments", None)
|
||||||
|
if args_fragment:
|
||||||
|
accumulated[idx]["arguments_str"] += args_fragment
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 工厂函数
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_litellm_provider(
|
||||||
|
provider_type: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LitellmProvider:
|
||||||
|
"""根据 provider_type 创建 LitellmProvider 实例。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: 配置中的 provider 类型字符串
|
||||||
|
("openai"|"anthropic"|"gemini"|"doubao"|"wenxin"|"yuanbao")。
|
||||||
|
未知类型回退到 "openai/" 前缀。
|
||||||
|
api_key: provider API key。
|
||||||
|
base_url: 可选自定义 API base URL。
|
||||||
|
**kwargs: 透传给 LitellmProvider 的额外默认参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好 model_prefix 的 LitellmProvider。
|
||||||
|
"""
|
||||||
|
prefix = _model_prefix_for(provider_type)
|
||||||
|
return LitellmProvider(
|
||||||
|
model_prefix=prefix,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
provider_type=provider_type,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
@ -11,9 +11,13 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from agentkit.core.agent_pool import AgentPool
|
from agentkit.core.agent_pool import AgentPool
|
||||||
from agentkit.core.event_queue import EventQueue, SubmissionQueue
|
from agentkit.core.event_queue import EventQueue, SubmissionQueue
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
|
||||||
from agentkit.llm.providers.gemini import GeminiProvider
|
# U15: 旧 provider 类保留导入用于向后兼容 / 回滚(_create_provider 已切到 LitellmProvider)。
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
# 不要删除 — 下个 release 稳定后再清理。
|
||||||
|
from agentkit.llm.providers.anthropic import AnthropicProvider # noqa: F401
|
||||||
|
from agentkit.llm.providers.gemini import GeminiProvider # noqa: F401
|
||||||
|
from agentkit.llm.providers.litellm_provider import create_litellm_provider
|
||||||
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider # noqa: F401
|
||||||
from agentkit.mcp.manager import MCPManager
|
from agentkit.mcp.manager import MCPManager
|
||||||
from agentkit.quality.gate import QualityGate
|
from agentkit.quality.gate import QualityGate
|
||||||
from agentkit.quality.output import OutputStandardizer
|
from agentkit.quality.output import OutputStandardizer
|
||||||
|
|
@ -83,7 +87,8 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||||
|
|
||||||
for name, pconf in config.llm_config.providers.items():
|
for name, pconf in config.llm_config.providers.items():
|
||||||
if not pconf.api_key:
|
# U15: 跳过既无 plaintext api_key 又无加密列的 provider
|
||||||
|
if not pconf.api_key and not pconf.api_key_encrypted:
|
||||||
continue # Skip providers without API keys
|
continue # Skip providers without API keys
|
||||||
try:
|
try:
|
||||||
provider = _create_provider(name, pconf)
|
provider = _create_provider(name, pconf)
|
||||||
|
|
@ -97,45 +102,38 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||||
def _create_provider(name: str, pconf) -> object:
|
def _create_provider(name: str, pconf) -> object:
|
||||||
"""Create an LLM provider instance from ProviderConfig.
|
"""Create an LLM provider instance from ProviderConfig.
|
||||||
|
|
||||||
Shared by server app and CLI chat to avoid duplicated initialization logic.
|
U15: 统一使用 LitellmProvider 替换 6 个直接 API provider 适配器。
|
||||||
|
旧的 AnthropicProvider / GeminiProvider / OpenAICompatibleProvider 类保留
|
||||||
|
导入(向后兼容 / 回滚),但新代码路径走 LiteLLM。
|
||||||
|
|
||||||
|
ponytail: 加密 key 的异步解密(aget_api_key)在启动时调用需 async 上下文,
|
||||||
|
此处为同步函数 — 双读窗口保证 plaintext ``api_key`` 在迁移期仍可用,
|
||||||
|
故直接用 ``pconf.api_key``。完全迁移后(plaintext 清空)需在此处前增加
|
||||||
|
async 预解密步骤。升级路径:把 _build_llm_gateway 改为 async 并在 lifespan
|
||||||
|
中调用。
|
||||||
"""
|
"""
|
||||||
if pconf.type == "anthropic":
|
# api_key 解析:双读窗口优先 plaintext(同步可读);plaintext 为空时
|
||||||
return AnthropicProvider(
|
# 调用方应在此之前完成 async 解密并写回 pconf.api_key。
|
||||||
api_key=pconf.api_key,
|
api_key = pconf.api_key
|
||||||
model=list(pconf.models.keys())[0] if pconf.models else "claude-sonnet-4-20250514",
|
if not api_key and pconf.api_key_encrypted:
|
||||||
max_tokens=pconf.max_tokens,
|
# plaintext 已清空但加密列存在 — 同步上下文无法解密,跳过注册
|
||||||
base_url=pconf.base_url or "https://api.anthropic.com",
|
# (由 lifespan 异步预解密路径补注册,或运维回滚 plaintext)
|
||||||
timeout=pconf.timeout,
|
logger.warning(
|
||||||
max_connections=pconf.max_connections,
|
f"Provider '{name}' has encrypted key but no plaintext fallback "
|
||||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
f"— skipped (run async pre-decrypt at startup)"
|
||||||
keepalive_expiry=pconf.keepalive_expiry,
|
|
||||||
)
|
)
|
||||||
elif pconf.type == "gemini":
|
raise ValueError(
|
||||||
return GeminiProvider(
|
f"Provider '{name}' api_key not available synchronously "
|
||||||
api_key=pconf.api_key,
|
f"(encrypted only). Run async pre-decrypt first."
|
||||||
model=list(pconf.models.keys())[0] if pconf.models else "gemini-2.0-flash",
|
|
||||||
max_output_tokens=pconf.max_tokens,
|
|
||||||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
|
||||||
timeout=pconf.timeout,
|
|
||||||
max_connections=pconf.max_connections,
|
|
||||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
|
||||||
keepalive_expiry=pconf.keepalive_expiry,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not pconf.base_url:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider '{name}' is missing base_url. "
|
|
||||||
f"OpenAI-compatible providers require an explicit base_url in config."
|
|
||||||
)
|
|
||||||
return OpenAICompatibleProvider(
|
|
||||||
api_key=pconf.api_key,
|
|
||||||
base_url=pconf.base_url,
|
|
||||||
max_connections=pconf.max_connections,
|
|
||||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
|
||||||
keepalive_expiry=pconf.keepalive_expiry,
|
|
||||||
timeout=pconf.timeout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
base_url = pconf.base_url or None
|
||||||
|
return create_litellm_provider(
|
||||||
|
provider_type=pconf.type,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
|
def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
|
||||||
"""Build SkillRegistry from ServerConfig, loading all skill configs."""
|
"""Build SkillRegistry from ServerConfig, loading all skill configs."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,287 @@
|
||||||
|
"""U15 — ProviderConfig API Key 加密迁移测试。
|
||||||
|
|
||||||
|
覆盖:
|
||||||
|
- get_api_key plaintext 回退
|
||||||
|
- aget_api_key + secrets_store 解密
|
||||||
|
- aget_api_key dual-read 回退(解密失败 → plaintext)
|
||||||
|
- migrate_to_secrets 完整流程
|
||||||
|
- migrate_to_secrets 幂等
|
||||||
|
- LLMConfig.from_dict 解析新字段
|
||||||
|
- LLMConfig.from_dict 缺省新字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from agentkit.channels.secrets import SecretsStore
|
||||||
|
from agentkit.llm.config import LLMConfig, ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 15: get_api_key plaintext
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_api_key_plaintext_no_store():
|
||||||
|
"""无 secrets_store 时 get_api_key 返回 plaintext。"""
|
||||||
|
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
|
||||||
|
assert pconf.get_api_key(None) == "sk-xxx"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_api_key_plaintext_with_store_sync():
|
||||||
|
"""有 secrets_store 但同步调用 — 仍返回 plaintext(async 解密不可用)。"""
|
||||||
|
pconf = ProviderConfig(api_key="sk-xxx", base_url="", type="openai")
|
||||||
|
# 即使传 store,同步路径无法 decrypt,回退 plaintext
|
||||||
|
store = object() # 任意非 None 对象
|
||||||
|
assert pconf.get_api_key(store) == "sk-xxx" # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 16: aget_api_key + secrets_store 解密成功
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_api_key_decrypts_from_store():
|
||||||
|
"""加密列 + secrets_store → 解密返回 key。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
# 先在 store 里放一个 key
|
||||||
|
await store.set_secret("llm:provider:openai:api_key", "sk-decrypted")
|
||||||
|
# 构造一个加密的 ProviderConfig(直接用真实加密流程)
|
||||||
|
pconf = ProviderConfig(
|
||||||
|
api_key="",
|
||||||
|
base_url="",
|
||||||
|
type="openai",
|
||||||
|
api_key_encrypted="", # 占位,下面重新生成
|
||||||
|
api_key_source="secrets_store",
|
||||||
|
)
|
||||||
|
# 用 store 加密 "sk-decrypted" 得到真实 encrypted 字段
|
||||||
|
entry = await store.set_secret("llm:provider:openai:api_key", "sk-decrypted")
|
||||||
|
pconf.api_key_encrypted = ProviderConfig._encode_secret_entry(
|
||||||
|
entry, "llm:provider:openai:api_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await pconf.aget_api_key(store)
|
||||||
|
assert result == "sk-decrypted"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 17: aget_api_key dual-read 回退
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_api_key_dual_read_fallback_to_plaintext():
|
||||||
|
"""加密列存在但解密失败 → 回退 plaintext。"""
|
||||||
|
# secrets_store 返回 None(模拟 key 不存在 / 解密失败)
|
||||||
|
store = AsyncMock(spec=SecretsStore)
|
||||||
|
store.get_secret = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
pconf = ProviderConfig(
|
||||||
|
api_key="sk-plaintext",
|
||||||
|
base_url="",
|
||||||
|
type="openai",
|
||||||
|
api_key_encrypted=json.dumps(
|
||||||
|
{
|
||||||
|
"key": "llm:provider:openai:api_key",
|
||||||
|
"value": "fake",
|
||||||
|
"nonce": "fake",
|
||||||
|
"salt": "fake",
|
||||||
|
"key_id": "default",
|
||||||
|
"created_at": "",
|
||||||
|
"updated_at": "",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
api_key_source="dual",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await pconf.aget_api_key(store)
|
||||||
|
assert result == "sk-plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_api_key_decrypt_exception_falls_back():
|
||||||
|
"""解密抛异常时回退 plaintext。"""
|
||||||
|
store = AsyncMock(spec=SecretsStore)
|
||||||
|
store.get_secret = AsyncMock(side_effect=RuntimeError("decrypt failed"))
|
||||||
|
|
||||||
|
pconf = ProviderConfig(
|
||||||
|
api_key="sk-plaintext",
|
||||||
|
base_url="",
|
||||||
|
type="openai",
|
||||||
|
api_key_encrypted=json.dumps(
|
||||||
|
{
|
||||||
|
"key": "llm:provider:openai:api_key",
|
||||||
|
"value": "fake",
|
||||||
|
"nonce": "fake",
|
||||||
|
"salt": "fake",
|
||||||
|
"key_id": "default",
|
||||||
|
"created_at": "",
|
||||||
|
"updated_at": "",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
api_key_source="dual",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await pconf.aget_api_key(store)
|
||||||
|
assert result == "sk-plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_api_key_no_encrypted_returns_plaintext():
|
||||||
|
"""无加密列时直接返回 plaintext。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
pconf = ProviderConfig(api_key="sk-plain", base_url="", type="openai")
|
||||||
|
assert await pconf.aget_api_key(store) == "sk-plain"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 18: migrate_to_secrets 完整流程
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrate_to_secrets_full_flow():
|
||||||
|
"""plaintext → secrets_store 迁移:加密、验证、清空 plaintext。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
pconf = ProviderConfig(
|
||||||
|
api_key="sk-secret",
|
||||||
|
base_url="",
|
||||||
|
type="openai",
|
||||||
|
)
|
||||||
|
|
||||||
|
await pconf.migrate_to_secrets(store)
|
||||||
|
|
||||||
|
assert pconf.api_key == "" # plaintext 清空
|
||||||
|
assert pconf.api_key_encrypted is not None # 加密列写入
|
||||||
|
assert pconf.api_key_source == "secrets_store"
|
||||||
|
|
||||||
|
# 验证:通过 aget_api_key 能读回原 key
|
||||||
|
decrypted = await pconf.aget_api_key(store)
|
||||||
|
assert decrypted == "sk-secret"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 19: migrate_to_secrets 幂等
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrate_to_secrets_idempotent():
|
||||||
|
"""已迁移的 config 再次调用 migrate_to_secrets 是 no-op。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
pconf = ProviderConfig(api_key="sk-secret", base_url="", type="openai")
|
||||||
|
|
||||||
|
await pconf.migrate_to_secrets(store)
|
||||||
|
encrypted_after_first = pconf.api_key_encrypted
|
||||||
|
assert pconf.api_key_source == "secrets_store"
|
||||||
|
|
||||||
|
# 第二次调用 — no-op
|
||||||
|
await pconf.migrate_to_secrets(store)
|
||||||
|
assert pconf.api_key_encrypted == encrypted_after_first
|
||||||
|
assert pconf.api_key_source == "secrets_store"
|
||||||
|
assert pconf.api_key == ""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrate_to_secrets_skips_empty_plaintext():
|
||||||
|
"""plaintext 为空时跳过迁移。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
pconf = ProviderConfig(api_key="", base_url="", type="openai")
|
||||||
|
|
||||||
|
await pconf.migrate_to_secrets(store)
|
||||||
|
|
||||||
|
assert pconf.api_key == ""
|
||||||
|
assert pconf.api_key_encrypted is None
|
||||||
|
assert pconf.api_key_source == "plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 20-21: LLMConfig.from_dict 解析新字段
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_parses_new_encrypted_fields():
|
||||||
|
"""LLMConfig.from_dict 解析 api_key_encrypted / api_key_source。"""
|
||||||
|
data = {
|
||||||
|
"providers": {
|
||||||
|
"openai": {
|
||||||
|
"api_key": "",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"type": "openai",
|
||||||
|
"api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}',
|
||||||
|
"api_key_source": "secrets_store",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = LLMConfig.from_dict(data)
|
||||||
|
pconf = config.providers["openai"]
|
||||||
|
assert pconf.api_key_encrypted == '{"key":"k","value":"v","nonce":"n","salt":"s"}'
|
||||||
|
assert pconf.api_key_source == "secrets_store"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_defaults_new_fields_to_plaintext():
|
||||||
|
"""LLMConfig.from_dict 缺省新字段时默认 plaintext 行为。"""
|
||||||
|
data = {
|
||||||
|
"providers": {
|
||||||
|
"openai": {
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"type": "openai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = LLMConfig.from_dict(data)
|
||||||
|
pconf = config.providers["openai"]
|
||||||
|
assert pconf.api_key_encrypted is None
|
||||||
|
assert pconf.api_key_source == "plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_dual_source_during_migration():
|
||||||
|
"""双写窗口:api_key_source="dual" 时两个字段都保留。"""
|
||||||
|
data = {
|
||||||
|
"providers": {
|
||||||
|
"openai": {
|
||||||
|
"api_key": "sk-plaintext",
|
||||||
|
"base_url": "",
|
||||||
|
"type": "openai",
|
||||||
|
"api_key_encrypted": '{"key":"k","value":"v","nonce":"n","salt":"s"}',
|
||||||
|
"api_key_source": "dual",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = LLMConfig.from_dict(data)
|
||||||
|
pconf = config.providers["openai"]
|
||||||
|
assert pconf.api_key == "sk-plaintext"
|
||||||
|
assert pconf.api_key_encrypted is not None
|
||||||
|
assert pconf.api_key_source == "dual"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:_secret_key_for_type 命名空间
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret_key_namespaced_by_type():
|
||||||
|
"""secret key 应按 provider type 命名空间隔离。"""
|
||||||
|
pconf = ProviderConfig(api_key="", base_url="", type="anthropic")
|
||||||
|
assert pconf._secret_key_for_type() == "llm:provider:anthropic:api_key"
|
||||||
|
|
||||||
|
pconf2 = ProviderConfig(api_key="", base_url="", type="gemini")
|
||||||
|
assert pconf2._secret_key_for_type() == "llm:provider:gemini:api_key"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:encode/decode SecretEntry 往返
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_encode_decode_secret_entry_roundtrip():
|
||||||
|
"""encode → decode 应保留关键字段。"""
|
||||||
|
store = SecretsStore(master_key=b"x" * 32)
|
||||||
|
entry = await store.set_secret("llm:provider:openai:api_key", "sk-xxx")
|
||||||
|
encoded = ProviderConfig._encode_secret_entry(entry, "llm:provider:openai:api_key")
|
||||||
|
decoded = ProviderConfig._decode_secret_entry(encoded)
|
||||||
|
|
||||||
|
assert decoded.key == "llm:provider:openai:api_key"
|
||||||
|
assert decoded.value == entry.value
|
||||||
|
assert decoded.nonce == entry.nonce
|
||||||
|
assert decoded.salt == entry.salt
|
||||||
|
|
@ -0,0 +1,543 @@
|
||||||
|
"""U15 — LitellmProvider 单元测试。
|
||||||
|
|
||||||
|
用 ``unittest.mock.AsyncMock`` mock ``litellm.acompletion``,覆盖:
|
||||||
|
- 6 个 provider 类型(openai/anthropic/gemini/doubao/wenxin/yuanbao)经 LiteLLM 调用
|
||||||
|
- 未知 provider_type 回退到 openai/ 前缀
|
||||||
|
- 自定义 base_url 透传
|
||||||
|
- tools 透传 + tool_calls 响应解析
|
||||||
|
- 流式 chunk 迭代 + 终止 chunk
|
||||||
|
- 流式空消息处理
|
||||||
|
- LiteLLM 异常 → LLMProviderError
|
||||||
|
- latency_ms 非负
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
|
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||||
|
from agentkit.llm.providers.litellm_provider import (
|
||||||
|
LitellmProvider,
|
||||||
|
_model_prefix_for,
|
||||||
|
create_litellm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 测试辅助:构造 OpenAI 格式的 fake response / chunk
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_response(
|
||||||
|
content: str = "Hello!",
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
prompt_tokens: int = 10,
|
||||||
|
completion_tokens: int = 5,
|
||||||
|
tool_calls: list | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
"""构造非流式 fake response(OpenAI ChatCompletion 格式)。"""
|
||||||
|
message = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||||
|
return SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=message)],
|
||||||
|
usage=SimpleNamespace(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_tool_call(
|
||||||
|
tc_id: str = "call_1",
|
||||||
|
name: str = "get_weather",
|
||||||
|
arguments: str = '{"city": "Beijing"}',
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
"""构造非流式 tool_call 对象。"""
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=tc_id,
|
||||||
|
function=SimpleNamespace(name=name, arguments=arguments),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_stream_chunk(
|
||||||
|
content: str = "",
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
tool_calls_delta: list | None = None,
|
||||||
|
usage: SimpleNamespace | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
"""构造流式 chunk(OpenAI ChatCompletionChunk 格式)。"""
|
||||||
|
delta = SimpleNamespace(content=content, tool_calls=tool_calls_delta)
|
||||||
|
return SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(delta=delta)],
|
||||||
|
model=model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_stream_tool_call_delta(
|
||||||
|
index: int = 0,
|
||||||
|
tc_id: str | None = "call_1",
|
||||||
|
name: str | None = "get_weather",
|
||||||
|
arguments_fragment: str | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
"""构造流式 tool_call delta 片段。"""
|
||||||
|
return SimpleNamespace(
|
||||||
|
index=index,
|
||||||
|
id=tc_id,
|
||||||
|
function=SimpleNamespace(name=name, arguments=arguments_fragment),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 1-7: provider_type → model_prefix 映射 + 基本 chat
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_type, expected_prefix",
|
||||||
|
[
|
||||||
|
("openai", "openai/"),
|
||||||
|
("anthropic", "anthropic/"),
|
||||||
|
("gemini", "gemini/"),
|
||||||
|
("doubao", "volcengine/"),
|
||||||
|
("wenxin", "openai/"), # ponytail: wenxin 回退到 openai/
|
||||||
|
("yuanbao", "hunyuan/"),
|
||||||
|
("unknown_type", "openai/"), # 未知 → openai/ 回退
|
||||||
|
],
|
||||||
|
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao", "unknown"],
|
||||||
|
)
|
||||||
|
def test_model_prefix_mapping(provider_type: str, expected_prefix: str):
|
||||||
|
"""验证 provider_type → LiteLLM model 前缀映射(含未知类型回退)。"""
|
||||||
|
assert _model_prefix_for(provider_type) == expected_prefix
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider_type, expected_prefix, model_name",
|
||||||
|
[
|
||||||
|
("openai", "openai/", "gpt-4o-mini"),
|
||||||
|
("anthropic", "anthropic/", "claude-sonnet-4-20250514"),
|
||||||
|
("gemini", "gemini/", "gemini-2.0-flash"),
|
||||||
|
("doubao", "volcengine/", "doubao-pro-32k"),
|
||||||
|
("wenxin", "openai/", "ernie-4.5-turbo-128k"),
|
||||||
|
("yuanbao", "hunyuan/", "hunyuan-pro"),
|
||||||
|
],
|
||||||
|
ids=["openai", "anthropic", "gemini", "doubao", "wenxin", "yuanbao"],
|
||||||
|
)
|
||||||
|
async def test_chat_via_litellm_for_each_provider_type(
|
||||||
|
provider_type: str,
|
||||||
|
expected_prefix: str,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
"""6 个 provider 类型经 LiteLLM 调用 — 验证 model 前缀和响应翻译。"""
|
||||||
|
provider = create_litellm_provider(provider_type, api_key="sk-test")
|
||||||
|
assert provider._model_prefix == expected_prefix
|
||||||
|
|
||||||
|
fake_resp = _fake_response(content="ok", model=model_name)
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, LLMResponse)
|
||||||
|
assert response.content == "ok"
|
||||||
|
assert response.model == model_name
|
||||||
|
assert response.usage.prompt_tokens == 10
|
||||||
|
assert response.usage.completion_tokens == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 8: 自定义 base_url 透传
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_custom_base_url_passed_through():
|
||||||
|
"""自定义 base_url 应作为 api_base 传给 litellm.acompletion。"""
|
||||||
|
provider = create_litellm_provider(
|
||||||
|
"openai",
|
||||||
|
api_key="sk-test",
|
||||||
|
base_url="https://custom.api/v1",
|
||||||
|
)
|
||||||
|
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||||
|
with patch("litellm.acompletion", new=mock_acompletion):
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
|
||||||
|
)
|
||||||
|
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["api_base"] == "https://custom.api/v1"
|
||||||
|
assert kwargs["api_key"] == "sk-test"
|
||||||
|
assert kwargs["model"] == "openai/gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 9: tools 透传
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tools_passed_through():
|
||||||
|
"""request.tools 应透传给 litellm.acompletion。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||||
|
with patch("litellm.acompletion", new=mock_acompletion):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": "weather?"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["tools"] == tools
|
||||||
|
assert kwargs["tool_choice"] == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 10: tool_calls 响应解析
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tool_calls_in_response_parsed():
|
||||||
|
"""非流式响应中的 tool_calls 应解析成 LLMResponse.tool_calls。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
fake_resp = _fake_response(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
_fake_tool_call(
|
||||||
|
tc_id="call_abc",
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Beijing"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.has_tool_calls
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
tc = response.tool_calls[0]
|
||||||
|
assert tc.id == "call_abc"
|
||||||
|
assert tc.name == "get_weather"
|
||||||
|
assert tc.arguments == {"city": "Beijing"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tool_calls_with_dict_arguments():
|
||||||
|
"""tool_call.arguments 为 dict 时直接采用。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
fake_resp = _fake_response(
|
||||||
|
tool_calls=[
|
||||||
|
SimpleNamespace(
|
||||||
|
id="call_1",
|
||||||
|
function=SimpleNamespace(name="fn", arguments={"k": "v"}),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.tool_calls[0].arguments == {"k": "v"}
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 11: 流式 yields chunks
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streaming_yields_chunks():
|
||||||
|
"""流式应 yield 多个 is_final=False chunk + 一个 is_final=True 终止 chunk。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
_fake_stream_chunk(content="Hello"),
|
||||||
|
_fake_stream_chunk(content=" world"),
|
||||||
|
_fake_stream_chunk(
|
||||||
|
usage=SimpleNamespace(prompt_tokens=8, completion_tokens=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
for c in chunks:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=_fake_stream):
|
||||||
|
results = []
|
||||||
|
async for chunk in provider.chat_stream(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini")
|
||||||
|
):
|
||||||
|
results.append(chunk)
|
||||||
|
|
||||||
|
# 前 3 个是流式 chunk(is_final=False),最后一个是终止 chunk(is_final=True)
|
||||||
|
assert len(results) == 4
|
||||||
|
assert results[0].content == "Hello"
|
||||||
|
assert results[1].content == " world"
|
||||||
|
assert all(not r.is_final for r in results[:3])
|
||||||
|
assert results[-1].is_final is True
|
||||||
|
# 终止 chunk 含聚合 usage
|
||||||
|
assert results[-1].usage is not None
|
||||||
|
assert results[-1].usage.prompt_tokens == 8
|
||||||
|
assert results[-1].usage.completion_tokens == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streaming_aggregates_tool_calls():
|
||||||
|
"""流式 tool_calls 片段应聚合到终止 chunk。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
_fake_stream_chunk(
|
||||||
|
tool_calls_delta=[
|
||||||
|
_fake_stream_tool_call_delta(
|
||||||
|
index=0,
|
||||||
|
tc_id="call_1",
|
||||||
|
name="get_weather",
|
||||||
|
arguments_fragment='{"city":',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_fake_stream_chunk(
|
||||||
|
tool_calls_delta=[
|
||||||
|
_fake_stream_tool_call_delta(
|
||||||
|
index=0,
|
||||||
|
tc_id=None,
|
||||||
|
name=None,
|
||||||
|
arguments_fragment=' "Beijing"}',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
for c in chunks:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=_fake_stream):
|
||||||
|
results = []
|
||||||
|
async for chunk in provider.chat_stream(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
):
|
||||||
|
results.append(chunk)
|
||||||
|
|
||||||
|
final = results[-1]
|
||||||
|
assert final.is_final
|
||||||
|
assert len(final.tool_calls) == 1
|
||||||
|
assert final.tool_calls[0].id == "call_1"
|
||||||
|
assert final.tool_calls[0].name == "get_weather"
|
||||||
|
assert final.tool_calls[0].arguments == {"city": "Beijing"}
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 12: 流式空消息处理
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streaming_empty_messages_handled():
|
||||||
|
"""流式无 chunk 时仍应 yield 一个 is_final=True 空终止 chunk,不崩溃。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
async def _empty_stream(**_kwargs):
|
||||||
|
return
|
||||||
|
yield # 让函数成为 async generator(永不执行)
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=_empty_stream):
|
||||||
|
results = []
|
||||||
|
async for chunk in provider.chat_stream(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
):
|
||||||
|
results.append(chunk)
|
||||||
|
|
||||||
|
# 至少 yield 一个终止 chunk
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert results[-1].is_final is True
|
||||||
|
assert results[-1].content == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 13: LiteLLM 异常 → LLMProviderError
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_litellm_exception_raises_provider_error():
|
||||||
|
"""litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(side_effect=RuntimeError("boom"))):
|
||||||
|
with pytest.raises(LLMProviderError) as exc_info:
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "boom" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_litellm_stream_exception_raises_provider_error():
|
||||||
|
"""流式 litellm.acompletion 抛异常时应包装成 LLMProviderError。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
async def _failing_stream(**_kwargs):
|
||||||
|
raise RuntimeError("stream boom")
|
||||||
|
yield # unreachable,仅为 async generator 语法
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=_failing_stream):
|
||||||
|
with pytest.raises(LLMProviderError) as exc_info:
|
||||||
|
async for _ in provider.chat_stream(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "stream boom" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 14: latency 非负
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_latency_measured_non_negative():
|
||||||
|
"""LLMResponse.latency_ms 应为非负数。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=_fake_response())):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.latency_ms >= 0
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:num_retries=0 禁用 LiteLLM 自带 retry
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_litellm_num_retries_disabled():
|
||||||
|
"""应传 num_retries=0 禁用 LiteLLM 自带 retry(由 gateway fallback 负责)。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||||
|
with patch("litellm.acompletion", new=mock_acompletion):
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["num_retries"] == 0
|
||||||
|
assert kwargs["stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streaming_passes_stream_true():
|
||||||
|
"""chat_stream 应传 stream=True。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
async def _capturing_stream(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
|
||||||
|
async def _inner():
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
return _inner()
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", new=_capturing_stream):
|
||||||
|
async for _ in provider.chat_stream(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured["stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:timeout 透传
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_timeout_passed_through():
|
||||||
|
"""request.timeout 应透传给 litellm.acompletion。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||||
|
with patch("litellm.acompletion", new=mock_acompletion):
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
timeout=42.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["timeout"] == 42.0
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:LitellmProvider 直接构造(不经工厂)
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_direct_litellm_provider_construction():
|
||||||
|
"""直接用 LitellmProvider(model_prefix=...) 构造也应工作。"""
|
||||||
|
provider = LitellmProvider(
|
||||||
|
model_prefix="anthropic/",
|
||||||
|
api_key="sk-test",
|
||||||
|
provider_type="anthropic",
|
||||||
|
)
|
||||||
|
fake_resp = _fake_response(content="hi", model="claude-3")
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "hi"
|
||||||
|
# 验证 model 前缀正确拼接
|
||||||
|
mock_acompletion = AsyncMock(return_value=fake_resp)
|
||||||
|
with patch("litellm.acompletion", new=mock_acompletion):
|
||||||
|
await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="claude-3")
|
||||||
|
)
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["model"] == "anthropic/claude-3"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# 额外:JSON tool_calls 解析容错
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tool_calls_invalid_json_arguments_wrapped():
|
||||||
|
"""tool_call.arguments 为非法 JSON 时应包装成 {"raw": ...} 不崩溃。"""
|
||||||
|
provider = create_litellm_provider("openai", api_key="sk-test")
|
||||||
|
fake_resp = _fake_response(
|
||||||
|
tool_calls=[
|
||||||
|
_fake_tool_call(
|
||||||
|
tc_id="call_1",
|
||||||
|
name="fn",
|
||||||
|
arguments="not-valid-json{",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch("litellm.acompletion", new=AsyncMock(return_value=fake_resp)):
|
||||||
|
response = await provider.chat(
|
||||||
|
LLMRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.tool_calls[0].arguments == {"raw": "not-valid-json{"}
|
||||||
Loading…
Reference in New Issue