fischer-agentkit/src/agentkit/llm/gateway.py

211 lines
7.2 KiB
Python

"""LLM Gateway - 统一 LLM 调用入口"""
import logging
import time
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
logger = logging.getLogger(__name__)
class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪"""
def __init__(self, config: LLMConfig | None = None):
self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker()
self._config = config or LLMConfig()
def register_provider(self, name: str, provider: LLMProvider) -> None:
"""注册 Provider"""
self._providers[name] = provider
logger.info(f"LLM provider '{name}' registered")
@property
def has_providers(self) -> bool:
"""Return True if at least one LLM provider is registered."""
return bool(self._providers)
async def chat(
self,
messages: list[dict[str, str]],
model: str,
agent_name: str = "",
task_type: str = "",
tools: list[dict] | None = None,
tool_choice: str = "auto",
**kwargs,
) -> LLMResponse:
"""发送 chat 请求,自动解析别名和 Fallback"""
resolved_model = self._resolve_model_alias(model)
if not self._providers:
raise LLMProviderError("", "No provider registered")
try:
provider, actual_model = self._resolve_model(resolved_model)
except ModelNotFoundError as e:
raise LLMProviderError("", str(e)) from e
request = LLMRequest(
messages=messages,
model=actual_model,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
start = time.monotonic()
try:
response = await provider.chat(request)
except LLMProviderError:
# 遍历所有 fallback 模型逐一尝试
fallback_models = self._config.fallbacks.get(resolved_model, [])
last_error = None
for fb_model in fallback_models:
try:
logger.warning(f"Model '{resolved_model}' failed, falling back to '{fb_model}'")
fb_provider, fb_actual = self._resolve_model(fb_model)
fb_request = LLMRequest(
messages=messages,
model=fb_actual,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
response = await fb_provider.chat(fb_request)
break
except LLMProviderError as e:
last_error = e
logger.warning(f"Fallback model '{fb_model}' also failed: {e}")
continue
else:
# 所有 fallback 都失败
raise last_error or LLMProviderError("", f"All models failed for '{resolved_model}'")
latency_ms = (time.monotonic() - start) * 1000
# 计算成本
cost = self._calculate_cost(response.model, response.usage)
# 记录使用量
self._usage_tracker.record(
agent_name=agent_name,
model=response.model,
usage=response.usage,
cost=cost,
latency_ms=latency_ms,
)
return response
async def chat_stream(
self,
messages: list[dict[str, str]],
model: str,
agent_name: str = "",
task_type: str = "",
tools: list[dict] | None = None,
tool_choice: str = "auto",
**kwargs,
):
"""Stream chat response, yielding StreamChunk objects"""
resolved_model = self._resolve_model_alias(model)
if not self._providers:
raise LLMProviderError("", "No provider registered")
try:
provider, actual_model = self._resolve_model(resolved_model)
except ModelNotFoundError as e:
raise LLMProviderError("", str(e)) from e
request = LLMRequest(
messages=messages,
model=actual_model,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
start = time.monotonic()
total_content = ""
final_usage = None
final_model = resolved_model
async for chunk in provider.chat_stream(request):
if chunk.content:
total_content += chunk.content
if chunk.usage:
final_usage = chunk.usage
if chunk.model:
final_model = chunk.model
yield chunk
# Track usage after stream completes
latency_ms = (time.monotonic() - start) * 1000
if final_usage is None:
final_usage = TokenUsage()
cost = self._calculate_cost(final_model, final_usage)
self._usage_tracker.record(
agent_name=agent_name,
model=final_model,
usage=final_usage,
cost=cost,
latency_ms=latency_ms,
)
def _resolve_model_alias(self, model: str) -> str:
"""解析模型别名"""
if model in self._config.model_aliases:
return self._config.model_aliases[model]
return model
def _resolve_model(self, model: str) -> tuple[LLMProvider, str]:
"""解析模型为 (provider, actual_model_name)"""
# model 格式: "provider/model_name" 或 "model_name"
if "/" in model:
provider_name, model_name = model.split("/", 1)
if provider_name not in self._providers:
raise ModelNotFoundError(model)
return self._providers[provider_name], model_name
# 无 "/" 前缀:仅当只有一个 provider 时自动匹配
if len(self._providers) == 1:
provider = next(iter(self._providers.values()))
return provider, model
raise ModelNotFoundError(model)
def _get_fallback_model(self, model: str) -> str | None:
"""获取 Fallback 模型"""
fallbacks = self._config.fallbacks.get(model, [])
return fallbacks[0] if fallbacks else None
def _calculate_cost(self, model: str, usage: TokenUsage) -> float:
"""计算成本"""
# 在 provider config 的 models 中查找成本配置
for provider_config in self._config.providers.values():
if model in provider_config.models:
model_conf = provider_config.models[model]
input_cost = usage.prompt_tokens * model_conf.get("cost_per_1k_input", 0) / 1000
output_cost = usage.completion_tokens * model_conf.get("cost_per_1k_output", 0) / 1000
return input_cost + output_cost
return 0.0
def get_usage(
self,
agent_name: str | None = None,
start_time=None,
end_time=None,
) -> UsageSummary:
"""查询使用量"""
return self._usage_tracker.get_usage(
agent_name=agent_name,
start_time=start_time,
end_time=end_time,
)