import pytest from app.services.llm.smart_router import ( ENGINE_COST_PROFILES, CostTier, EngineCostProfile, SmartRouter, ) from app.services.usage_tracker import ( UsageRecord, UsageSummary, UsageTracker, ) class TestEngineCostProfile: """EngineCostProfile数据结构测试""" def test_creation(self): profile = EngineCostProfile( engine_type="deepseek", cost_tier=CostTier.FREE, input_price_per_million=0.25, output_price_per_million=6.0, has_free_tier=True, requires_own_key=False, geo_relevance=4, domestic=True, ) assert profile.engine_type == "deepseek" assert profile.cost_tier == CostTier.FREE assert profile.input_price_per_million == 0.25 assert profile.output_price_per_million == 6.0 assert profile.has_free_tier is True assert profile.requires_own_key is False assert profile.geo_relevance == 4 assert profile.domestic is True def test_cost_tier_values(self): assert CostTier.FREE == "free" assert CostTier.LOW_COST == "low_cost" assert CostTier.MID_COST == "mid_cost" assert CostTier.HIGH_COST == "high_cost" def test_engine_cost_profiles_completeness(self): expected_engines = {"deepseek", "qwen", "wenxin", "kimi", "doubao", "gemini", "yuanbao", "chatgpt", "perplexity"} assert set(ENGINE_COST_PROFILES.keys()) == expected_engines def test_free_tier_engines_have_free_flag(self): for name, profile in ENGINE_COST_PROFILES.items(): if profile.cost_tier == CostTier.FREE: assert profile.has_free_tier is True, f"{name} is FREE tier but has_free_tier=False" def test_high_cost_engines_require_own_key(self): for name, profile in ENGINE_COST_PROFILES.items(): if profile.cost_tier == CostTier.HIGH_COST: assert profile.requires_own_key is True, f"{name} is HIGH_COST but requires_own_key=False" class TestSmartRouter: """SmartRouter智能路由服务测试""" @pytest.fixture def router(self): return SmartRouter() @pytest.fixture def router_with_user_engines(self): return SmartRouter(user_engines=["chatgpt", "perplexity"]) def test_initialization_default(self, router): assert len(router.available_engines) == len(ENGINE_COST_PROFILES) assert router.user_engines == [] def test_initialization_with_available_engines(self): r = SmartRouter(available_engines=["deepseek", "qwen"]) assert r.available_engines == ["deepseek", "qwen"] def test_initialization_with_user_engines(self): r = SmartRouter(user_engines=["chatgpt"]) assert r.user_engines == ["chatgpt"] def test_select_engines_returns_list(self, router): result = router.select_engines() assert isinstance(result, list) assert len(result) > 0 def test_select_engines_respects_max_engines(self, router): result = router.select_engines(max_engines=3) assert len(result) <= 3 def test_select_engines_priority_free_first(self, router): result = router.select_engines(max_engines=5) free_engines = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE] non_free = [e for e in result if ENGINE_COST_PROFILES[e].cost_tier != CostTier.FREE] free_count = len(free_engines) if free_count > 0 and len(non_free) > 0: first_non_free_idx = result.index(non_free[0]) last_free_idx = max(result.index(e) for e in free_engines) assert last_free_idx < first_non_free_idx def test_tiered_routing_order(self, router): result = router.select_engines(max_engines=10) tiers = [ENGINE_COST_PROFILES[e].cost_tier for e in result] tier_order = {CostTier.FREE: 0, CostTier.LOW_COST: 1, CostTier.MID_COST: 2, CostTier.HIGH_COST: 3} tier_values = [tier_order[t] for t in tiers] for i in range(len(tier_values) - 1): assert tier_values[i] <= tier_values[i + 1], f"分层路由顺序错误: {result[i]}({tiers[i]}) 排在 {result[i+1]}({tiers[i+1]}) 前面" def test_user_engines_prioritized(self, router_with_user_engines): result = router_with_user_engines.select_engines(max_engines=9) chatgpt_idx = result.index("chatgpt") if "chatgpt" in result else 999 other_high_idx = None for e in result: if ENGINE_COST_PROFILES[e].cost_tier == CostTier.HIGH_COST and e != "chatgpt": other_high_idx = result.index(e) break if other_high_idx is not None: assert chatgpt_idx < other_high_idx def test_engine_fallback_when_unavailable(self): r = SmartRouter(available_engines=["chatgpt", "perplexity"]) result = r.select_engines(max_engines=5) assert len(result) == 2 assert "chatgpt" in result assert "perplexity" in result def test_prefer_domestic(self, router): result = router.select_engines(max_engines=5, prefer_domestic=True) domestic = [e for e in result if ENGINE_COST_PROFILES[e].domestic] international = [e for e in result if not ENGINE_COST_PROFILES[e].domestic] if domestic and international: last_domestic_idx = max(result.index(e) for e in domestic) first_intl_idx = min(result.index(e) for e in international) assert last_domestic_idx < first_intl_idx def test_get_cost_estimate(self, router): engines = ["deepseek", "qwen"] estimate = router.get_cost_estimate(engines) assert "total_cost" in estimate assert "per_engine" in estimate assert "deepseek" in estimate["per_engine"] assert "qwen" in estimate["per_engine"] assert estimate["total_cost"] > 0 def test_get_cost_estimate_values(self, router): engines = ["deepseek"] estimate = router.get_cost_estimate(engines, estimated_input_tokens=1_000_000, estimated_output_tokens=1_000_000) expected_cost = 0.25 + 6.0 assert abs(estimate["total_cost"] - expected_cost) < 0.001 def test_get_recommended_combination(self, router): combo = router.get_recommended_combination() assert "basic" in combo assert "standard" in combo assert "premium" in combo assert "user_premium" in combo assert len(combo["basic"]) <= 3 for e in combo["basic"]: assert ENGINE_COST_PROFILES[e].cost_tier == CostTier.FREE def test_select_engines_no_duplicates(self, router): result = router.select_engines(max_engines=10) assert len(result) == len(set(result)) class TestUsageRecord: """UsageRecord数据结构测试""" def test_creation(self): from datetime import UTC, datetime now = datetime.now(UTC) record = UsageRecord( id="usage_1", user_id="user_001", brand_id="brand_001", engine_type="deepseek", query="测试查询", input_tokens=500, output_tokens=1000, cost=0.00625, timestamp=now, ) assert record.id == "usage_1" assert record.user_id == "user_001" assert record.brand_id == "brand_001" assert record.engine_type == "deepseek" assert record.query == "测试查询" assert record.input_tokens == 500 assert record.output_tokens == 1000 assert record.cost == 0.00625 assert record.timestamp == now def test_default_metadata(self): from datetime import UTC, datetime record = UsageRecord( id="usage_2", user_id="user_001", brand_id="brand_001", engine_type="qwen", query="测试", input_tokens=100, output_tokens=200, cost=0.001, timestamp=datetime.now(UTC), ) assert record.metadata == {} def test_custom_metadata(self): from datetime import UTC, datetime record = UsageRecord( id="usage_3", user_id="user_001", brand_id="brand_001", engine_type="qwen", query="测试", input_tokens=100, output_tokens=200, cost=0.001, timestamp=datetime.now(UTC), metadata={"source": "api"}, ) assert record.metadata == {"source": "api"} class TestUsageTracker: """UsageTracker用量追踪服务测试""" @pytest.fixture def tracker(self): return UsageTracker() def test_record_returns_usage_record(self, tracker): record = tracker.record( user_id="user_001", brand_id="brand_001", engine_type="deepseek", query="测试查询", input_tokens=500, output_tokens=1000, cost=0.00625, ) assert isinstance(record, UsageRecord) assert record.user_id == "user_001" assert record.engine_type == "deepseek" assert record.input_tokens == 500 assert record.output_tokens == 1000 assert record.cost == 0.00625 def test_record_auto_generates_id(self, tracker): record = tracker.record( user_id="user_001", brand_id="brand_001", engine_type="deepseek", query="测试", input_tokens=100, output_tokens=200, cost=0.001, ) assert record.id.startswith("usage_") def test_record_auto_generates_timestamp(self, tracker): record = tracker.record( user_id="user_001", brand_id="brand_001", engine_type="deepseek", query="测试", input_tokens=100, output_tokens=200, cost=0.001, ) assert record.timestamp is not None def test_record_with_metadata(self, tracker): record = tracker.record( user_id="user_001", brand_id="brand_001", engine_type="deepseek", query="测试", input_tokens=100, output_tokens=200, cost=0.001, metadata={"source": "api"}, ) assert record.metadata == {"source": "api"} def test_get_summary_daily(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006) tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003) summary = tracker.get_summary(user_id="user_001", period="day") assert isinstance(summary, UsageSummary) assert summary.period == "day" assert summary.total_queries == 2 assert summary.total_input_tokens == 800 assert summary.total_output_tokens == 1600 assert summary.total_cost > 0 def test_get_summary_weekly(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006) summary = tracker.get_summary(user_id="user_001", period="week") assert summary.period == "week" assert summary.total_queries == 1 def test_get_summary_monthly(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006) summary = tracker.get_summary(user_id="user_001", period="month") assert summary.period == "month" assert summary.total_queries == 1 def test_get_summary_by_engine(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006) tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003) summary = tracker.get_summary(user_id="user_001", period="month") assert "deepseek" in summary.by_engine assert "qwen" in summary.by_engine assert summary.by_engine["deepseek"]["queries"] == 1 assert summary.by_engine["qwen"]["queries"] == 1 def test_get_summary_filter_by_brand(self, tracker): tracker.record("user_001", "brand_A", "deepseek", "q1", 500, 1000, 0.006) tracker.record("user_001", "brand_B", "qwen", "q2", 300, 600, 0.003) summary = tracker.get_summary(user_id="user_001", period="month", brand_id="brand_A") assert summary.total_queries == 1 assert "deepseek" in summary.by_engine def test_calculate_cost(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 0.006) tracker.record("user_001", "brand_001", "qwen", "q2", 300, 600, 0.003) summary = tracker.get_summary(user_id="user_001", period="month") assert abs(summary.total_cost - 0.009) < 0.001 def test_check_quota_ok(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 5.0) result = tracker.check_quota(user_id="user_001", monthly_limit=100.0) assert result["status"] == "ok" assert result["used"] == 5.0 assert result["limit"] == 100.0 assert result["usage_percentage"] == 5.0 def test_check_quota_warning(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 85.0) result = tracker.check_quota(user_id="user_001", monthly_limit=100.0) assert result["status"] == "warning" assert result["usage_percentage"] == 85.0 def test_check_quota_exceeded(self, tracker): tracker.record("user_001", "brand_001", "deepseek", "q1", 500, 1000, 120.0) result = tracker.check_quota(user_id="user_001", monthly_limit=100.0) assert result["status"] == "exceeded" assert result["usage_percentage"] == 120.0 def test_check_quota_zero_usage(self, tracker): result = tracker.check_quota(user_id="user_001", monthly_limit=100.0) assert result["status"] == "ok" assert result["used"] == 0.0 assert result["usage_percentage"] == 0.0 def test_multiple_records_accumulate(self, tracker): for i in range(5): tracker.record("user_001", "brand_001", "deepseek", f"q{i}", 100, 200, 0.5) summary = tracker.get_summary(user_id="user_001", period="month") assert summary.total_queries == 5 assert summary.total_input_tokens == 500 assert summary.total_output_tokens == 1000 assert abs(summary.total_cost - 2.5) < 0.001