fischer-agentkit/src/agentkit/llm/gateway.py

623 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM Gateway - 统一 LLM 调用入口"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import llm_token_histogram
logger = logging.getLogger(__name__)
class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded.
Carries enough metadata for the API layer to return a structured
429 response (department_id, quota_type, period, limit, current).
"""
def __init__(
self,
department_id: str,
quota_type: str,
period: str,
limit: Any,
current: Any,
) -> None:
self.department_id = department_id
self.quota_type = quota_type
self.period = period
self.limit = limit
self.current = current
super().__init__(
f"Quota exceeded for department {department_id}: "
f"{quota_type} {period} (limit={limit}, current={current})"
)
class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig()
# Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache_manager: Any = None # LitellmCacheManager | None
if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
litellm_config = LitellmCacheConfig.from_cache_config(self._config.cache)
self._cache_manager = LitellmCacheManager(litellm_config)
self._cache_manager.enable()
logger.info(
f"LLM cache enabled (LiteLLM, backend={self._config.cache.backend}, "
f"similarity_threshold={litellm_config.similarity_threshold})"
)
def register_provider(self, name: str, provider: LLMProvider) -> None:
"""注册 Provider"""
self._providers[name] = provider
logger.info(f"LLM provider '{name}' registered")
@property
def has_providers(self) -> bool:
"""Return True if at least one LLM provider is registered."""
return bool(self._providers)
async def chat(
self,
messages: list[dict[str, str]],
model: str,
agent_name: str = "",
task_type: str = "",
tools: list[dict] | None = None,
tool_choice: str = "auto",
timeout: float | None = None,
user_id: str | None = None,
department_ids: list[str] | None = None,
db_path: Path | str | None = None,
kb_id: str | None = None,
kb_acl_hash: str | None = None,
**kwargs,
) -> LLMResponse:
"""发送 chat 请求,自动解析别名和 Fallback"""
resolved_model = self._resolve_model_alias(model)
if not self._providers:
raise LLMProviderError("", "No provider registered")
# ── Quota enforcement ──
# Only enforce when department_ids + db_path are provided
# (other call sites pass None — no quota check).
if department_ids and db_path:
await self._enforce_quota(db_path, department_ids, resolved_model)
# Telemetry: start LLM span
_span_cm = None
_span = None
if _OTEL_AVAILABLE:
tracer = get_tracer()
if tracer is not None:
from opentelemetry.trace import SpanKind
_span_cm = tracer.start_as_current_span(
"gen_ai.chat",
kind=SpanKind.CLIENT,
attributes={
"gen_ai.system": resolved_model.split("/")[0]
if "/" in resolved_model
else "unknown",
"gen_ai.operation.name": "chat",
"gen_ai.request.model": resolved_model,
},
)
_span = _span_cm.__enter__()
start = time.monotonic()
# ── Cache check (U17 — LiteLLM cache via cache_key in request) ──
# LiteLLM 在 litellm.acompletion 内部处理缓存读写gateway 只需:
# 1. 构建 per-user + ACL-scoped cache_key安全要求 a, b
# 2. 将 cache 参数注入 kwargs 透传到 provider
# 3. 检测响应的 cache_hit 标志,用于 usage trackingcost=0
if self._cache_manager is not None:
from agentkit.llm.cache import LitellmCacheManager
# 解析 KB caching_disabled安全要求 c
# 非 RAG 请求kb_id=None→ 默认启用缓存(无 KB 数据需保护)。
# RAG 请求kb_id!=None→ fail-closed默认禁用缓存仅在 settings
# 明确返回 caching_disabled=False 时启用。防止 DB 异常时 fail-open
# 导致禁用缓存的 KB 数据泄漏到缓存。
kb_caching_disabled = kb_id is not None
if kb_id is not None:
try:
from agentkit.rag_platform.settings import get_settings_store
settings = await get_settings_store().get_settings(kb_id)
if settings is not None:
kb_caching_disabled = settings.caching_disabled
# settings 为 NoneKB 不存在)→ 保持 Truefail-closed
except Exception as e:
logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}")
# 读取异常 → 保持 Truefail-closed禁用缓存
if self._cache_manager.should_cache(kb_caching_disabled, user_id):
cache_key = self._cache_manager.build_cache_key(
model=resolved_model,
messages=messages,
temperature=kwargs.get("temperature", 0.7),
tools=tools,
tool_choice=tool_choice,
max_tokens=kwargs.get("max_tokens", 2000),
user_id=user_id,
kb_acl_hash=kb_acl_hash,
)
kwargs["cache"] = LitellmCacheManager.cache_params_for_hit(cache_key)
else:
kwargs["cache"] = LitellmCacheManager.cache_params_for_no_cache()
# ── Normal provider call ──
models_to_try = self._get_models_to_try(resolved_model)
last_error: LLMProviderError | None = None
response: LLMResponse | None = None
try:
for model_name in models_to_try:
try:
provider, actual_model = self._resolve_model(model_name)
except ModelNotFoundError:
continue
req = LLMRequest(
messages=messages,
model=actual_model,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
**kwargs,
)
try:
response = await provider.chat(req)
# Empty response detection: if content is None/empty and no tool_calls,
# treat as failure and try next fallback model.
# This handles the common case where providers return 200 OK but empty body.
if (
response.content is None or not response.content.strip()
) and not response.tool_calls:
# Record usage for billing before discarding this response
if response.usage:
latency_ms = (time.monotonic() - start) * 1000
cost = self._calculate_cost(model_name, response.usage)
await self._record_usage(
agent_name=agent_name,
model=model_name,
usage=response.usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
logger.warning(
f"Model '{model_name}' returned empty content with no tool_calls, "
f"trying next fallback"
)
last_error = LLMProviderError(
model_name,
f"Empty response from {model_name} (no content, no tool_calls)",
)
continue
break
except LLMProviderError as e:
last_error = e
logger.warning(f"Model '{model_name}' failed, trying next: {e}")
continue
else:
raise last_error or LLMProviderError(
"", f"All models failed for '{resolved_model}'"
)
latency_ms = (time.monotonic() - start) * 1000
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0
is_cache_hit = getattr(response, "cache_hit", False)
if self._cache_manager is not None:
self._cache_manager.record_cache_result(is_cache_hit)
# 计算成本(缓存命中时 cost=0
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
# 记录使用量
await self._record_usage(
agent_name=agent_name,
model=response.model,
usage=response.usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
# Telemetry: record token usage and end span
if _span is not None:
_span.set_attribute("gen_ai.usage.input_tokens", response.usage.prompt_tokens)
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
_span.set_attribute("gen_ai.response.model", response.model)
_span.set_attribute("gen_ai.duration.ms", int(latency_ms))
if self._cache_manager is not None:
_span.set_attribute("gen_ai.cache.hit", is_cache_hit)
llm_token_histogram().record(
response.usage.total_tokens,
{"gen_ai.request.model": resolved_model},
)
return response
finally:
if _span_cm is not None:
_span_cm.__exit__(None, None, None)
async def chat_stream(
self,
messages: list[dict[str, str]],
model: str,
agent_name: str = "",
task_type: str = "",
tools: list[dict] | None = None,
tool_choice: str = "auto",
timeout: float | None = None,
user_id: str | None = None,
department_ids: list[str] | None = None,
db_path: Path | str | None = None,
**kwargs,
):
"""Stream chat response with fallback support.
If the primary model fails before any chunk is yielded, tries fallback
models. If it fails after chunks have been sent, yields an error chunk
and terminates (cannot switch mid-stream).
Note: Streaming responses are NOT cached in this iteration.
"""
resolved_model = self._resolve_model_alias(model)
if not self._providers:
raise LLMProviderError("", "No provider registered")
# ── Quota enforcement ──
if department_ids and db_path:
await self._enforce_quota(db_path, department_ids, resolved_model)
models_to_try = self._get_models_to_try(resolved_model)
last_error: Exception | None = None
for model_name in models_to_try:
try:
provider, actual_model = self._resolve_model(model_name)
except ModelNotFoundError:
continue
stream_request = LLMRequest(
messages=messages,
model=actual_model,
tools=tools,
tool_choice=tool_choice,
timeout=timeout,
**kwargs,
)
chunk_yielded = False
start = time.monotonic()
total_content = ""
final_usage = None
final_model = model_name
try:
stream_obj = provider.chat_stream(stream_request)
# Defensive: guard against misconfigured providers (e.g. an
# AsyncMock in tests, or a future refactor that accidentally
# turns chat_stream into a regular ``async def``) that return
# a coroutine instead of an async generator. The original
# cryptic error ``'async for' requires an object with
# __aiter__ method, got coroutine`` becomes a clear,
# actionable message naming the offending provider+model.
if asyncio.iscoroutine(stream_obj):
logger.error(
f"Provider '{model_name}'.chat_stream returned a "
f"coroutine instead of an async generator. "
f"Check that the method is defined as "
f"``async def chat_stream(...): ...; yield ...``."
)
raise TypeError(
f"Provider '{model_name}' returned a coroutine "
f"from chat_stream() — expected an async "
f"generator. This indicates a provider "
f"implementation bug."
)
async for chunk in stream_obj:
chunk_yielded = True
if chunk.content:
total_content += chunk.content
if chunk.usage:
final_usage = chunk.usage
if chunk.model:
final_model = chunk.model
yield chunk
# Track usage after successful stream
latency_ms = (time.monotonic() - start) * 1000
if final_usage is None:
final_usage = TokenUsage()
cost = self._calculate_cost(final_model, final_usage)
await self._record_usage(
agent_name=agent_name,
model=final_model,
usage=final_usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
# Empty stream detection: if no content was produced,
# raise error so the caller (ReActEngine) can retry with a different model.
# We cannot continue to next model here because chunks may have already
# been yielded to the client, which would cause mixed output.
# Note: stream tool_calls are not tracked in chunks, so we only check content.
if not total_content.strip():
logger.warning(f"Stream from '{model_name}' produced empty content")
raise LLMProviderError(
model_name,
f"Empty stream from {model_name}",
)
return # Success, done
except Exception as e:
last_error = e
if chunk_yielded:
# Can't switch mid-stream, terminate gracefully
logger.error(f"Stream failed after chunks sent for '{model_name}': {e}")
yield StreamChunk(
content="",
model=final_model,
usage=None,
is_final=True,
)
return
# No chunks yet, try next fallback
logger.warning(f"Stream failed for '{model_name}', trying fallback: {e}")
continue
# All models failed
raise last_error or LLMProviderError(
"", f"No provider available for streaming '{resolved_model}'"
)
def _get_models_to_try(self, resolved_model: str) -> list[str]:
"""Return [primary_model] + fallback_models for the given resolved model."""
fallback_models = self._config.fallbacks.get(resolved_model, [])
return [resolved_model] + fallback_models
def _resolve_model_alias(self, model: str) -> str:
"""解析模型别名"""
if model in self._config.model_aliases:
return self._config.model_aliases[model]
return model
def get_provider_name_for_model(self, model: str) -> str | None:
"""返回 model 对应的 provider 名(用于 provider-specific 优化如 cache_control)。
ponytail: 仅做 alias 解析 + provider 前缀提取,不查内部状态。
升级路径:ServerConfig 显式声明 provider per model。
返回 None 表示无法确定(多 provider + 无 "/" 前缀),调用方应回退到字符串拼接。
"""
resolved = self._resolve_model_alias(model)
if "/" in resolved:
provider_name = resolved.split("/", 1)[0]
if provider_name in self._providers:
return provider_name
return None
# 无 "/" 前缀:仅当只有一个 provider 时能确定
if len(self._providers) == 1:
return next(iter(self._providers))
return None
def _resolve_model(self, model: str) -> tuple[LLMProvider, str]:
"""解析模型为 (provider, actual_model_name)"""
# model 格式: "provider/model_name" 或 "model_name"
if "/" in model:
provider_name, model_name = model.split("/", 1)
if provider_name not in self._providers:
raise ModelNotFoundError(model)
return self._providers[provider_name], model_name
# 无 "/" 前缀:仅当只有一个 provider 时自动匹配
if len(self._providers) == 1:
provider = next(iter(self._providers.values()))
return provider, model
raise ModelNotFoundError(model)
def _get_fallback_model(self, model: str) -> str | None:
"""获取 Fallback 模型"""
fallbacks = self._config.fallbacks.get(model, [])
return fallbacks[0] if fallbacks else None
def _calculate_cost(self, model: str, usage: TokenUsage) -> float:
"""计算成本"""
# 在 provider config 的 models 中查找成本配置
for provider_config in self._config.providers.values():
if model in provider_config.models:
model_conf = provider_config.models[model]
input_cost = usage.prompt_tokens * model_conf.get("cost_per_1k_input", 0) / 1000
output_cost = (
usage.completion_tokens * model_conf.get("cost_per_1k_output", 0) / 1000
)
return input_cost + output_cost
return 0.0
def get_usage(
self,
agent_name: str | None = None,
start_time=None,
end_time=None,
) -> UsageSummary:
"""查询使用量"""
return self._usage_tracker.get_usage(
agent_name=agent_name,
start_time=start_time,
end_time=end_time,
)
# ------------------------------------------------------------------
# Quota enforcement helpers (U7)
# ------------------------------------------------------------------
async def _record_usage(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None,
department_ids: list[str] | None,
) -> None:
"""Record a usage event via the async store interface (KTD-6).
Multi-department attribution (U2): when a user belongs to
multiple departments, a separate :class:`UsageRecord` is created
for each department. This ensures ``get_usage(dept_id)`` returns
the correct total for every department the user belongs to,
matching the quota check scope (which checks all departments).
TOCTOU (KTD-2): This method is called *after* the LLM response
is received. Between ``_enforce_quota`` (before the call) and
this recording, concurrent requests may push usage over the
limit. This race window is accepted; post-hoc reconciliation
(periodic scans for over-limit users) handles violations.
"""
if not department_ids:
# API key users (no departments) — record once with dept=None.
await self._usage_tracker.record_async(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=None,
)
return
# Record one entry per department so each department's aggregate
# includes this usage. The cost is attributed in full to each
# department (not split) — this matches how quota checks work
# (each department is checked against the full usage).
for dept_id in department_ids:
await self._usage_tracker.record_async(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=dept_id,
)
async def _enforce_quota(
self,
db_path: Path | str,
department_ids: list[str],
resolved_model: str,
) -> None:
"""Run all quota checks for the given departments.
Strictest-wins: if ANY department fails ANY check, raises
:class:`QuotaExceededError` and the request is rejected.
Both daily and monthly periods are checked (U2): for each
department, ``token_limit`` and ``cost_limit`` are evaluated
against both ``daily`` and ``monthly`` windows.
Fail-closed (KTD-1): if the usage store is unavailable (Redis
degraded), raises :class:`UsageStoreUnavailableError`. The
caller must translate this to HTTP 503.
TOCTOU (KTD-2): quota is checked *before* the LLM call, and
usage is recorded *after*. Concurrent requests in this window
may exceed the limit. This race is accepted; see
:meth:`_record_usage` for the reconciliation strategy.
"""
# Lazy import to avoid circular dependency (admin → ... → gateway).
from agentkit.server.admin.quota_service import get_quota_service
quota_service = get_quota_service()
db = Path(db_path)
for dept_id in department_ids:
# 1. Model whitelist
allowed, _reason = await quota_service.is_model_allowed(db, dept_id, resolved_model)
if not allowed:
raise QuotaExceededError(
department_id=dept_id,
quota_type="model_whitelist",
period="",
limit="",
current=resolved_model,
)
# 2. Token + cost limits (daily AND monthly)
# 优化:每个 period 只查一次 get_usage复用 summary 检查 token + cost
for period in ("daily", "monthly"):
summary = self._get_usage_summary(dept_id, period)
current_tokens = int(summary.total_tokens)
current_cost_cents = float(summary.total_cost) * 100.0
await self._check_quota_value(
quota_service, db, dept_id, period, "token_limit", current_tokens
)
await self._check_quota_value(
quota_service, db, dept_id, period, "cost_limit", current_cost_cents
)
def _get_usage_summary(self, department_id: str, period: str) -> UsageSummary:
"""返回 department_id 在当前 period 的 usage summary单次查询供 token+cost 复用)。"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
async def _check_quota_value(
self,
quota_service: Any,
db: Path,
dept_id: str,
period: str,
quota_type: str,
current: float,
) -> None:
"""检查单个配额token_limit 或 cost_limit— current 由调用方预计算传入。"""
allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type=quota_type,
period=period,
limit=limit,
current=current,
)