geo/backend/app/middleware/llm_metrics.py

103 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM调用指标包装"""
import time
from typing import Optional
from app.middleware.prometheus_metrics import (
LLM_REQUESTS_TOTAL,
LLM_REQUEST_DURATION_SECONDS,
LLM_TOKENS_TOTAL,
LLM_COST_ESTIMATED,
)
# LLM成本估算USD/token
LLM_COST_PER_TOKEN = {
# OpenAI
("openai", "gpt-4o"): {"prompt": 0.000005, "completion": 0.000015},
("openai", "gpt-4o-mini"): {"prompt": 0.00000015, "completion": 0.0000006},
("openai", "gpt-4-turbo"): {"prompt": 0.00001, "completion": 0.00003},
# DeepSeek
("deepseek", "deepseek-chat"): {"prompt": 0.00000014, "completion": 0.00000028},
("deepseek", "deepseek-coder"): {"prompt": 0.00000014, "completion": 0.00000028},
}
class LLMMetricsWrapper:
"""LLM调用指标包装器"""
def __init__(self, provider: str, model: str):
self.provider = provider
self.model = model
def record_request(
self,
status: str,
duration: float,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
):
"""记录LLM请求指标"""
# 记录请求数和耗时
LLM_REQUESTS_TOTAL.labels(
provider=self.provider,
model=self.model,
status=status
).inc()
LLM_REQUEST_DURATION_SECONDS.labels(
provider=self.provider,
model=self.model
).observe(duration)
# 记录Token消耗
if prompt_tokens is not None:
LLM_TOKENS_TOTAL.labels(
provider=self.provider,
model=self.model,
token_type="prompt"
).inc(prompt_tokens)
if completion_tokens is not None:
LLM_TOKENS_TOTAL.labels(
provider=self.provider,
model=self.model,
token_type="completion"
).inc(completion_tokens)
# 估算成本
cost = self._estimate_cost(prompt_tokens, completion_tokens)
if cost > 0:
LLM_COST_ESTIMATED.labels(
provider=self.provider,
model=self.model
).inc(cost)
def _estimate_cost(
self,
prompt_tokens: Optional[int],
completion_tokens: Optional[int]
) -> float:
"""估算请求成本"""
cost_info = LLM_COST_PER_TOKEN.get((self.provider, self.model))
if not cost_info:
return 0.0
total = 0.0
if prompt_tokens:
total += prompt_tokens * cost_info["prompt"]
if completion_tokens:
total += completion_tokens * cost_info["completion"]
return total
# 全局LLM指标记录器
_llm_metrics_cache: dict[str, LLMMetricsWrapper] = {}
def get_llm_metrics(provider: str, model: str) -> LLMMetricsWrapper:
"""获取LLM指标包装器带缓存"""
key = f"{provider}:{model}"
if key not in _llm_metrics_cache:
_llm_metrics_cache[key] = LLMMetricsWrapper(provider, model)
return _llm_metrics_cache[key]