geo/backend/app/services/smart_router.py

147 lines
5.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.api_key_manager import APIKeyManager
class CostTier(str, Enum):
FREE = "free"
LOW_COST = "low_cost"
MID_COST = "mid_cost"
HIGH_COST = "high_cost"
@dataclass
class EngineCostProfile:
engine_type: str
cost_tier: CostTier
input_price_per_million: float
output_price_per_million: float
has_free_tier: bool
requires_own_key: bool
geo_relevance: int
domestic: bool
ENGINE_COST_PROFILES: dict[str, EngineCostProfile] = {
"deepseek": EngineCostProfile("deepseek", CostTier.FREE, 0.25, 6.0, True, False, 4, True),
"qwen": EngineCostProfile("qwen", CostTier.FREE, 0.3, 0.6, True, False, 4, True),
"wenxin": EngineCostProfile("wenxin", CostTier.FREE, 0.012, 0.012, True, False, 3, True),
"kimi": EngineCostProfile("kimi", CostTier.LOW_COST, 12.0, 12.0, True, False, 4, True),
"doubao": EngineCostProfile("doubao", CostTier.LOW_COST, 0.5, 0.9, True, False, 3, True),
"gemini": EngineCostProfile("gemini", CostTier.LOW_COST, 0.5, 2.0, True, False, 4, False),
"yuanbao": EngineCostProfile("yuanbao", CostTier.MID_COST, 0.8, 2.0, True, False, 3, True),
"chatgpt": EngineCostProfile("chatgpt", CostTier.HIGH_COST, 1.0, 4.0, False, True, 5, False),
"perplexity": EngineCostProfile("perplexity", CostTier.HIGH_COST, 35.0, 35.0, False, True, 5, False),
}
_TIER_ORDER = [CostTier.FREE, CostTier.LOW_COST, CostTier.MID_COST, CostTier.HIGH_COST]
def _get_profile(engine: str) -> EngineCostProfile | None:
return ENGINE_COST_PROFILES.get(engine)
class SmartRouter:
def __init__(
self,
available_engines: list[str] | None = None,
user_engines: list[str] | None = None,
key_manager: APIKeyManager | None = None,
):
self.available_engines = available_engines or list(ENGINE_COST_PROFILES.keys())
self._user_engine_set = set(user_engines or [])
self._key_manager = key_manager
@property
def user_engines(self) -> list[str]:
return list(self._user_engine_set)
def set_key_manager(self, key_manager: APIKeyManager) -> None:
self._key_manager = key_manager
def _filter_by_available_keys(self, engines: list[str]) -> list[str]:
if not self._key_manager:
return engines
available = []
for engine in engines:
key = self._key_manager.get_any_available_key(engine)
if key:
available.append(engine)
return available
def get_available_engines(self) -> list[str]:
all_engines = list(ENGINE_COST_PROFILES.keys())
return self._filter_by_available_keys(all_engines)
def select_engines(self, max_engines: int = 5, prefer_domestic: bool = True) -> list[str]:
tiers: dict[CostTier, list[str]] = {tier: [] for tier in CostTier}
for engine in self.available_engines:
profile = _get_profile(engine)
if profile is None:
continue
tiers[profile.cost_tier].append(engine)
for tier in _TIER_ORDER:
user = [e for e in tiers[tier] if e in self._user_engine_set]
non_user = [e for e in tiers[tier] if e not in self._user_engine_set]
tiers[tier] = user + non_user
if prefer_domestic:
for tier in _TIER_ORDER:
domestic = [e for e in tiers[tier] if _get_profile(e).domestic]
international = [e for e in tiers[tier] if not _get_profile(e).domestic]
tiers[tier] = domestic + international
selected: list[str] = []
for tier in _TIER_ORDER:
for engine in tiers[tier]:
if len(selected) >= max_engines:
break
key = self._key_manager.get_any_available_key(engine) if self._key_manager else True
if key and engine not in selected:
selected.append(engine)
if len(selected) >= max_engines:
break
return selected[:max_engines]
def get_cost_estimate(self, engines: list[str], estimated_input_tokens: int = 500, estimated_output_tokens: int = 1000) -> dict:
total_cost = 0.0
details: dict[str, dict] = {}
for engine in engines:
profile = _get_profile(engine)
if profile is None:
continue
cost = (estimated_input_tokens / 1_000_000 * profile.input_price_per_million
+ estimated_output_tokens / 1_000_000 * profile.output_price_per_million)
details[engine] = {"cost": round(cost, 6), "tier": profile.cost_tier.value}
total_cost += cost
return {"total_cost": round(total_cost, 6), "per_engine": details}
def _engines_by_tier(self, tier: CostTier) -> list[str]:
tier_engines = [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier]
return self._filter_by_available_keys(tier_engines)
def get_engines_by_cost_tier(self, tier: CostTier) -> list[str]:
return self._engines_by_tier(tier)
def _engines_by_requires_key(self) -> list[str]:
return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.requires_own_key]
def get_recommended_combination(self) -> dict:
free_engines = self._engines_by_tier(CostTier.FREE)
low_cost = self._engines_by_tier(CostTier.LOW_COST)
user_only = self._engines_by_requires_key()
return {
"basic": free_engines[:3],
"standard": (free_engines + low_cost)[:5],
"premium": list(self.available_engines),
"user_premium": user_only,
}