211 lines
7.2 KiB
Python
211 lines
7.2 KiB
Python
"""LLM Gateway - 统一 LLM 调用入口"""
|
|
|
|
import logging
|
|
import time
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.config import LLMConfig
|
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
|
|
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMGateway:
|
|
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪"""
|
|
|
|
def __init__(self, config: LLMConfig | None = None):
|
|
self._providers: dict[str, LLMProvider] = {}
|
|
self._usage_tracker = UsageTracker()
|
|
self._config = config or LLMConfig()
|
|
|
|
def register_provider(self, name: str, provider: LLMProvider) -> None:
|
|
"""注册 Provider"""
|
|
self._providers[name] = provider
|
|
logger.info(f"LLM provider '{name}' registered")
|
|
|
|
@property
|
|
def has_providers(self) -> bool:
|
|
"""Return True if at least one LLM provider is registered."""
|
|
return bool(self._providers)
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
model: str,
|
|
agent_name: str = "",
|
|
task_type: str = "",
|
|
tools: list[dict] | None = None,
|
|
tool_choice: str = "auto",
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""发送 chat 请求,自动解析别名和 Fallback"""
|
|
resolved_model = self._resolve_model_alias(model)
|
|
|
|
if not self._providers:
|
|
raise LLMProviderError("", "No provider registered")
|
|
|
|
try:
|
|
provider, actual_model = self._resolve_model(resolved_model)
|
|
except ModelNotFoundError as e:
|
|
raise LLMProviderError("", str(e)) from e
|
|
|
|
request = LLMRequest(
|
|
messages=messages,
|
|
model=actual_model,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
**kwargs,
|
|
)
|
|
|
|
start = time.monotonic()
|
|
try:
|
|
response = await provider.chat(request)
|
|
except LLMProviderError:
|
|
# 遍历所有 fallback 模型逐一尝试
|
|
fallback_models = self._config.fallbacks.get(resolved_model, [])
|
|
last_error = None
|
|
for fb_model in fallback_models:
|
|
try:
|
|
logger.warning(f"Model '{resolved_model}' failed, falling back to '{fb_model}'")
|
|
fb_provider, fb_actual = self._resolve_model(fb_model)
|
|
fb_request = LLMRequest(
|
|
messages=messages,
|
|
model=fb_actual,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
**kwargs,
|
|
)
|
|
response = await fb_provider.chat(fb_request)
|
|
break
|
|
except LLMProviderError as e:
|
|
last_error = e
|
|
logger.warning(f"Fallback model '{fb_model}' also failed: {e}")
|
|
continue
|
|
else:
|
|
# 所有 fallback 都失败
|
|
raise last_error or LLMProviderError("", f"All models failed for '{resolved_model}'")
|
|
|
|
latency_ms = (time.monotonic() - start) * 1000
|
|
|
|
# 计算成本
|
|
cost = self._calculate_cost(response.model, response.usage)
|
|
|
|
# 记录使用量
|
|
self._usage_tracker.record(
|
|
agent_name=agent_name,
|
|
model=response.model,
|
|
usage=response.usage,
|
|
cost=cost,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
return response
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
model: str,
|
|
agent_name: str = "",
|
|
task_type: str = "",
|
|
tools: list[dict] | None = None,
|
|
tool_choice: str = "auto",
|
|
**kwargs,
|
|
):
|
|
"""Stream chat response, yielding StreamChunk objects"""
|
|
resolved_model = self._resolve_model_alias(model)
|
|
|
|
if not self._providers:
|
|
raise LLMProviderError("", "No provider registered")
|
|
|
|
try:
|
|
provider, actual_model = self._resolve_model(resolved_model)
|
|
except ModelNotFoundError as e:
|
|
raise LLMProviderError("", str(e)) from e
|
|
|
|
request = LLMRequest(
|
|
messages=messages,
|
|
model=actual_model,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
**kwargs,
|
|
)
|
|
|
|
start = time.monotonic()
|
|
total_content = ""
|
|
final_usage = None
|
|
final_model = resolved_model
|
|
|
|
async for chunk in provider.chat_stream(request):
|
|
if chunk.content:
|
|
total_content += chunk.content
|
|
if chunk.usage:
|
|
final_usage = chunk.usage
|
|
if chunk.model:
|
|
final_model = chunk.model
|
|
yield chunk
|
|
|
|
# Track usage after stream completes
|
|
latency_ms = (time.monotonic() - start) * 1000
|
|
if final_usage is None:
|
|
final_usage = TokenUsage()
|
|
cost = self._calculate_cost(final_model, final_usage)
|
|
self._usage_tracker.record(
|
|
agent_name=agent_name,
|
|
model=final_model,
|
|
usage=final_usage,
|
|
cost=cost,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
def _resolve_model_alias(self, model: str) -> str:
|
|
"""解析模型别名"""
|
|
if model in self._config.model_aliases:
|
|
return self._config.model_aliases[model]
|
|
return model
|
|
|
|
def _resolve_model(self, model: str) -> tuple[LLMProvider, str]:
|
|
"""解析模型为 (provider, actual_model_name)"""
|
|
# model 格式: "provider/model_name" 或 "model_name"
|
|
if "/" in model:
|
|
provider_name, model_name = model.split("/", 1)
|
|
if provider_name not in self._providers:
|
|
raise ModelNotFoundError(model)
|
|
return self._providers[provider_name], model_name
|
|
|
|
# 无 "/" 前缀:仅当只有一个 provider 时自动匹配
|
|
if len(self._providers) == 1:
|
|
provider = next(iter(self._providers.values()))
|
|
return provider, model
|
|
|
|
raise ModelNotFoundError(model)
|
|
|
|
def _get_fallback_model(self, model: str) -> str | None:
|
|
"""获取 Fallback 模型"""
|
|
fallbacks = self._config.fallbacks.get(model, [])
|
|
return fallbacks[0] if fallbacks else None
|
|
|
|
def _calculate_cost(self, model: str, usage: TokenUsage) -> float:
|
|
"""计算成本"""
|
|
# 在 provider config 的 models 中查找成本配置
|
|
for provider_config in self._config.providers.values():
|
|
if model in provider_config.models:
|
|
model_conf = provider_config.models[model]
|
|
input_cost = usage.prompt_tokens * model_conf.get("cost_per_1k_input", 0) / 1000
|
|
output_cost = usage.completion_tokens * model_conf.get("cost_per_1k_output", 0) / 1000
|
|
return input_cost + output_cost
|
|
return 0.0
|
|
|
|
def get_usage(
|
|
self,
|
|
agent_name: str | None = None,
|
|
start_time=None,
|
|
end_time=None,
|
|
) -> UsageSummary:
|
|
"""查询使用量"""
|
|
return self._usage_tracker.get_usage(
|
|
agent_name=agent_name,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|