369 lines
14 KiB
Python
369 lines
14 KiB
Python
import pytest
|
|
|
|
from app.services.llm.smart_router import (
|
|
ENGINE_COST_PROFILES,
|
|
CostTier,
|
|
EngineCostProfile,
|
|
SmartRouter,
|
|
)
|
|
from app.services.usage_tracker import (
|
|
UsageRecord,
|
|
UsageSummary,
|
|
UsageTracker,
|
|
)
|
|
|
|
|
|
class TestEngineCostProfile:
|
|
"""EngineCostProfile数据结构测试"""
|
|
|
|
def test_creation(self):
|
|
profile = EngineCostProfile(
|
|
engine_type="deepseek",
|
|
cost_tier=CostTier.FREE,
|
|
input_price_per_million=0.25,
|
|
output_price_per_million=6.0,
|
|
has_free_tier=True,
|
|
requires_own_key=False,
|
|
geo_relevance=4,
|
|
domestic=True,
|
|
)
|
|
assert profile.engine_type == "deepseek"
|
|
assert profile.cost_tier == CostTier.FREE
|
|
assert profile.input_price_per_million == 0.25
|
|
assert profile.output_price_per_million == 6.0
|
|
assert profile.has_free_tier is True
|
|
assert profile.requires_own_key is False
|
|
assert profile.geo_relevance == 4
|
|
assert profile.domestic is True
|
|
|
|
def test_cost_tier_values(self):
|
|
assert CostTier.FREE == "free"
|
|
assert CostTier.LOW_COST == "low_cost"
|
|
assert CostTier.MID_COST == "mid_cost"
|
|
assert CostTier.HIGH_COST == "high_cost"
|
|
|
|
def test_engine_cost_profiles_completeness(self):
|
|
expected_engines = {"deepseek", "qwen", "wenxin", "kimi", "doubao", "gemini", "yuanbao", "chatgpt", "perplexity"}
|
|
assert set(ENGINE_COST_PROFILES.keys()) == expected_engines
|
|
|
|
def test_free_tier_engines_have_free_flag(self):
|
|
for name, profile in ENGINE_COST_PROFILES.items():
|
|
if profile.cost_tier == CostTier.FREE:
|
|
assert profile.has_free_tier is True, f"{name} is FREE tier but has_free_tier=False"
|
|
|
|
def test_high_cost_engines_require_own_key(self):
|
|
for name, profile in ENGINE_COST_PROFILES.items():
|
|
if profile.cost_tier == CostTier.HIGH_COST:
|
|
assert profile.requires_own_key is True, f"{name} is HIGH_COST but requires_own_key=False"
|
|
|
|
|
|
class TestSmartRouter:
|
|
"""SmartRouter智能路由服务测试"""
|
|
|
|
@pytest.fixture
|
|
def router(self):
|
|
return SmartRouter()
|
|
|
|
@pytest.fixture
|
|
def router_with_user_engines(self):
|
|
return SmartRouter(user_engines=["chatgpt", "perplexity"])
|
|
|
|
def test_initialization_default(self, router):
|
|
assert len(router.available_engines) == len(ENGINE_COST_PROFILES)
|
|
assert router.user_engines == []
|
|
|
|
def test_initialization_with_available_engines(self):
|
|
r = SmartRouter(available_engines=["deepseek", "qwen"])
|
|
assert r.available_engines == ["deepseek", "qwen"]
|
|
|
|
def test_initialization_with_user_engines(self):
|
|
r = SmartRouter(user_engines=["chatgpt"])
|
|
assert r.user_engines == ["chatgpt"]
|
|
|
|
def test_select_engines_returns_list(self, router):
|
|
result = router.select_engines()
|
|
assert isinstance(result, list)
|
|
assert len(result) > 0
|
|
|
|
def test_select_engines_respects_max_engines(self, router):
|
|
result = router.select_engines(max_engines=3)
|
|
assert len(result) <= 3
|
|
|
|
def test_select_engines_priority_free_first(self, router):
|
|
result = router.select_engines(max_engines=5)
|
|
free_engines = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE]
|
|
non_free = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier != CostTier.FREE]
|
|
free_count = len(free_engines)
|
|
if free_count > 0 and len(non_free) > 0:
|
|
first_non_free_idx = result.index(non_free[0])
|
|
last_free_idx = max(result.index(e) for e in free_engines)
|
|
assert last_free_idx < first_non_free_idx
|
|
|
|
def test_tiered_routing_order(self, router):
|
|
result = router.select_engines(max_engines=10)
|
|
tiers = [ENGINE_COST_PROFILES[e].cost_tier for e in result]
|
|
tier_order = {CostTier.FREE: 0, CostTier.LOW_COST: 1, CostTier.MID_COST: 2, CostTier.HIGH_COST: 3}
|
|
tier_values = [tier_order[t] for t in tiers]
|
|
for i in range(len(tier_values) - 1):
|
|
assert tier_values[i] <= tier_values[i + 1], f"分层路由顺序错误: {result[i]}({tiers[i]}) 排在 {result[i+1]}({tiers[i+1]}) 前面"
|
|
|
|
def test_user_engines_prioritized(self, router_with_user_engines):
|
|
result = router_with_user_engines.select_engines(max_engines=9)
|
|
chatgpt_idx = result.index("chatgpt") if "chatgpt" in result else 999
|
|
other_high_idx = None
|
|
for e in result:
|
|
if ENGINE_COST_PROFILES[e].cost_tier == CostTier.HIGH_COST and e != "chatgpt":
|
|
other_high_idx = result.index(e)
|
|
break
|
|
if other_high_idx is not None:
|
|
assert chatgpt_idx < other_high_idx
|
|
|
|
def test_engine_fallback_when_unavailable(self):
|
|
r = SmartRouter(available_engines=["chatgpt", "perplexity"])
|
|
result = r.select_engines(max_engines=5)
|
|
assert len(result) == 2
|
|
assert "chatgpt" in result
|
|
assert "perplexity" in result
|
|
|
|
def test_prefer_domestic(self, router):
|
|
result = router.select_engines(max_engines=5, prefer_domestic=True)
|
|
domestic = [e for e in result if ENGINE_COST_PROFILES[e].domestic]
|
|
international = [e for e in result if not ENGINE_COST_PROFILES[e].domestic]
|
|
if domestic and international:
|
|
last_domestic_idx = max(result.index(e) for e in domestic)
|
|
first_intl_idx = min(result.index(e) for e in international)
|
|
assert last_domestic_idx < first_intl_idx
|
|
|
|
def test_get_cost_estimate(self, router):
|
|
engines = ["deepseek", "qwen"]
|
|
estimate = router.get_cost_estimate(engines)
|
|
assert "total_cost" in estimate
|
|
assert "per_engine" in estimate
|
|
assert "deepseek" in estimate["per_engine"]
|
|
assert "qwen" in estimate["per_engine"]
|
|
assert estimate["total_cost"] > 0
|
|
|
|
def test_get_cost_estimate_values(self, router):
|
|
engines = ["deepseek"]
|
|
estimate = router.get_cost_estimate(engines, estimated_input_tokens=1_000_000, estimated_output_tokens=1_000_000)
|
|
expected_cost = 0.25 + 6.0
|
|
assert abs(estimate["total_cost"] - expected_cost) < 0.001
|
|
|
|
def test_get_recommended_combination(self, router):
|
|
combo = router.get_recommended_combination()
|
|
assert "basic" in combo
|
|
assert "standard" in combo
|
|
assert "premium" in combo
|
|
assert "user_premium" in combo
|
|
assert len(combo["basic"]) <= 3
|
|
for e in combo["basic"]:
|
|
assert ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE
|
|
|
|
def test_select_engines_no_duplicates(self, router):
|
|
result = router.select_engines(max_engines=10)
|
|
assert len(result) == len(set(result))
|
|
|
|
|
|
class TestUsageRecord:
|
|
"""UsageRecord数据结构测试"""
|
|
|
|
def test_creation(self):
|
|
from datetime import UTC, datetime
|
|
|
|
now = datetime.now(UTC)
|
|
record = UsageRecord(
|
|
id="usage_1",
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="deepseek",
|
|
query="测试查询",
|
|
input_tokens=500,
|
|
output_tokens=1000,
|
|
cost=0.00625,
|
|
timestamp=now,
|
|
)
|
|
assert record.id == "usage_1"
|
|
assert record.user_id == "user_001"
|
|
assert record.brand_id == "brand_001"
|
|
assert record.engine_type == "deepseek"
|
|
assert record.query == "测试查询"
|
|
assert record.input_tokens == 500
|
|
assert record.output_tokens == 1000
|
|
assert record.cost == 0.00625
|
|
assert record.timestamp == now
|
|
|
|
def test_default_metadata(self):
|
|
from datetime import UTC, datetime
|
|
|
|
record = UsageRecord(
|
|
id="usage_2",
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="qwen",
|
|
query="测试",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.001,
|
|
timestamp=datetime.now(UTC),
|
|
)
|
|
assert record.metadata == {}
|
|
|
|
def test_custom_metadata(self):
|
|
from datetime import UTC, datetime
|
|
|
|
record = UsageRecord(
|
|
id="usage_3",
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="qwen",
|
|
query="测试",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.001,
|
|
timestamp=datetime.now(UTC),
|
|
metadata={"source": "api"},
|
|
)
|
|
assert record.metadata == {"source": "api"}
|
|
|
|
|
|
class TestUsageTracker:
|
|
"""UsageTracker用量追踪服务测试"""
|
|
|
|
@pytest.fixture
|
|
def tracker(self):
|
|
return UsageTracker()
|
|
|
|
def test_record_returns_usage_record(self, tracker):
|
|
record = tracker.record(
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="deepseek",
|
|
query="测试查询",
|
|
input_tokens=500,
|
|
output_tokens=1000,
|
|
cost=0.00625,
|
|
)
|
|
assert isinstance(record, UsageRecord)
|
|
assert record.user_id == "user_001"
|
|
assert record.engine_type == "deepseek"
|
|
assert record.input_tokens == 500
|
|
assert record.output_tokens == 1000
|
|
assert record.cost == 0.00625
|
|
|
|
def test_record_auto_generates_id(self, tracker):
|
|
record = tracker.record(
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="deepseek",
|
|
query="测试",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.001,
|
|
)
|
|
assert record.id.startswith("usage_")
|
|
|
|
def test_record_auto_generates_timestamp(self, tracker):
|
|
record = tracker.record(
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="deepseek",
|
|
query="测试",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.001,
|
|
)
|
|
assert record.timestamp is not None
|
|
|
|
def test_record_with_metadata(self, tracker):
|
|
record = tracker.record(
|
|
user_id="user_001",
|
|
brand_id="brand_001",
|
|
engine_type="deepseek",
|
|
query="测试",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.001,
|
|
metadata={"source": "api"},
|
|
)
|
|
assert record.metadata == {"source": "api"}
|
|
|
|
def test_get_summary_daily(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
|
summary = tracker.get_summary(user_id="user_001", period="day")
|
|
assert isinstance(summary, UsageSummary)
|
|
assert summary.period == "day"
|
|
assert summary.total_queries == 2
|
|
assert summary.total_input_tokens == 800
|
|
assert summary.total_output_tokens == 1600
|
|
assert summary.total_cost > 0
|
|
|
|
def test_get_summary_weekly(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
|
summary = tracker.get_summary(user_id="user_001", period="week")
|
|
assert summary.period == "week"
|
|
assert summary.total_queries == 1
|
|
|
|
def test_get_summary_monthly(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
|
assert summary.period == "month"
|
|
assert summary.total_queries == 1
|
|
|
|
def test_get_summary_by_engine(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
|
assert "deepseek" in summary.by_engine
|
|
assert "qwen" in summary.by_engine
|
|
assert summary.by_engine["deepseek"]["queries"] == 1
|
|
assert summary.by_engine["qwen"]["queries"] == 1
|
|
|
|
def test_get_summary_filter_by_brand(self, tracker):
|
|
tracker.record("user_001", "brand_A", "deepseek", "q1", 500, 1000, 0.006)
|
|
tracker.record("user_001", "brand_B", "qwen", "q2", 300, 600, 0.003)
|
|
summary = tracker.get_summary(user_id="user_001", period="month", brand_id="brand_A")
|
|
assert summary.total_queries == 1
|
|
assert "deepseek" in summary.by_engine
|
|
|
|
def test_calculate_cost(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006)
|
|
tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003)
|
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
|
assert abs(summary.total_cost - 0.009) < 0.001
|
|
|
|
def test_check_quota_ok(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 5.0)
|
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
|
assert result["status"] == "ok"
|
|
assert result["used"] == 5.0
|
|
assert result["limit"] == 100.0
|
|
assert result["usage_percentage"] == 5.0
|
|
|
|
def test_check_quota_warning(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 85.0)
|
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
|
assert result["status"] == "warning"
|
|
assert result["usage_percentage"] == 85.0
|
|
|
|
def test_check_quota_exceeded(self, tracker):
|
|
tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 120.0)
|
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
|
assert result["status"] == "exceeded"
|
|
assert result["usage_percentage"] == 120.0
|
|
|
|
def test_check_quota_zero_usage(self, tracker):
|
|
result = tracker.check_quota(user_id="user_001", monthly_limit=100.0)
|
|
assert result["status"] == "ok"
|
|
assert result["used"] == 0.0
|
|
assert result["usage_percentage"] == 0.0
|
|
|
|
def test_multiple_records_accumulate(self, tracker):
|
|
for i in range(5):
|
|
tracker.record("user_001", "brand_001", "deepseek", f"q{i}", 100, 200, 0.5)
|
|
summary = tracker.get_summary(user_id="user_001", period="month")
|
|
assert summary.total_queries == 5
|
|
assert summary.total_input_tokens == 500
|
|
assert summary.total_output_tokens == 1000
|
|
assert abs(summary.total_cost - 2.5) < 0.001
|