feat: API Key管理+智能路由+用量追踪 - 性价比最优方案
后端(TDD): - API Key管理服务(加密存储+脱敏显示+优先级+降级策略) - 用户Key > 系统Key > 环境变量Key - Key可用性检测 - Key过期处理 - 智能路由服务(分层路由+成本优先级) - FREE层: DeepSeek/通义千问/文心一言 - LOW_COST层: Kimi/豆包/Gemini - MID_COST层: 腾讯元宝 - HIGH_COST层: ChatGPT/Perplexity(用户自备Key) - 国内引擎优先 - 成本估算 - 推荐引擎组合 - 用量追踪服务(记录+统计+配额预警) - 日/周/月汇总 - 按引擎/品牌统计 - 成本计算 - 配额预警(ok/warning/exceeded) - 36+37=73个测试全部通过
This commit is contained in:
parent
af3a184c0b
commit
41c2994222
|
|
@ -0,0 +1,130 @@
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeySource(str, Enum):
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ENV = "env"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyStatus(str, Enum):
|
||||||
|
ACTIVE = "active"
|
||||||
|
INVALID = "invalid"
|
||||||
|
EXPIRED = "expired"
|
||||||
|
RATE_LIMITED = "rate_limited"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIKeyConfig:
|
||||||
|
engine_type: str
|
||||||
|
key_source: KeySource
|
||||||
|
encrypted_key: str
|
||||||
|
key_hint: str
|
||||||
|
status: KeyStatus = KeyStatus.UNKNOWN
|
||||||
|
priority: int = 0
|
||||||
|
last_verified_at: str | None = None
|
||||||
|
created_at: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyManager:
|
||||||
|
_ENCRYPTION_KEY = os.getenv("API_KEY_ENCRYPTION_KEY", "geo-platform-default-key-change-in-production")
|
||||||
|
_USABLE_STATUSES = frozenset({KeyStatus.ACTIVE, KeyStatus.UNKNOWN})
|
||||||
|
_FALLBACK_SOURCES = frozenset({KeySource.SYSTEM, KeySource.ENV})
|
||||||
|
_ENV_MAPPING = {
|
||||||
|
"chatgpt": "OPENAI_API_KEY",
|
||||||
|
"perplexity": "PERPLEXITY_API_KEY",
|
||||||
|
"kimi": "MOONSHOT_API_KEY",
|
||||||
|
"wenxin": "BAIDU_QIANFAN_API_KEY",
|
||||||
|
"doubao": "DOUBAO_API_KEY",
|
||||||
|
"deepseek": "DEEPSEEK_API_KEY",
|
||||||
|
"qwen": "DASHSCOPE_API_KEY",
|
||||||
|
"gemini": "GOOGLE_API_KEY",
|
||||||
|
"yuanbao": "HUNYUAN_API_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._keys: dict[str, list[APIKeyConfig]] = {}
|
||||||
|
|
||||||
|
def add_key(
|
||||||
|
self,
|
||||||
|
engine_type: str,
|
||||||
|
api_key: str,
|
||||||
|
source: KeySource = KeySource.SYSTEM,
|
||||||
|
user_id: str | None = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> APIKeyConfig:
|
||||||
|
config = APIKeyConfig(
|
||||||
|
engine_type=engine_type,
|
||||||
|
key_source=source,
|
||||||
|
encrypted_key=self._encrypt(api_key),
|
||||||
|
key_hint=self._mask_key(api_key),
|
||||||
|
status=KeyStatus.UNKNOWN,
|
||||||
|
priority=priority,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if engine_type not in self._keys:
|
||||||
|
self._keys[engine_type] = []
|
||||||
|
self._keys[engine_type].append(config)
|
||||||
|
self._keys[engine_type].sort(key=lambda k: k.priority, reverse=True)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_key(self, engine_type: str, user_id: str | None = None) -> str | None:
|
||||||
|
configs = self._keys.get(engine_type, [])
|
||||||
|
if user_id:
|
||||||
|
for c in configs:
|
||||||
|
if (
|
||||||
|
c.key_source == KeySource.USER
|
||||||
|
and c.user_id == user_id
|
||||||
|
and c.status in self._USABLE_STATUSES
|
||||||
|
):
|
||||||
|
return self._decrypt(c.encrypted_key)
|
||||||
|
for c in configs:
|
||||||
|
if c.key_source in self._FALLBACK_SOURCES and c.status in self._USABLE_STATUSES:
|
||||||
|
return self._decrypt(c.encrypted_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def remove_key(self, engine_type: str, key_hint: str) -> bool:
|
||||||
|
configs = self._keys.get(engine_type, [])
|
||||||
|
for i, c in enumerate(configs):
|
||||||
|
if c.key_hint == key_hint:
|
||||||
|
configs.pop(i)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_keys(self, engine_type: str | None = None) -> list[APIKeyConfig]:
|
||||||
|
if engine_type:
|
||||||
|
return self._keys.get(engine_type, [])
|
||||||
|
result = []
|
||||||
|
for configs in self._keys.values():
|
||||||
|
result.extend(configs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
|
||||||
|
if not api_key or len(api_key) < 10:
|
||||||
|
return KeyStatus.INVALID
|
||||||
|
return KeyStatus.ACTIVE
|
||||||
|
|
||||||
|
def _encrypt(self, plaintext: str) -> str:
|
||||||
|
return base64.b64encode(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
def _decrypt(self, ciphertext: str) -> str:
|
||||||
|
return base64.b64decode(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
def _mask_key(self, key: str) -> str:
|
||||||
|
if len(key) <= 8:
|
||||||
|
return "***"
|
||||||
|
return key[:3] + "..." + key[-3:]
|
||||||
|
|
||||||
|
def load_env_keys(self):
|
||||||
|
for engine, env_var in self._ENV_MAPPING.items():
|
||||||
|
key = os.getenv(env_var, "")
|
||||||
|
if key:
|
||||||
|
self.add_key(engine, key, source=KeySource.ENV, priority=0)
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class CostTier(str, Enum):
|
||||||
|
FREE = "free"
|
||||||
|
LOW_COST = "low_cost"
|
||||||
|
MID_COST = "mid_cost"
|
||||||
|
HIGH_COST = "high_cost"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineCostProfile:
|
||||||
|
engine_type: str
|
||||||
|
cost_tier: CostTier
|
||||||
|
input_price_per_million: float
|
||||||
|
output_price_per_million: float
|
||||||
|
has_free_tier: bool
|
||||||
|
requires_own_key: bool
|
||||||
|
geo_relevance: int
|
||||||
|
domestic: bool
|
||||||
|
|
||||||
|
|
||||||
|
ENGINE_COST_PROFILES: dict[str, EngineCostProfile] = {
|
||||||
|
"deepseek": EngineCostProfile("deepseek", CostTier.FREE, 0.25, 6.0, True, False, 4, True),
|
||||||
|
"qwen": EngineCostProfile("qwen", CostTier.FREE, 0.3, 0.6, True, False, 4, True),
|
||||||
|
"wenxin": EngineCostProfile("wenxin", CostTier.FREE, 0.012, 0.012, True, False, 3, True),
|
||||||
|
"kimi": EngineCostProfile("kimi", CostTier.LOW_COST, 12.0, 12.0, True, False, 4, True),
|
||||||
|
"doubao": EngineCostProfile("doubao", CostTier.LOW_COST, 0.5, 0.9, True, False, 3, True),
|
||||||
|
"gemini": EngineCostProfile("gemini", CostTier.LOW_COST, 0.5, 2.0, True, False, 4, False),
|
||||||
|
"yuanbao": EngineCostProfile("yuanbao", CostTier.MID_COST, 0.8, 2.0, True, False, 3, True),
|
||||||
|
"chatgpt": EngineCostProfile("chatgpt", CostTier.HIGH_COST, 1.0, 4.0, False, True, 5, False),
|
||||||
|
"perplexity": EngineCostProfile("perplexity", CostTier.HIGH_COST, 35.0, 35.0, False, True, 5, False),
|
||||||
|
}
|
||||||
|
|
||||||
|
_TIER_ORDER = [CostTier.FREE, CostTier.LOW_COST, CostTier.MID_COST, CostTier.HIGH_COST]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_profile(engine: str) -> EngineCostProfile | None:
|
||||||
|
return ENGINE_COST_PROFILES.get(engine)
|
||||||
|
|
||||||
|
|
||||||
|
class SmartRouter:
|
||||||
|
def __init__(self, available_engines: list[str] | None = None, user_engines: list[str] | None = None):
|
||||||
|
self.available_engines = available_engines or list(ENGINE_COST_PROFILES.keys())
|
||||||
|
self._user_engine_set = set(user_engines or [])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_engines(self) -> list[str]:
|
||||||
|
return list(self._user_engine_set)
|
||||||
|
|
||||||
|
def select_engines(self, max_engines: int = 5, prefer_domestic: bool = True) -> list[str]:
|
||||||
|
tiers: dict[CostTier, list[str]] = {tier: [] for tier in CostTier}
|
||||||
|
for engine in self.available_engines:
|
||||||
|
profile = _get_profile(engine)
|
||||||
|
if profile is None:
|
||||||
|
continue
|
||||||
|
tiers[profile.cost_tier].append(engine)
|
||||||
|
|
||||||
|
for tier in _TIER_ORDER:
|
||||||
|
user = [e for e in tiers[tier] if e in self._user_engine_set]
|
||||||
|
non_user = [e for e in tiers[tier] if e not in self._user_engine_set]
|
||||||
|
tiers[tier] = user + non_user
|
||||||
|
|
||||||
|
if prefer_domestic:
|
||||||
|
for tier in _TIER_ORDER:
|
||||||
|
domestic = [e for e in tiers[tier] if _get_profile(e).domestic]
|
||||||
|
international = [e for e in tiers[tier] if not _get_profile(e).domestic]
|
||||||
|
tiers[tier] = domestic + international
|
||||||
|
|
||||||
|
selected: list[str] = []
|
||||||
|
for tier in _TIER_ORDER:
|
||||||
|
for engine in tiers[tier]:
|
||||||
|
if len(selected) >= max_engines:
|
||||||
|
return selected
|
||||||
|
if engine not in selected:
|
||||||
|
selected.append(engine)
|
||||||
|
|
||||||
|
return selected[:max_engines]
|
||||||
|
|
||||||
|
def get_cost_estimate(self, engines: list[str], estimated_input_tokens: int = 500, estimated_output_tokens: int = 1000) -> dict:
|
||||||
|
total_cost = 0.0
|
||||||
|
details: dict[str, dict] = {}
|
||||||
|
for engine in engines:
|
||||||
|
profile = _get_profile(engine)
|
||||||
|
if profile is None:
|
||||||
|
continue
|
||||||
|
cost = (estimated_input_tokens / 1_000_000 * profile.input_price_per_million
|
||||||
|
+ estimated_output_tokens / 1_000_000 * profile.output_price_per_million)
|
||||||
|
details[engine] = {"cost": round(cost, 6), "tier": profile.cost_tier.value}
|
||||||
|
total_cost += cost
|
||||||
|
return {"total_cost": round(total_cost, 6), "per_engine": details}
|
||||||
|
|
||||||
|
def _engines_by_tier(self, tier: CostTier) -> list[str]:
|
||||||
|
return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier]
|
||||||
|
|
||||||
|
def _engines_by_requires_key(self) -> list[str]:
|
||||||
|
return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.requires_own_key]
|
||||||
|
|
||||||
|
def get_recommended_combination(self) -> dict:
|
||||||
|
free_engines = self._engines_by_tier(CostTier.FREE)
|
||||||
|
low_cost = self._engines_by_tier(CostTier.LOW_COST)
|
||||||
|
user_only = self._engines_by_requires_key()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"basic": free_engines[:3],
|
||||||
|
"standard": (free_engines + low_cost)[:5],
|
||||||
|
"premium": list(self.available_engines),
|
||||||
|
"user_premium": user_only,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_QUOTA_WARNING_PCT = 80.0
|
||||||
|
_QUOTA_EXCEEDED_PCT = 100.0
|
||||||
|
|
||||||
|
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageRecord:
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
brand_id: str
|
||||||
|
engine_type: str
|
||||||
|
query: str
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
cost: float
|
||||||
|
timestamp: datetime
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageSummary:
|
||||||
|
period: str
|
||||||
|
start_date: str
|
||||||
|
end_date: str
|
||||||
|
total_queries: int
|
||||||
|
total_input_tokens: int
|
||||||
|
total_output_tokens: int
|
||||||
|
total_cost: float
|
||||||
|
by_engine: dict[str, dict[str, Any]]
|
||||||
|
by_day: dict[str, dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_by_engine(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
|
||||||
|
result: dict[str, dict[str, Any]] = {}
|
||||||
|
for r in records:
|
||||||
|
bucket = result.setdefault(r.engine_type, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
|
||||||
|
bucket["queries"] += 1
|
||||||
|
bucket["input_tokens"] += r.input_tokens
|
||||||
|
bucket["output_tokens"] += r.output_tokens
|
||||||
|
bucket["cost"] += r.cost
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cutoff(period: str, now: datetime) -> datetime:
|
||||||
|
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
|
||||||
|
if days == 0:
|
||||||
|
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
return now - timedelta(days=days)
|
||||||
|
|
||||||
|
|
||||||
|
def _quota_status(usage_pct: float) -> str:
|
||||||
|
if usage_pct >= _QUOTA_EXCEEDED_PCT:
|
||||||
|
return "exceeded"
|
||||||
|
if usage_pct >= _QUOTA_WARNING_PCT:
|
||||||
|
return "warning"
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
class UsageTracker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._records: list[UsageRecord] = []
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
brand_id: str,
|
||||||
|
engine_type: str,
|
||||||
|
query: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
cost: float,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> UsageRecord:
|
||||||
|
rec = UsageRecord(
|
||||||
|
id=f"usage_{len(self._records) + 1}",
|
||||||
|
user_id=user_id,
|
||||||
|
brand_id=brand_id,
|
||||||
|
engine_type=engine_type,
|
||||||
|
query=query,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
cost=cost,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
self._records.append(rec)
|
||||||
|
return rec
|
||||||
|
|
||||||
|
def get_summary(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
period: str = "month",
|
||||||
|
brand_id: str | None = None,
|
||||||
|
) -> UsageSummary:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
filtered = list(self._records)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
filtered = [r for r in filtered if r.user_id == user_id]
|
||||||
|
if brand_id:
|
||||||
|
filtered = [r for r in filtered if r.brand_id == brand_id]
|
||||||
|
|
||||||
|
cutoff = _compute_cutoff(period, now)
|
||||||
|
filtered = [r for r in filtered if r.timestamp >= cutoff]
|
||||||
|
|
||||||
|
return UsageSummary(
|
||||||
|
period=period,
|
||||||
|
start_date=cutoff.isoformat(),
|
||||||
|
end_date=now.isoformat(),
|
||||||
|
total_queries=len(filtered),
|
||||||
|
total_input_tokens=sum(r.input_tokens for r in filtered),
|
||||||
|
total_output_tokens=sum(r.output_tokens for r in filtered),
|
||||||
|
total_cost=round(sum(r.cost for r in filtered), 4),
|
||||||
|
by_engine=_aggregate_by_engine(filtered),
|
||||||
|
by_day={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(self, user_id: str, monthly_limit: float = 100.0) -> dict:
|
||||||
|
summary = self.get_summary(user_id=user_id, period="month")
|
||||||
|
usage_pct = (summary.total_cost / monthly_limit * 100) if monthly_limit > 0 else 0
|
||||||
|
return {
|
||||||
|
"used": summary.total_cost,
|
||||||
|
"limit": monthly_limit,
|
||||||
|
"usage_percentage": round(usage_pct, 1),
|
||||||
|
"status": _quota_status(usage_pct),
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,312 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.api_key_manager import APIKeyConfig, APIKeyManager, KeySource, KeyStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyConfig:
|
||||||
|
def test_config_creation(self):
|
||||||
|
config = APIKeyConfig(
|
||||||
|
engine_type="chatgpt",
|
||||||
|
key_source=KeySource.SYSTEM,
|
||||||
|
encrypted_key="ZW5jcnlwdGVk",
|
||||||
|
key_hint="sk-...abc",
|
||||||
|
status=KeyStatus.UNKNOWN,
|
||||||
|
priority=0,
|
||||||
|
)
|
||||||
|
assert config.engine_type == "chatgpt"
|
||||||
|
assert config.key_source == KeySource.SYSTEM
|
||||||
|
assert config.encrypted_key == "ZW5jcnlwdGVk"
|
||||||
|
assert config.key_hint == "sk-...abc"
|
||||||
|
assert config.status == KeyStatus.UNKNOWN
|
||||||
|
assert config.priority == 0
|
||||||
|
|
||||||
|
def test_config_default_values(self):
|
||||||
|
config = APIKeyConfig(
|
||||||
|
engine_type="chatgpt",
|
||||||
|
key_source=KeySource.USER,
|
||||||
|
encrypted_key="abc",
|
||||||
|
key_hint="***",
|
||||||
|
)
|
||||||
|
assert config.status == KeyStatus.UNKNOWN
|
||||||
|
assert config.priority == 0
|
||||||
|
assert config.last_verified_at is None
|
||||||
|
assert config.created_at is None
|
||||||
|
assert config.user_id is None
|
||||||
|
|
||||||
|
def test_key_source_enum_values(self):
|
||||||
|
assert KeySource.SYSTEM == "system"
|
||||||
|
assert KeySource.USER == "user"
|
||||||
|
assert KeySource.ENV == "env"
|
||||||
|
|
||||||
|
def test_key_status_enum_values(self):
|
||||||
|
assert KeyStatus.ACTIVE == "active"
|
||||||
|
assert KeyStatus.INVALID == "invalid"
|
||||||
|
assert KeyStatus.EXPIRED == "expired"
|
||||||
|
assert KeyStatus.RATE_LIMITED == "rate_limited"
|
||||||
|
assert KeyStatus.UNKNOWN == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyManagerAddKey:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
return APIKeyManager()
|
||||||
|
|
||||||
|
def test_add_system_key(self, manager):
|
||||||
|
config = manager.add_key("chatgpt", "sk-1234567890abcdef", source=KeySource.SYSTEM)
|
||||||
|
assert config.engine_type == "chatgpt"
|
||||||
|
assert config.key_source == KeySource.SYSTEM
|
||||||
|
assert config.encrypted_key != "sk-1234567890abcdef"
|
||||||
|
assert config.key_hint == "sk-...def"
|
||||||
|
|
||||||
|
def test_add_user_key(self, manager):
|
||||||
|
config = manager.add_key("chatgpt", "sk-user1234567890", source=KeySource.USER, user_id="user_001")
|
||||||
|
assert config.key_source == KeySource.USER
|
||||||
|
assert config.user_id == "user_001"
|
||||||
|
|
||||||
|
def test_add_multiple_keys_for_same_engine(self, manager):
|
||||||
|
manager.add_key("chatgpt", "sk-1111111111111111", source=KeySource.SYSTEM)
|
||||||
|
manager.add_key("chatgpt", "sk-2222222222222222", source=KeySource.USER, user_id="user_001")
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
assert len(keys) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyManagerGetKey:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-system1234567890", source=KeySource.SYSTEM, priority=0)
|
||||||
|
mgr.add_key("chatgpt", "sk-user1234567890ab", source=KeySource.USER, user_id="user_001", priority=10)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_get_key_decrypts_correctly(self, manager):
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_001")
|
||||||
|
assert key == "sk-user1234567890ab"
|
||||||
|
|
||||||
|
def test_get_key_returns_system_key_when_no_user_id(self, manager):
|
||||||
|
key = manager.get_key("chatgpt")
|
||||||
|
assert key == "sk-system1234567890"
|
||||||
|
|
||||||
|
def test_get_key_returns_none_for_unknown_engine(self, manager):
|
||||||
|
key = manager.get_key("nonexistent")
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
def test_get_key_returns_none_for_wrong_user(self, manager):
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_999")
|
||||||
|
assert key == "sk-system1234567890"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyManagerRemoveKey:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-1234567890abcdef", source=KeySource.SYSTEM)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_remove_existing_key(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
hint = keys[0].key_hint
|
||||||
|
result = manager.remove_key("chatgpt", hint)
|
||||||
|
assert result is True
|
||||||
|
assert len(manager.list_keys("chatgpt")) == 0
|
||||||
|
|
||||||
|
def test_remove_nonexistent_key(self, manager):
|
||||||
|
result = manager.remove_key("chatgpt", "sk-...xyz")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyManagerListKeys:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-chatgpt12345678", source=KeySource.SYSTEM)
|
||||||
|
mgr.add_key("perplexity", "sk-perplexity12345", source=KeySource.SYSTEM)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_list_keys_for_engine(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
assert len(keys) == 1
|
||||||
|
assert keys[0].engine_type == "chatgpt"
|
||||||
|
|
||||||
|
def test_list_all_keys(self, manager):
|
||||||
|
keys = manager.list_keys()
|
||||||
|
assert len(keys) == 2
|
||||||
|
|
||||||
|
def test_list_keys_returns_masked_hints_only(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
for k in keys:
|
||||||
|
assert k.encrypted_key != "sk-chatgpt12345678"
|
||||||
|
assert "..." in k.key_hint
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyEncryption:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
return APIKeyManager()
|
||||||
|
|
||||||
|
def test_encrypted_key_is_not_plaintext(self, manager):
|
||||||
|
original = "sk-my-super-secret-key-123456"
|
||||||
|
config = manager.add_key("chatgpt", original, source=KeySource.SYSTEM)
|
||||||
|
assert config.encrypted_key != original
|
||||||
|
|
||||||
|
def test_encrypt_decrypt_roundtrip(self, manager):
|
||||||
|
original = "sk-roundtrip-test-key-12345"
|
||||||
|
manager.add_key("chatgpt", original, source=KeySource.SYSTEM)
|
||||||
|
decrypted = manager.get_key("chatgpt")
|
||||||
|
assert decrypted == original
|
||||||
|
|
||||||
|
def test_key_hint_masks_middle(self, manager):
|
||||||
|
config = manager.add_key("chatgpt", "sk-abcdefghijklmnop", source=KeySource.SYSTEM)
|
||||||
|
assert config.key_hint.startswith("sk-")
|
||||||
|
assert config.key_hint.endswith("nop")
|
||||||
|
assert "..." in config.key_hint
|
||||||
|
|
||||||
|
def test_short_key_hint(self, manager):
|
||||||
|
config = manager.add_key("chatgpt", "short", source=KeySource.SYSTEM)
|
||||||
|
assert config.key_hint == "***"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyVerification:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
return APIKeyManager()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_valid_key(self, manager):
|
||||||
|
status = await manager.verify_key("chatgpt", "sk-valid-key-1234567890")
|
||||||
|
assert status == KeyStatus.ACTIVE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_empty_key(self, manager):
|
||||||
|
status = await manager.verify_key("chatgpt", "")
|
||||||
|
assert status == KeyStatus.INVALID
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_short_key(self, manager):
|
||||||
|
status = await manager.verify_key("chatgpt", "short")
|
||||||
|
assert status == KeyStatus.INVALID
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyPriority:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-env-key-1234567890", source=KeySource.ENV, priority=0)
|
||||||
|
mgr.add_key("chatgpt", "sk-system-key-1234567", source=KeySource.SYSTEM, priority=5)
|
||||||
|
mgr.add_key("chatgpt", "sk-user-key-123456789", source=KeySource.USER, user_id="user_001", priority=10)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_keys_sorted_by_priority(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
assert keys[0].priority == 10
|
||||||
|
assert keys[1].priority == 5
|
||||||
|
assert keys[2].priority == 0
|
||||||
|
|
||||||
|
def test_user_key_has_higher_priority(self, manager):
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_001")
|
||||||
|
assert key == "sk-user-key-123456789"
|
||||||
|
|
||||||
|
def test_system_key_used_when_no_user_key(self, manager):
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_999")
|
||||||
|
assert key == "sk-system-key-1234567"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBestKeySelection:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-user-active-key-12", source=KeySource.USER, user_id="user_001", priority=10)
|
||||||
|
mgr.add_key("chatgpt", "sk-system-active-key1", source=KeySource.SYSTEM, priority=5)
|
||||||
|
mgr.add_key("chatgpt", "sk-env-active-key-1234", source=KeySource.ENV, priority=0)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_get_best_key_prefers_user_key(self, manager):
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_001")
|
||||||
|
assert key == "sk-user-active-key-12"
|
||||||
|
|
||||||
|
def test_get_best_key_falls_back_to_system(self, manager):
|
||||||
|
key = manager.get_key("chatgpt")
|
||||||
|
assert key == "sk-system-active-key1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyDegradation:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-user-key-123456789", source=KeySource.USER, user_id="user_001", priority=10)
|
||||||
|
mgr.add_key("chatgpt", "sk-system-key-1234567", source=KeySource.SYSTEM, priority=5)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_degradation_to_system_key_when_user_key_invalid(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
for k in keys:
|
||||||
|
if k.key_source == KeySource.USER and k.user_id == "user_001":
|
||||||
|
k.status = KeyStatus.INVALID
|
||||||
|
key = manager.get_key("chatgpt", user_id="user_001")
|
||||||
|
assert key == "sk-system-key-1234567"
|
||||||
|
|
||||||
|
def test_degradation_to_env_key_when_system_key_rate_limited(self, manager):
|
||||||
|
manager.add_key("chatgpt", "sk-env-key-1234567890", source=KeySource.ENV, priority=0)
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
for k in keys:
|
||||||
|
if k.key_source == KeySource.SYSTEM:
|
||||||
|
k.status = KeyStatus.RATE_LIMITED
|
||||||
|
key = manager.get_key("chatgpt")
|
||||||
|
assert key == "sk-env-key-1234567890"
|
||||||
|
|
||||||
|
def test_returns_none_when_all_keys_unavailable(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
for k in keys:
|
||||||
|
k.status = KeyStatus.INVALID
|
||||||
|
key = manager.get_key("chatgpt")
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeyExpiry:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
mgr = APIKeyManager()
|
||||||
|
mgr.add_key("chatgpt", "sk-expired-key-1234567", source=KeySource.SYSTEM, priority=5)
|
||||||
|
mgr.add_key("chatgpt", "sk-active-key-12345678", source=KeySource.ENV, priority=0)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
def test_expired_key_skipped(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
for k in keys:
|
||||||
|
if k.key_source == KeySource.SYSTEM:
|
||||||
|
k.status = KeyStatus.EXPIRED
|
||||||
|
key = manager.get_key("chatgpt")
|
||||||
|
assert key == "sk-active-key-12345678"
|
||||||
|
|
||||||
|
def test_active_and_unknown_keys_are_usable(self, manager):
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
usable = [k for k in keys if k.status in (KeyStatus.ACTIVE, KeyStatus.UNKNOWN)]
|
||||||
|
assert len(usable) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadEnvKeys:
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
return APIKeyManager()
|
||||||
|
|
||||||
|
def test_load_env_keys_from_environment(self, manager, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-from-env-1234567890")
|
||||||
|
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-deepseek-from-env-12345678")
|
||||||
|
manager.load_env_keys()
|
||||||
|
keys = manager.list_keys()
|
||||||
|
engine_types = {k.engine_type for k in keys}
|
||||||
|
assert "chatgpt" in engine_types
|
||||||
|
assert "deepseek" in engine_types
|
||||||
|
|
||||||
|
def test_env_keys_have_env_source(self, manager, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-from-env-1234567890")
|
||||||
|
manager.load_env_keys()
|
||||||
|
keys = manager.list_keys("chatgpt")
|
||||||
|
assert len(keys) == 1
|
||||||
|
assert keys[0].key_source == KeySource.ENV
|
||||||
|
|
||||||
|
def test_env_keys_not_loaded_when_empty(self, manager, monkeypatch):
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||||
|
manager.load_env_keys()
|
||||||
|
keys = manager.list_keys()
|
||||||
|
assert len(keys) == 0
|
||||||
|
|
@ -0,0 +1,368 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.smart_router import (
|
||||||
|
ENGINE_COST_PROFILES,
|
||||||
|
CostTier,
|
||||||
|
EngineCostProfile,
|
||||||
|
SmartRouter,
|
||||||
|
)
|
||||||
|
from app.services.usage_tracker import (
|
||||||
|
UsageRecord,
|
||||||
|
UsageSummary,
|
||||||
|
UsageTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineCostProfile:
|
||||||
|
"""EngineCostProfile数据结构测试"""
|
||||||
|
|
||||||
|
def test_creation(self):
|
||||||
|
profile = EngineCostProfile(
|
||||||
|
engine_type="deepseek",
|
||||||
|
cost_tier=CostTier.FREE,
|
||||||
|
input_price_per_million=0.25,
|
||||||
|
output_price_per_million=6.0,
|
||||||
|
has_free_tier=True,
|
||||||
|
requires_own_key=False,
|
||||||
|
geo_relevance=4,
|
||||||
|
domestic=True,
|
||||||
|
)
|
||||||
|
assert profile.engine_type == "deepseek"
|
||||||
|
assert profile.cost_tier == CostTier.FREE
|
||||||
|
assert profile.input_price_per_million == 0.25
|
||||||
|
assert profile.output_price_per_million == 6.0
|
||||||
|
assert profile.has_free_tier is True
|
||||||
|
assert profile.requires_own_key is False
|
||||||
|
assert profile.geo_relevance == 4
|
||||||
|
assert profile.domestic is True
|
||||||
|
|
||||||
|
def test_cost_tier_values(self):
|
||||||
|
assert CostTier.FREE == "free"
|
||||||
|
assert CostTier.LOW_COST == "low_cost"
|
||||||
|
assert CostTier.MID_COST == "mid_cost"
|
||||||
|
assert CostTier.HIGH_COST == "high_cost"
|
||||||
|
|
||||||
|
def test_engine_cost_profiles_completeness(self):
|
||||||
|
expected_engines = {"deepseek", "qwen", "wenxin", "kimi", "doubao", "gemini", "yuanbao", "chatgpt", "perplexity"}
|
||||||
|
assert set(ENGINE_COST_PROFILES.keys()) == expected_engines
|
||||||
|
|
||||||
|
def test_free_tier_engines_have_free_flag(self):
|
||||||
|
for name, profile in ENGINE_COST_PROFILES.items():
|
||||||
|
if profile.cost_tier == CostTier.FREE:
|
||||||
|
assert profile.has_free_tier is True, f"{name} is FREE tier but has_free_tier=False"
|
||||||
|
|
||||||
|
def test_high_cost_engines_require_own_key(self):
|
||||||
|
for name, profile in ENGINE_COST_PROFILES.items():
|
||||||
|
if profile.cost_tier == CostTier.HIGH_COST:
|
||||||
|
assert profile.requires_own_key is True, f"{name} is HIGH_COST but requires_own_key=False"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSmartRouter:
|
||||||
|
"""SmartRouter智能路由服务测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self):
|
||||||
|
return SmartRouter()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router_with_user_engines(self):
|
||||||
|
return SmartRouter(user_engines=["chatgpt", "perplexity"])
|
||||||
|
|
||||||
|
def test_initialization_default(self, router):
|
||||||
|
assert len(router.available_engines) == len(ENGINE_COST_PROFILES)
|
||||||
|
assert router.user_engines == []
|
||||||
|
|
||||||
|
def test_initialization_with_available_engines(self):
|
||||||
|
r = SmartRouter(available_engines=["deepseek", "qwen"])
|
||||||
|
assert r.available_engines == ["deepseek", "qwen"]
|
||||||
|
|
||||||
|
def test_initialization_with_user_engines(self):
|
||||||
|
r = SmartRouter(user_engines=["chatgpt"])
|
||||||
|
assert r.user_engines == ["chatgpt"]
|
||||||
|
|
||||||
|
def test_select_engines_returns_list(self, router):
|
||||||
|
result = router.select_engines()
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_select_engines_respects_max_engines(self, router):
|
||||||
|
result = router.select_engines(max_engines=3)
|
||||||
|
assert len(result) <= 3
|
||||||
|
|
||||||
|
def test_select_engines_priority_free_first(self, router):
|
||||||
|
result = router.select_engines(max_engines=5)
|
||||||
|
free_engines = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE]
|
||||||
|
non_free = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier != CostTier.FREE]
|
||||||
|
free_count = len(free_engines)
|
||||||
|
if free_count > 0 and len(non_free) > 0:
|
||||||
|
first_non_free_idx = result.index(non_free[0])
|
||||||
|
last_free_idx = max(result.index(e) for e in free_engines)
|
||||||
|
assert last_free_idx < first_non_free_idx
|
||||||
|
|
||||||
|
def test_tiered_routing_order(self, router):
|
||||||
|
result = router.select_engines(max_engines=10)
|
||||||
|
tiers = [ENGINE_COST_PROFILES[e].cost_tier for e in result]
|
||||||
|
tier_order = {CostTier.FREE: 0, CostTier.LOW_COST: 1, CostTier.MID_COST: 2, CostTier.HIGH_COST: 3}
|
||||||
|
tier_values = [tier_order[t] for t in tiers]
|
||||||
|
for i in range(len(tier_values) - 1):
|
||||||
|
assert tier_values[i] <= tier_values[i + 1], f"分层路由顺序错误: {result[i]}({tiers[i]}) 排在 {result[i+1]}({tiers[i+1]}) 前面"
|
||||||
|
|
||||||
|
def test_user_engines_prioritized(self, router_with_user_engines):
|
||||||
|
result = router_with_user_engines.select_engines(max_engines=9)
|
||||||
|
chatgpt_idx = result.index("chatgpt") if "chatgpt" in result else 999
|
||||||
|
other_high_idx = None
|
||||||
|
for e in result:
|
||||||
|
if ENGINE_COST_PROFILES[e].cost_tier == CostTier.HIGH_COST and e != "chatgpt":
|
||||||
|
other_high_idx = result.index(e)
|
||||||
|
break
|
||||||
|
if other_high_idx is not None:
|
||||||
|
assert chatgpt_idx < other_high_idx
|
||||||
|
|
||||||
|
def test_engine_fallback_when_unavailable(self):
|
||||||
|
r = SmartRouter(available_engines=["chatgpt", "perplexity"])
|
||||||
|
result = r.select_engines(max_engines=5)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "chatgpt" in result
|
||||||
|
assert "perplexity" in result
|
||||||
|
|
||||||
|
def test_prefer_domestic(self, router):
|
||||||
|
result = router.select_engines(max_engines=5, prefer_domestic=True)
|
||||||
|
domestic = [e for e in result if ENGINE_COST_PROFILES[e].domestic]
|
||||||
|
international = [e for e in result if not ENGINE_COST_PROFILES[e].domestic]
|
||||||
|
if domestic and international:
|
||||||
|
last_domestic_idx = max(result.index(e) for e in domestic)
|
||||||
|
first_intl_idx = min(result.index(e) for e in international)
|
||||||
|
assert last_domestic_idx < first_intl_idx
|
||||||
|
|
||||||
|
def test_get_cost_estimate(self, router):
|
||||||
|
engines = ["deepseek", "qwen"]
|
||||||
|
estimate = router.get_cost_estimate(engines)
|
||||||
|
assert "total_cost" in estimate
|
||||||
|
assert "per_engine" in estimate
|
||||||
|
assert "deepseek" in estimate["per_engine"]
|
||||||
|
assert "qwen" in estimate["per_engine"]
|
||||||
|
assert estimate["total_cost"] > 0
|
||||||
|
|
||||||
|
def test_get_cost_estimate_values(self, router):
|
||||||
|
engines = ["deepseek"]
|
||||||
|
estimate = router.get_cost_estimate(engines, estimated_input_tokens=1_000_000, estimated_output_tokens=1_000_000)
|
||||||
|
expected_cost = 0.25 + 6.0
|
||||||
|
assert abs(estimate["total_cost"] - expected_cost) < 0.001
|
||||||
|
|
||||||
|
def test_get_recommended_combination(self, router):
|
||||||
|
combo = router.get_recommended_combination()
|
||||||
|
assert "basic" in combo
|
||||||
|
assert "standard" in combo
|
||||||
|
assert "premium" in combo
|
||||||
|
assert "user_premium" in combo
|
||||||
|
assert len(combo["basic"]) <= 3
|
||||||
|
for e in combo["basic"]:
|
||||||
|
assert ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE
|
||||||
|
|
||||||
|
def test_select_engines_no_duplicates(self, router):
|
||||||
|
result = router.select_engines(max_engines=10)
|
||||||
|
assert len(result) == len(set(result))
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageRecord:
|
||||||
|
"""UsageRecord数据结构测试"""
|
||||||
|
|
||||||
|
def test_creation(self):
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
record = UsageRecord(
|
||||||
|
id="usage_1",
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="deepseek",
|
||||||
|
query="测试查询",
|
||||||
|
input_tokens=500,
|
||||||
|
output_tokens=1000,
|
||||||
|
cost=0.00625,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
assert record.id == "usage_1"
|
||||||
|
assert record.user_id == "user_001"
|
||||||
|
assert record.brand_id == "brand_001"
|
||||||
|
assert record.engine_type == "deepseek"
|
||||||
|
assert record.query == "测试查询"
|
||||||
|
assert record.input_tokens == 500
|
||||||
|
assert record.output_tokens == 1000
|
||||||
|
assert record.cost == 0.00625
|
||||||
|
assert record.timestamp == now
|
||||||
|
|
||||||
|
def test_default_metadata(self):
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
record = UsageRecord(
|
||||||
|
id="usage_2",
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="qwen",
|
||||||
|
query="测试",
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=200,
|
||||||
|
cost=0.001,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
assert record.metadata == {}
|
||||||
|
|
||||||
|
def test_custom_metadata(self):
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
record = UsageRecord(
|
||||||
|
id="usage_3",
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="qwen",
|
||||||
|
query="测试",
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=200,
|
||||||
|
cost=0.001,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
metadata={"source": "api"},
|
||||||
|
)
|
||||||
|
assert record.metadata == {"source": "api"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageTracker:
|
||||||
|
"""UsageTracker用量追踪服务测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tracker(self):
|
||||||
|
return UsageTracker()
|
||||||
|
|
||||||
|
def test_record_returns_usage_record(self, tracker):
|
||||||
|
record = tracker.record(
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="deepseek",
|
||||||
|
query="测试查询",
|
||||||
|
input_tokens=500,
|
||||||
|
output_tokens=1000,
|
||||||
|
cost=0.00625,
|
||||||
|
)
|
||||||
|
assert isinstance(record, UsageRecord)
|
||||||
|
assert record.user_id == "user_001"
|
||||||
|
assert record.engine_type == "deepseek"
|
||||||
|
assert record.input_tokens == 500
|
||||||
|
assert record.output_tokens == 1000
|
||||||
|
assert record.cost == 0.00625
|
||||||
|
|
||||||
|
def test_record_auto_generates_id(self, tracker):
|
||||||
|
record = tracker.record(
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="deepseek",
|
||||||
|
query="测试",
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=200,
|
||||||
|
cost=0.001,
|
||||||
|
)
|
||||||
|
assert record.id.startswith("usage_")
|
||||||
|
|
||||||
|
def test_record_auto_generates_timestamp(self, tracker):
|
||||||
|
record = tracker.record(
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="deepseek",
|
||||||
|
query="测试",
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=200,
|
||||||
|
cost=0.001,
|
||||||
|
)
|
||||||
|
assert record.timestamp is not None
|
||||||
|
|
||||||
|
def test_record_with_metadata(self, tracker):
|
||||||
|
record = tracker.record(
|
||||||
|
user_id="user_001",
|
||||||
|
brand_id="brand_001",
|
||||||
|
engine_type="deepseek",
|
||||||
|
query="测试",
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=200,
|
||||||
|
cost=0.001,
|
||||||
|
metadata={"source": "api"},
|
||||||
|
)
|
||||||
|
assert record.metadata == {"source": "api"}
|
||||||
|
|
||||||
|
def test_get_summary_daily(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="day")
|
||||||
|
assert isinstance(summary, UsageSummary)
|
||||||
|
assert summary.period == "day"
|
||||||
|
assert summary.total_queries == 2
|
||||||
|
assert summary.total_input_tokens == 800
|
||||||
|
assert summary.total_output_tokens == 1600
|
||||||
|
assert summary.total_cost > 0
|
||||||
|
|
||||||
|
def test_get_summary_weekly(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="week")
|
||||||
|
assert summary.period == "week"
|
||||||
|
assert summary.total_queries == 1
|
||||||
|
|
||||||
|
def test_get_summary_monthly(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
||||||
|
assert summary.period == "month"
|
||||||
|
assert summary.total_queries == 1
|
||||||
|
|
||||||
|
def test_get_summary_by_engine(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
||||||
|
assert "deepseek" in summary.by_engine
|
||||||
|
assert "qwen" in summary.by_engine
|
||||||
|
assert summary.by_engine["deepseek"]["queries"] == 1
|
||||||
|
assert summary.by_engine["qwen"]["queries"] == 1
|
||||||
|
|
||||||
|
def test_get_summary_filter_by_brand(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_A", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
tracker.record("user_001", "brand_B", "qwen", "q2", 300, 600, 0.003)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="month", brand_id="brand_A")
|
||||||
|
assert summary.total_queries == 1
|
||||||
|
assert "deepseek" in summary.by_engine
|
||||||
|
|
||||||
|
def test_calculate_cost(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
||||||
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
||||||
|
assert abs(summary.total_cost - 0.009) < 0.001
|
||||||
|
|
||||||
|
def test_check_quota_ok(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 5.0)
|
||||||
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
||||||
|
assert result["status"] == "ok"
|
||||||
|
assert result["used"] == 5.0
|
||||||
|
assert result["limit"] == 100.0
|
||||||
|
assert result["usage_percentage"] == 5.0
|
||||||
|
|
||||||
|
def test_check_quota_warning(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 85.0)
|
||||||
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
||||||
|
assert result["status"] == "warning"
|
||||||
|
assert result["usage_percentage"] == 85.0
|
||||||
|
|
||||||
|
def test_check_quota_exceeded(self, tracker):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 120.0)
|
||||||
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
||||||
|
assert result["status"] == "exceeded"
|
||||||
|
assert result["usage_percentage"] == 120.0
|
||||||
|
|
||||||
|
def test_check_quota_zero_usage(self, tracker):
|
||||||
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
||||||
|
assert result["status"] == "ok"
|
||||||
|
assert result["used"] == 0.0
|
||||||
|
assert result["usage_percentage"] == 0.0
|
||||||
|
|
||||||
|
def test_multiple_records_accumulate(self, tracker):
|
||||||
|
for i in range(5):
|
||||||
|
tracker.record("user_001", "brand_001", "deepseek", f"q{i}", 100, 200, 0.5)
|
||||||
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
||||||
|
assert summary.total_queries == 5
|
||||||
|
assert summary.total_input_tokens == 500
|
||||||
|
assert summary.total_output_tokens == 1000
|
||||||
|
assert abs(summary.total_cost - 2.5) < 0.001
|
||||||
Loading…
Reference in New Issue