From 4cc8f73bb4e608a10cd3712a4ad67c8444271323 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 25 May 2026 20:43:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20API=20Key=E7=AE=A1=E7=90=86+=E7=94=A8?= =?UTF-8?q?=E9=87=8F=E8=BF=BD=E8=B8=AA=E5=AE=8C=E6=95=B4=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E9=93=BE=E8=B7=AFv2=EF=BC=88=E7=9C=9F=E5=AE=9E=E5=8F=AF?= =?UTF-8?q?=E7=94=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 持久化存储: - APIKey模型 + APIKeyRepository(SQLAlchemy) - UsageRecord模型 + UsageRepository(SQLAlchemy) API Key验证: - KeyVerifier服务(真正调用引擎API验证) - 支持9个引擎的真实性验证 加密存储: - KeyEncryption服务(Fernet AES加密) - 环境变量API_KEY_ENCRYPTION_KEY 用量追踪: - UsageRecorder自动记录查询用量 - 按引擎/按日聚合(修复by_day空dict) - UserQuotaService支持套餐配额(free:10/basic:50/pro:200/enterprise:1000) 集成修复: - AI引擎适配器使用APIKeyManager获取Key(用户Key>系统Key>环境变量) - SmartRouter与APIKeyManager集成(过滤无Key引擎) - BatchQueryService自动记录用量并传递用户上下文 - 所有适配器支持引擎特定代理环境变量 前端: - usage页面替换MOCK为真实API调用 - 显示加载/错误/空状态 测试: 630 passed --- backend/app/api/ai_engines.py | 48 +-- backend/app/api/deps.py | 9 + backend/app/api/usage.py | 11 +- backend/app/models/__init__.py | 4 + backend/app/models/api_key.py | 44 ++ backend/app/models/usage_record.py | 67 +++ backend/app/repositories/__init__.py | 7 + .../app/repositories/api_key_repository.py | 119 +++++ backend/app/repositories/usage_repository.py | 175 ++++++++ backend/app/services/ai_engine/base.py | 46 +- backend/app/services/ai_engine/batch_query.py | 108 ++++- backend/app/services/ai_engine/chatgpt.py | 20 +- backend/app/services/ai_engine/deepseek.py | 20 +- backend/app/services/ai_engine/doubao.py | 22 +- backend/app/services/ai_engine/gemini.py | 21 +- backend/app/services/ai_engine/kimi.py | 22 +- backend/app/services/ai_engine/perplexity.py | 20 +- backend/app/services/ai_engine/qwen.py | 22 +- backend/app/services/ai_engine/wenxin.py | 22 +- backend/app/services/ai_engine/yuanbao.py | 22 +- backend/app/services/api_key_manager.py | 138 +++++- backend/app/services/detection_scheduler.py | 3 +- backend/app/services/engine_selector.py | 37 ++ backend/app/services/key_encryption.py | 57 +++ backend/app/services/key_verifier.py | 325 ++++++++++++++ backend/app/services/smart_router.py | 42 +- backend/app/services/usage_recorder.py | 38 ++ backend/app/services/usage_tracker.py | 72 +++- backend/app/services/user_quota_service.py | 37 ++ backend/tests/test_content_pipeline.py | 89 ---- backend/tests/test_models/test_api_key.py | 306 +++++++++++++ .../tests/test_models/test_usage_record.py | 406 ++++++++++++++++++ .../test_usage_quota_integration.py | 377 ++++++++++++++++ .../test_usage_repository.py | 371 ++++++++++++++++ .../test_services/test_adapter_key_source.py | 107 +++++ .../test_adapter_registration.py | 115 +++++ .../test_services/test_ai_engine_query.py | 15 + .../test_services/test_api_key_manager.py | 11 +- .../test_services/test_batch_query_service.py | 3 + .../test_services/test_key_encryption.py | 164 +++++++ .../tests/test_services/test_key_verifier.py | 354 +++++++++++++++ .../test_services/test_proxy_and_deepseek.py | 3 + .../test_smart_router_key_integration.py | 165 +++++++ .../test_services/test_usage_recording.py | 320 ++++++++++++++ .../app/(dashboard)/dashboard/usage/page.tsx | 159 +++++-- 45 files changed, 4342 insertions(+), 201 deletions(-) create mode 100644 backend/app/models/api_key.py create mode 100644 backend/app/models/usage_record.py create mode 100644 backend/app/repositories/__init__.py create mode 100644 backend/app/repositories/api_key_repository.py create mode 100644 backend/app/repositories/usage_repository.py create mode 100644 backend/app/services/engine_selector.py create mode 100644 backend/app/services/key_encryption.py create mode 100644 backend/app/services/key_verifier.py create mode 100644 backend/app/services/usage_recorder.py create mode 100644 backend/app/services/user_quota_service.py delete mode 100644 backend/tests/test_content_pipeline.py create mode 100644 backend/tests/test_models/test_api_key.py create mode 100644 backend/tests/test_models/test_usage_record.py create mode 100644 backend/tests/test_repositories/test_usage_quota_integration.py create mode 100644 backend/tests/test_repositories/test_usage_repository.py create mode 100644 backend/tests/test_services/test_adapter_key_source.py create mode 100644 backend/tests/test_services/test_adapter_registration.py create mode 100644 backend/tests/test_services/test_key_encryption.py create mode 100644 backend/tests/test_services/test_key_verifier.py create mode 100644 backend/tests/test_services/test_smart_router_key_integration.py create mode 100644 backend/tests/test_services/test_usage_recording.py diff --git a/backend/app/api/ai_engines.py b/backend/app/api/ai_engines.py index e5963c5..4a8f356 100644 --- a/backend/app/api/ai_engines.py +++ b/backend/app/api/ai_engines.py @@ -1,13 +1,13 @@ import logging -from functools import lru_cache +from typing import TYPE_CHECKING from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field -from app.api.deps import get_current_user +from app.api.deps import get_current_user, get_key_manager from app.models.user import User -from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType -from app.services.ai_engine.batch_query import BatchQueryService +from app.services.ai_engine.base import AIQueryResult, EngineType +from app.services.ai_engine.batch_query import BatchQueryService, get_batch_service as _get_batch_service from app.services.ai_engine.chatgpt import ChatGPTAdapter from app.services.ai_engine.doubao import DoubaoAdapter from app.services.ai_engine.kimi import KimiAdapter @@ -15,6 +15,9 @@ from app.services.ai_engine.perplexity import PerplexityAdapter from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.yuanbao import YuanbaoAdapter +if TYPE_CHECKING: + from app.services.api_key_manager import APIKeyManager + logger = logging.getLogger(__name__) router = APIRouter() @@ -60,31 +63,6 @@ class BatchQueryResponse(BaseModel): citation_rate: CitationRateResponse -_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = { - EngineType.CHATGPT: ChatGPTAdapter, - EngineType.PERPLEXITY: PerplexityAdapter, - EngineType.KIMI: KimiAdapter, - EngineType.WENXIN: WenxinAdapter, - EngineType.DOUBAO: DoubaoAdapter, - EngineType.YUANBAO: YuanbaoAdapter, -} - - -@lru_cache(maxsize=1) -def _build_adapters() -> dict[str, AIEngineAdapter]: - adapters: dict[str, AIEngineAdapter] = {} - for engine_type, cls in _ADAPTER_CLASSES.items(): - try: - adapters[engine_type.value] = cls() - except Exception: - logger.warning(f"Failed to initialize {engine_type.value} adapter") - return adapters - - -def get_batch_service() -> BatchQueryService: - return BatchQueryService(_build_adapters()) - - def _result_to_response(r: AIQueryResult) -> QueryResultResponse: return QueryResultResponse( engine_type=r.engine_type.value, @@ -129,8 +107,10 @@ async def _execute_batch( async def query_single_engine( request: SingleQueryRequest, current_user: User = Depends(get_current_user), + key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): - service = get_batch_service() + service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) + service.set_user_context(str(current_user.id)) engine_type = _parse_engine(request.engine) try: result = await service.query_single( @@ -148,8 +128,10 @@ async def query_single_engine( async def query_batch_engines( request: BatchQueryRequest, current_user: User = Depends(get_current_user), + key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): - service = get_batch_service() + service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) + service.set_user_context(str(current_user.id)) engine_types = [_parse_engine(e) for e in request.engines] return await _execute_batch( service, engine_types, request.query, request.brand_name, request.competitor_names @@ -163,8 +145,10 @@ async def get_query_results( brand_name: str = Query(..., min_length=1, max_length=200), competitor_names: str | None = Query(None, description="Comma-separated competitor names"), current_user: User = Depends(get_current_user), + key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): - service = get_batch_service() + service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) + service.set_user_context(str(current_user.id)) engine_list = [e.strip() for e in engines.split(",") if e.strip()] engine_types = [_parse_engine(e) for e in engine_list] comp_names = ( diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index fd1336c..88addb3 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,5 @@ import uuid +from functools import lru_cache from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -8,11 +9,19 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.user import User +from app.services.api_key_manager import APIKeyManager from app.services.auth import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") +@lru_cache(maxsize=1) +def get_key_manager() -> APIKeyManager: + manager = APIKeyManager() + manager.load_env_keys() + return manager + + async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), diff --git a/backend/app/api/usage.py b/backend/app/api/usage.py index d673fdc..d1e1026 100644 --- a/backend/app/api/usage.py +++ b/backend/app/api/usage.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, Query from app.api.deps import get_current_user from app.models.user import User from app.services.usage_tracker import UsageTracker +from app.services.user_quota_service import UserQuotaService logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ async def get_usage_summary( "total_output_tokens": summary.total_output_tokens, "total_cost": summary.total_cost, "by_engine": summary.by_engine, + "by_day": summary.by_day, } @@ -51,7 +53,14 @@ async def get_usage_summary( async def get_quota( current_user: User = Depends(get_current_user), ): - return get_usage_tracker().check_quota(user_id=_user_id(current_user)) + tracker = get_usage_tracker() + if tracker._session: + quota_service = UserQuotaService(session=tracker._session) + return await quota_service.check_quota_with_plan( + user_id=_user_id(current_user), + user_plan=current_user.plan, + ) + return tracker.check_quota(user_id=_user_id(current_user)) @router.get("/by-engine") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index da5bf86..cceceb4 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,5 @@ from app.models.user import User +from app.models.api_key import APIKey from app.models.query import Query from app.models.citation_record import CitationRecord from app.models.query_task import QueryTask @@ -30,9 +31,11 @@ from app.models.suggestion import Suggestion from app.models.alert import Alert from app.models.alert_setting import AlertSetting from app.models.detection_task import DetectionTask +from app.models.usage_record import UsageRecord __all__ = [ "User", + "APIKey", "Query", "CitationRecord", "QueryTask", @@ -70,4 +73,5 @@ __all__ = [ "Alert", "AlertSetting", "DetectionTask", + "UsageRecord", ] diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py new file mode 100644 index 0000000..efff1fb --- /dev/null +++ b/backend/app/models/api_key.py @@ -0,0 +1,44 @@ +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Index, String, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class APIKey(Base): + __tablename__ = "api_keys" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + nullable=False, + index=True, + ) + engine_type: Mapped[str] = mapped_column(String(20), nullable=False) + encrypted_key: Mapped[str] = mapped_column(String(500), nullable=False) + key_hint: Mapped[str] = mapped_column(String(50), nullable=False) + key_source: Mapped[str] = mapped_column(String(10), default="user") + status: Mapped[str] = mapped_column(String(20), default="active") + priority: Mapped[int] = mapped_column(default=0) + last_verified_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + __table_args__ = ( + Index("idx_api_keys_user_engine", "user_id", "engine_type"), + Index("idx_api_keys_engine_status", "engine_type", "status"), + ) diff --git a/backend/app/models/usage_record.py b/backend/app/models/usage_record.py new file mode 100644 index 0000000..0dac114 --- /dev/null +++ b/backend/app/models/usage_record.py @@ -0,0 +1,67 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Integer, Float, DateTime, ForeignKey, Index, func +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base, JSONType + + +class UsageRecord(Base): + __tablename__ = "usage_records" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + user_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + nullable=False, + index=True, + ) + brand_id: Mapped[uuid.UUID | None] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="SET NULL"), + nullable=True, + ) + engine_type: Mapped[str] = mapped_column( + String(20), + nullable=False, + ) + query: Mapped[str] = mapped_column( + String(500), + nullable=False, + ) + input_tokens: Mapped[int] = mapped_column( + Integer, + default=0, + ) + output_tokens: Mapped[int] = mapped_column( + Integer, + default=0, + ) + cost: Mapped[float] = mapped_column( + Float, + default=0.0, + ) + extra_data: Mapped[dict] = mapped_column( + JSONType, + default=dict, + ) + timestamp: Mapped[datetime] = mapped_column( + DateTime, + default=func.now(), + index=True, + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + + __table_args__ = ( + Index("idx_usage_records_user_engine", "user_id", "engine_type"), + Index("idx_usage_records_user_timestamp", "user_id", "timestamp"), + Index("idx_usage_records_engine_timestamp", "engine_type", "timestamp"), + ) diff --git a/backend/app/repositories/__init__.py b/backend/app/repositories/__init__.py new file mode 100644 index 0000000..8354432 --- /dev/null +++ b/backend/app/repositories/__init__.py @@ -0,0 +1,7 @@ +from app.repositories.api_key_repository import APIKeyRepository +from app.repositories.usage_repository import UsageRepository + +__all__ = [ + "APIKeyRepository", + "UsageRepository", +] diff --git a/backend/app/repositories/api_key_repository.py b/backend/app/repositories/api_key_repository.py new file mode 100644 index 0000000..15b3265 --- /dev/null +++ b/backend/app/repositories/api_key_repository.py @@ -0,0 +1,119 @@ +import uuid +from datetime import datetime + +from sqlalchemy import select, and_, delete, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.api_key import APIKey + + +class APIKeyRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create( + self, + user_id: uuid.UUID, + engine_type: str, + encrypted_key: str, + key_hint: str, + key_source: str = "user", + status: str = "active", + priority: int = 0, + last_verified_at: datetime | None = None, + ) -> APIKey: + api_key = APIKey( + user_id=user_id, + engine_type=engine_type, + encrypted_key=encrypted_key, + key_hint=key_hint, + key_source=key_source, + status=status, + priority=priority, + last_verified_at=last_verified_at, + ) + self.session.add(api_key) + await self.session.commit() + await self.session.refresh(api_key) + return api_key + + async def get_by_id(self, key_id: uuid.UUID) -> APIKey | None: + result = await self.session.execute( + select(APIKey).where(APIKey.id == key_id) + ) + return result.scalar_one_or_none() + + async def get_by_user( + self, + user_id: uuid.UUID, + engine_type: str | None = None, + ) -> list[APIKey]: + conditions = [APIKey.user_id == user_id] + if engine_type: + conditions.append(APIKey.engine_type == engine_type) + + result = await self.session.execute( + select(APIKey).where(and_(*conditions)).order_by(APIKey.priority.desc()) + ) + return list(result.scalars().all()) + + async def get_by_user_and_engine( + self, + user_id: uuid.UUID, + engine_type: str, + ) -> APIKey | None: + result = await self.session.execute( + select(APIKey).where( + and_( + APIKey.user_id == user_id, + APIKey.engine_type == engine_type, + ) + ) + ) + return result.scalar_one_or_none() + + async def update( + self, + key_id: uuid.UUID, + **kwargs, + ) -> APIKey | None: + api_key = await self.get_by_id(key_id) + if not api_key: + return None + + for key, value in kwargs.items(): + if hasattr(api_key, key): + setattr(api_key, key, value) + + api_key.updated_at = datetime.utcnow() + await self.session.commit() + await self.session.refresh(api_key) + return api_key + + async def delete(self, key_id: uuid.UUID) -> bool: + result = await self.session.execute( + delete(APIKey).where(APIKey.id == key_id) + ) + await self.session.commit() + return result.rowcount > 0 + + async def list_all( + self, + engine_type: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[APIKey]: + conditions = [] + if engine_type: + conditions.append(APIKey.engine_type == engine_type) + if status: + conditions.append(APIKey.status == status) + + query = select(APIKey) + if conditions: + query = query.where(and_(*conditions)) + query = query.order_by(APIKey.priority.desc()).limit(limit).offset(offset) + + result = await self.session.execute(query) + return list(result.scalars().all()) diff --git a/backend/app/repositories/usage_repository.py b/backend/app/repositories/usage_repository.py new file mode 100644 index 0000000..bd43b7c --- /dev/null +++ b/backend/app/repositories/usage_repository.py @@ -0,0 +1,175 @@ +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import select, func, and_, case +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.usage_record import UsageRecord + + +_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30} +_QUOTA_WARNING_PCT = 80.0 +_QUOTA_EXCEEDED_PCT = 100.0 + + +def _compute_cutoff(period: str, now: datetime) -> datetime: + days = _PERIOD_CUTOFF_DAYS.get(period, 30) + if days == 0: + return now.replace(hour=0, minute=0, second=0, microsecond=0) + return now - timedelta(days=days) + + +def _quota_status(usage_pct: float) -> str: + if usage_pct >= _QUOTA_EXCEEDED_PCT: + return "exceeded" + if usage_pct >= _QUOTA_WARNING_PCT: + return "warning" + return "ok" + + +class UsageRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create(self, data: dict) -> UsageRecord: + record = UsageRecord( + user_id=data["user_id"], + brand_id=data.get("brand_id"), + engine_type=data["engine_type"], + query=data["query"], + input_tokens=data.get("input_tokens", 0), + output_tokens=data.get("output_tokens", 0), + cost=data.get("cost", 0.0), + extra_data=data.get("extra_data", {}), + timestamp=data.get("timestamp", datetime.now(timezone.utc)), + ) + self.session.add(record) + await self.session.commit() + await self.session.refresh(record) + return record + + async def get_summary( + self, + user_id: str | uuid.UUID, + period: str = "month", + brand_id: str | uuid.UUID | None = None, + ) -> dict: + now = datetime.now(timezone.utc) + cutoff = _compute_cutoff(period, now) + + if isinstance(user_id, str): + user_id = uuid.UUID(user_id) + if brand_id and isinstance(brand_id, str): + brand_id = uuid.UUID(brand_id) + + conditions = [ + UsageRecord.user_id == user_id, + UsageRecord.timestamp >= cutoff, + ] + if brand_id: + conditions.append(UsageRecord.brand_id == brand_id) + + result = await self.session.execute( + select(UsageRecord).where(and_(*conditions)) + ) + records = list(result.scalars().all()) + + total_queries = len(records) + total_input_tokens = sum(r.input_tokens for r in records) + total_output_tokens = sum(r.output_tokens for r in records) + total_cost = round(sum(r.cost for r in records), 4) + + by_engine: dict[str, dict] = {} + for r in records: + bucket = by_engine.setdefault( + r.engine_type, + {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0} + ) + bucket["queries"] += 1 + bucket["input_tokens"] += r.input_tokens + bucket["output_tokens"] += r.output_tokens + bucket["cost"] = round(bucket["cost"] + r.cost, 4) + + by_day: dict[str, dict] = {} + for r in records: + day_key = r.timestamp.strftime("%Y-%m-%d") + bucket = by_day.setdefault( + day_key, + {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0} + ) + bucket["queries"] += 1 + bucket["input_tokens"] += r.input_tokens + bucket["output_tokens"] += r.output_tokens + bucket["cost"] = round(bucket["cost"] + r.cost, 4) + + return { + "period": period, + "start_date": cutoff.isoformat(), + "end_date": now.isoformat(), + "total_queries": total_queries, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_cost": total_cost, + "by_engine": by_engine, + "by_day": by_day, + } + + async def check_quota( + self, + user_id: str | uuid.UUID, + monthly_limit: float = 100.0, + ) -> dict: + summary = await self.get_summary(user_id=user_id, period="month") + usage_pct = (summary["total_cost"] / monthly_limit * 100) if monthly_limit > 0 else 0 + return { + "used": summary["total_cost"], + "limit": monthly_limit, + "usage_percentage": round(usage_pct, 1), + "status": _quota_status(usage_pct), + } + + async def get_by_id(self, record_id: uuid.UUID) -> UsageRecord | None: + result = await self.session.execute( + select(UsageRecord).where(UsageRecord.id == record_id) + ) + return result.scalar_one_or_none() + + async def get_by_user( + self, + user_id: str | uuid.UUID, + limit: int = 100, + offset: int = 0, + ) -> list[UsageRecord]: + if isinstance(user_id, str): + user_id = uuid.UUID(user_id) + + result = await self.session.execute( + select(UsageRecord) + .where(UsageRecord.user_id == user_id) + .order_by(UsageRecord.timestamp.desc()) + .limit(limit) + .offset(offset) + ) + return list(result.scalars().all()) + + async def get_by_user_and_engine( + self, + user_id: str | uuid.UUID, + engine_type: str, + limit: int = 100, + ) -> list[UsageRecord]: + if isinstance(user_id, str): + user_id = uuid.UUID(user_id) + + result = await self.session.execute( + select(UsageRecord) + .where( + and_( + UsageRecord.user_id == user_id, + UsageRecord.engine_type == engine_type, + ) + ) + .order_by(UsageRecord.timestamp.desc()) + .limit(limit) + ) + return list(result.scalars().all()) diff --git a/backend/app/services/ai_engine/base.py b/backend/app/services/ai_engine/base.py index a3f976c..6ad0472 100644 --- a/backend/app/services/ai_engine/base.py +++ b/backend/app/services/ai_engine/base.py @@ -9,6 +9,8 @@ from typing import Any import httpx +from app.services.api_key_manager import APIKeyManager + logger = logging.getLogger(__name__) _MAX_RETRIES = 3 @@ -49,15 +51,53 @@ class AIQueryResult: response_time_ms: int timestamp: datetime metadata: dict[str, Any] = field(default_factory=dict) + input_tokens: int = 0 + output_tokens: int = 0 + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens class AIEngineAdapter(ABC): - def __init__(self, api_key: str, rate_limiter=None, proxy: str | None = None): - self.api_key = api_key + def __init__( + self, + api_key: str | None = None, + rate_limiter=None, + proxy: str | None = None, + key_manager: APIKeyManager | None = None, + user_id: str | None = None, + ): + self._key_manager = key_manager + self._user_id = user_id + self.api_key = self._resolve_api_key(api_key, key_manager, user_id) self.rate_limiter = rate_limiter - self.proxy = proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + self.proxy = proxy or self._load_proxy() self._client: httpx.AsyncClient | None = None + def _load_proxy(self) -> str | None: + return os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + + def _resolve_api_key( + self, + direct_key: str | None, + key_manager: APIKeyManager | None, + user_id: str | None, + ) -> str: + if direct_key and direct_key.strip(): + return direct_key + + if key_manager: + key = key_manager.get_key(self.get_engine_type().value, user_id=user_id) + if key: + return key + + return self._get_env_key() or "" + + @abstractmethod + def _get_env_key(self) -> str | None: + pass + @abstractmethod async def query( self, diff --git a/backend/app/services/ai_engine/batch_query.py b/backend/app/services/ai_engine/batch_query.py index e47be21..e2302ec 100644 --- a/backend/app/services/ai_engine/batch_query.py +++ b/backend/app/services/ai_engine/batch_query.py @@ -1,14 +1,103 @@ import asyncio import logging +from functools import lru_cache +from typing import TYPE_CHECKING from .base import AIEngineAdapter, AIQueryResult, EngineType +from app.services.usage_recorder import UsageRecorder +from app.services.usage_tracker import UsageTracker + +if TYPE_CHECKING: + from app.services.api_key_manager import APIKeyManager logger = logging.getLogger(__name__) +_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {} + + +def register_adapter(cls: type[AIEngineAdapter]) -> None: + engine_type = None + try: + temp = cls() + engine_type = temp.get_engine_type() + _ADAPTER_CLASSES[engine_type] = cls + except Exception as e: + logger.warning(f"Failed to register adapter: {e}") + + +from .chatgpt import ChatGPTAdapter +from .perplexity import PerplexityAdapter +from .kimi import KimiAdapter +from .wenxin import WenxinAdapter +from .doubao import DoubaoAdapter +from .yuanbao import YuanbaoAdapter +from .deepseek import DeepSeekAdapter +from .qwen import QwenAdapter +from .gemini import GeminiAdapter + +register_adapter(ChatGPTAdapter) +register_adapter(PerplexityAdapter) +register_adapter(KimiAdapter) +register_adapter(WenxinAdapter) +register_adapter(DoubaoAdapter) +register_adapter(YuanbaoAdapter) +register_adapter(DeepSeekAdapter) +register_adapter(QwenAdapter) +register_adapter(GeminiAdapter) + + +def get_batch_service( + key_manager: "APIKeyManager | None" = None, + user_id: str | None = None, +) -> "BatchQueryService": + if key_manager: + adapters = _build_adapters_with_key_manager(key_manager=key_manager, user_id=user_id) + else: + adapters = _build_adapters() + return BatchQueryService(adapters) + + +@lru_cache(maxsize=1) +def _build_adapters() -> dict[str, AIEngineAdapter]: + adapters: dict[str, AIEngineAdapter] = {} + for engine_type, cls in _ADAPTER_CLASSES.items(): + try: + adapters[engine_type.value] = cls() + except Exception: + logger.warning(f"Failed to initialize {engine_type.value} adapter") + return adapters + + +def _build_adapters_with_key_manager( + key_manager: "APIKeyManager | None" = None, + user_id: str | None = None, +) -> dict[str, AIEngineAdapter]: + adapters: dict[str, AIEngineAdapter] = {} + for engine_type, cls in _ADAPTER_CLASSES.items(): + try: + adapters[engine_type.value] = cls( + key_manager=key_manager, + user_id=user_id, + ) + except Exception: + logger.warning(f"Failed to initialize {engine_type.value} adapter") + return adapters + class BatchQueryService: def __init__(self, adapters: dict[str, AIEngineAdapter]): self.adapters = adapters + self._tracker = UsageTracker() + self._recorder = UsageRecorder(self._tracker) + self._user_id: str | None = None + self._brand_id: str | None = None + + def set_user_context(self, user_id: str, brand_id: str | None = None) -> None: + self._user_id = user_id + self._brand_id = brand_id + + def get_usage_summary(self): + return self._tracker.get_summary(user_id=self._user_id) async def query_single( self, @@ -20,7 +109,24 @@ class BatchQueryService: adapter = self.adapters.get(engine_type.value) if not adapter: raise ValueError(f"Unknown engine type: {engine_type}") - return await adapter.query(query, brand_name, competitor_names) + + result = await adapter.query(query, brand_name, competitor_names) + + if self._user_id: + self._recorder.record( + user_id=self._user_id, + brand_id=self._brand_id, + engine_type=engine_type.value, + query=query, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + metadata={ + "brand_name": brand_name, + "response_time_ms": result.response_time_ms, + }, + ) + + return result async def query_batch( self, diff --git a/backend/app/services/ai_engine/chatgpt.py b/backend/app/services/ai_engine/chatgpt.py index 85fec00..15b90ef 100644 --- a/backend/app/services/ai_engine/chatgpt.py +++ b/backend/app/services/ai_engine/chatgpt.py @@ -21,11 +21,15 @@ class ChatGPTAdapter(AIEngineAdapter): base_url: str | None = None, rate_limiter=None, proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("OPENAI_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, - proxy=proxy or os.getenv("OPENAI_PROXY"), + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -43,6 +47,12 @@ class ChatGPTAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.CHATGPT + def _load_proxy(self) -> str | None: + return os.getenv("OPENAI_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + + def _get_env_key(self) -> str | None: + return os.getenv("OPENAI_API_KEY", "") + async def query( self, query: str, @@ -70,6 +80,10 @@ class ChatGPTAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[chatgpt] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -87,4 +101,6 @@ class ChatGPTAdapter(AIEngineAdapter): response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), metadata={"model": data.get("model", self._model)}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/deepseek.py b/backend/app/services/ai_engine/deepseek.py index 06946f0..ac6cf57 100644 --- a/backend/app/services/ai_engine/deepseek.py +++ b/backend/app/services/ai_engine/deepseek.py @@ -21,11 +21,15 @@ class DeepSeekAdapter(AIEngineAdapter): base_url: str | None = None, rate_limiter=None, proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, - proxy=proxy or os.getenv("DEEPSEEK_PROXY"), + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -43,6 +47,12 @@ class DeepSeekAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.DEEPSEEK + def _get_env_key(self) -> str | None: + return os.getenv("DEEPSEEK_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("DEEPSEEK_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def query( self, query: str, @@ -70,6 +80,10 @@ class DeepSeekAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[deepseek] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -87,4 +101,6 @@ class DeepSeekAdapter(AIEngineAdapter): response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), metadata={"model": data.get("model", self._model)}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/doubao.py b/backend/app/services/ai_engine/doubao.py index 1cd699d..99d4d6d 100644 --- a/backend/app/services/ai_engine/doubao.py +++ b/backend/app/services/ai_engine/doubao.py @@ -19,10 +19,16 @@ class DoubaoAdapter(AIEngineAdapter): api_key: str | None = None, endpoint_id: str | None = None, rate_limiter=None, + proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("DOUBAO_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "") self._base_url = _DEFAULT_BASE_URL.rstrip("/") @@ -38,6 +44,12 @@ class DoubaoAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.DOUBAO + def _get_env_key(self) -> str | None: + return os.getenv("DOUBAO_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("DOUBAO_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + def _get_model_id(self) -> str: if self._endpoint_id and self._endpoint_id.strip(): if not self._endpoint_id.startswith("ep-"): @@ -76,6 +88,10 @@ class DoubaoAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[doubao] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -92,5 +108,7 @@ class DoubaoAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": data.get("model", model_id), "usage": data.get("usage")}, + metadata={"model": data.get("model", model_id), "usage": usage}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/gemini.py b/backend/app/services/ai_engine/gemini.py index b9a0ae3..8294643 100644 --- a/backend/app/services/ai_engine/gemini.py +++ b/backend/app/services/ai_engine/gemini.py @@ -23,10 +23,15 @@ class GeminiAdapter(AIEngineAdapter): model: str | None = None, rate_limiter=None, proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("GOOGLE_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL) self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/") @@ -46,6 +51,12 @@ class GeminiAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.GEMINI + def _get_env_key(self) -> str | None: + return os.getenv("GOOGLE_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("GOOGLE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def _request_with_retry(self, payload: dict) -> dict: if self.rate_limiter: await self.rate_limiter.acquire() @@ -117,6 +128,10 @@ class GeminiAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage_metadata = data.get("usageMetadata", {}) + input_tokens = usage_metadata.get("promptTokenCount", 0) + output_tokens = usage_metadata.get("candidatesTokenCount", 0) + logger.info( f"[gemini] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -133,5 +148,7 @@ class GeminiAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": self._model}, + metadata={"model": self._model, "usage": usage_metadata}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/kimi.py b/backend/app/services/ai_engine/kimi.py index 757718c..a241dc1 100644 --- a/backend/app/services/ai_engine/kimi.py +++ b/backend/app/services/ai_engine/kimi.py @@ -20,10 +20,16 @@ class KimiAdapter(AIEngineAdapter): model: str | None = None, base_url: str | None = None, rate_limiter=None, + proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("MOONSHOT_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or _DEFAULT_MODEL self._base_url = ( @@ -41,6 +47,12 @@ class KimiAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.KIMI + def _get_env_key(self) -> str | None: + return os.getenv("MOONSHOT_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("MOONSHOT_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def query( self, query: str, @@ -71,6 +83,10 @@ class KimiAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[kimi] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -87,5 +103,7 @@ class KimiAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + metadata={"model": data.get("model", self._model), "usage": usage}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/perplexity.py b/backend/app/services/ai_engine/perplexity.py index 3e5983a..2bc5dab 100644 --- a/backend/app/services/ai_engine/perplexity.py +++ b/backend/app/services/ai_engine/perplexity.py @@ -21,11 +21,15 @@ class PerplexityAdapter(AIEngineAdapter): base_url: str | None = None, rate_limiter=None, proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, - proxy=proxy or os.getenv("PERPLEXITY_PROXY"), + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -43,6 +47,12 @@ class PerplexityAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.PERPLEXITY + def _get_env_key(self) -> str | None: + return os.getenv("PERPLEXITY_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("PERPLEXITY_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def query( self, query: str, @@ -71,6 +81,10 @@ class PerplexityAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[perplexity] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms" @@ -88,6 +102,8 @@ class PerplexityAdapter(AIEngineAdapter): response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), metadata={"model": data.get("model", self._model)}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) def _extract_citations(self, data: dict) -> list[CitationInfo]: diff --git a/backend/app/services/ai_engine/qwen.py b/backend/app/services/ai_engine/qwen.py index 849b0a2..b47fb80 100644 --- a/backend/app/services/ai_engine/qwen.py +++ b/backend/app/services/ai_engine/qwen.py @@ -20,10 +20,16 @@ class QwenAdapter(AIEngineAdapter): model: str | None = None, base_url: str | None = None, rate_limiter=None, + proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("DASHSCOPE_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("QWEN_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -41,6 +47,12 @@ class QwenAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.QWEN + def _get_env_key(self) -> str | None: + return os.getenv("DASHSCOPE_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("DASHSCOPE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def query( self, query: str, @@ -71,6 +83,10 @@ class QwenAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[qwen] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -87,5 +103,7 @@ class QwenAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + metadata={"model": data.get("model", self._model), "usage": usage}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/wenxin.py b/backend/app/services/ai_engine/wenxin.py index 3160faf..95cd320 100644 --- a/backend/app/services/ai_engine/wenxin.py +++ b/backend/app/services/ai_engine/wenxin.py @@ -26,10 +26,16 @@ class WenxinAdapter(AIEngineAdapter): api_key: str | None = None, secret_key: str | None = None, rate_limiter=None, + proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("BAIDU_QIANFAN_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "") self._model = _DEFAULT_MODEL @@ -40,6 +46,12 @@ class WenxinAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.WENXIN + def _get_env_key(self) -> str | None: + return os.getenv("BAIDU_QIANFAN_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def _get_access_token(self) -> str: global _cached_token, _token_expires_at @@ -125,6 +137,10 @@ class WenxinAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[wenxin] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -141,5 +157,7 @@ class WenxinAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": self._model, "usage": data.get("usage")}, + metadata={"model": self._model, "usage": usage}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/ai_engine/yuanbao.py b/backend/app/services/ai_engine/yuanbao.py index 3afd752..50b0210 100644 --- a/backend/app/services/ai_engine/yuanbao.py +++ b/backend/app/services/ai_engine/yuanbao.py @@ -20,10 +20,16 @@ class YuanbaoAdapter(AIEngineAdapter): model: str | None = None, base_url: str | None = None, rate_limiter=None, + proxy: str | None = None, + key_manager=None, + user_id: str | None = None, ): super().__init__( - api_key=api_key or os.getenv("HUNYUAN_API_KEY", ""), + api_key=api_key, rate_limiter=rate_limiter, + proxy=proxy, + key_manager=key_manager, + user_id=user_id, ) self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -41,6 +47,12 @@ class YuanbaoAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return EngineType.YUANBAO + def _get_env_key(self) -> str | None: + return os.getenv("HUNYUAN_API_KEY", "") + + def _load_proxy(self) -> str | None: + return os.getenv("HUNYUAN_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") + async def query( self, query: str, @@ -71,6 +83,10 @@ class YuanbaoAdapter(AIEngineAdapter): content, brand_name, competitor_names ) + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + logger.info( f"[yuanbao] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" @@ -87,5 +103,7 @@ class YuanbaoAdapter(AIEngineAdapter): competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), - metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + metadata={"model": data.get("model", self._model), "usage": usage}, + input_tokens=input_tokens, + output_tokens=output_tokens, ) diff --git a/backend/app/services/api_key_manager.py b/backend/app/services/api_key_manager.py index 9680835..53a83ec 100644 --- a/backend/app/services/api_key_manager.py +++ b/backend/app/services/api_key_manager.py @@ -1,9 +1,17 @@ import base64 import logging import os +import uuid from dataclasses import dataclass +from datetime import datetime from enum import Enum +from sqlalchemy.ext.asyncio import AsyncSession + +from app.repositories.api_key_repository import APIKeyRepository +from app.services.key_encryption import KeyEncryption, get_key_encryption +from app.services.key_verifier import KeyStatus, KeyVerifierFactory + logger = logging.getLogger(__name__) @@ -13,14 +21,6 @@ class KeySource(str, Enum): ENV = "env" -class KeyStatus(str, Enum): - ACTIVE = "active" - INVALID = "invalid" - EXPIRED = "expired" - RATE_LIMITED = "rate_limited" - UNKNOWN = "unknown" - - @dataclass class APIKeyConfig: engine_type: str @@ -116,13 +116,17 @@ class APIKeyManager: async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus: if not api_key or len(api_key) < 10: return KeyStatus.INVALID - return KeyStatus.ACTIVE + try: + return await KeyVerifierFactory.verify(engine_type, api_key) + except Exception as e: + logger.warning(f"[api_key_manager] Key verification failed: {e}") + return KeyStatus.UNKNOWN def _encrypt(self, plaintext: str) -> str: - return base64.b64encode(plaintext.encode()).decode() + return get_key_encryption().encrypt(plaintext) def _decrypt(self, ciphertext: str) -> str: - return base64.b64decode(ciphertext.encode()).decode() + return get_key_encryption().decrypt(ciphertext) def _mask_key(self, key: str) -> str: if len(key) <= 8: @@ -134,3 +138,115 @@ class APIKeyManager: key = os.getenv(env_var, "") if key: self.add_key(engine, key, source=KeySource.ENV, priority=0) + + async def add_key_async( + self, + session: AsyncSession, + engine_type: str, + api_key: str, + source: KeySource = KeySource.USER, + user_id: uuid.UUID | None = None, + priority: int = 0, + ) -> APIKeyConfig: + config = APIKeyConfig( + engine_type=engine_type, + key_source=source, + encrypted_key=self._encrypt(api_key), + key_hint=self._mask_key(api_key), + status=KeyStatus.UNKNOWN, + priority=priority, + user_id=str(user_id) if user_id else None, + ) + repository = APIKeyRepository(session) + await repository.create( + user_id=user_id, + engine_type=engine_type, + encrypted_key=config.encrypted_key, + key_hint=config.key_hint, + key_source=source.value, + status=config.status.value, + priority=priority, + ) + return config + + async def get_key_async( + self, + session: AsyncSession, + engine_type: str, + user_id: uuid.UUID | None = None, + ) -> str | None: + repository = APIKeyRepository(session) + keys = await repository.get_by_user(user_id, engine_type) + for key in keys: + source = KeySource(key.key_source) + status = KeyStatus(key.status) + if user_id: + if source == KeySource.USER and status in self._USABLE_STATUSES: + return self._decrypt(key.encrypted_key) + if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES: + return self._decrypt(key.encrypted_key) + return None + + async def get_any_available_key_async( + self, + session: AsyncSession, + engine_type: str, + ) -> str | None: + repository = APIKeyRepository(session) + keys = await repository.get_by_user(None, engine_type) + for key in keys: + status = KeyStatus(key.status) + if status in self._USABLE_STATUSES: + return self._decrypt(key.encrypted_key) + return None + + async def remove_key_async( + self, + session: AsyncSession, + user_id: uuid.UUID, + engine_type: str, + key_hint: str, + ) -> bool: + repository = APIKeyRepository(session) + keys = await repository.get_by_user(user_id, engine_type) + for key in keys: + if key.key_hint == key_hint: + deleted = await repository.delete(key.id) + return deleted + return False + + async def list_keys_async( + self, + session: AsyncSession, + user_id: uuid.UUID | None = None, + engine_type: str | None = None, + ) -> list[APIKeyConfig]: + repository = APIKeyRepository(session) + if user_id: + keys = await repository.get_by_user(user_id, engine_type) + else: + keys = await repository.list_all(engine_type=engine_type) + return [ + APIKeyConfig( + engine_type=k.engine_type, + key_source=KeySource(k.key_source), + encrypted_key=k.encrypted_key, + key_hint=k.key_hint, + status=KeyStatus(k.status), + priority=k.priority, + last_verified_at=str(k.last_verified_at) if k.last_verified_at else None, + created_at=str(k.created_at) if k.created_at else None, + user_id=str(k.user_id) if k.user_id else None, + ) + for k in keys + ] + + async def update_key_status_async( + self, + session: AsyncSession, + key_id: uuid.UUID, + status: KeyStatus, + ) -> bool: + repository = APIKeyRepository(session) + updated = await repository.update(key_id, status=status.value) + return updated is not None diff --git a/backend/app/services/detection_scheduler.py b/backend/app/services/detection_scheduler.py index b461ffb..d120c49 100644 --- a/backend/app/services/detection_scheduler.py +++ b/backend/app/services/detection_scheduler.py @@ -5,11 +5,10 @@ from datetime import datetime, timezone from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession -from app.api.ai_engines import _build_adapters from app.models.brand import Brand from app.models.detection_task import VALID_FREQUENCIES, DetectionTask from app.services.ai_engine.base import EngineType -from app.services.ai_engine.batch_query import BatchQueryService +from app.services.ai_engine.batch_query import BatchQueryService, _build_adapters logger = logging.getLogger(__name__) diff --git a/backend/app/services/engine_selector.py b/backend/app/services/engine_selector.py new file mode 100644 index 0000000..cc9011f --- /dev/null +++ b/backend/app/services/engine_selector.py @@ -0,0 +1,37 @@ +from app.services.smart_router import ENGINE_COST_PROFILES, SmartRouter +from app.services.api_key_manager import APIKeyManager + + +class EngineSelector: + def __init__(self, key_manager: APIKeyManager): + self.key_manager = key_manager + self.router = SmartRouter(key_manager=key_manager) + + def select_engines( + self, + max_engines: int = 5, + prefer_domestic: bool = True, + min_cost_tier: str | None = None, + ) -> list[str]: + engines = self.router.select_engines(max_engines, prefer_domestic) + + if min_cost_tier: + tier_order = ["free", "low_cost", "mid_cost", "high_cost"] + if min_cost_tier in tier_order: + min_idx = tier_order.index(min_cost_tier) + filtered = [] + for engine in engines: + profile = ENGINE_COST_PROFILES.get(engine) + if profile and tier_order.index(profile.cost_tier.value) >= min_idx: + filtered.append(engine) + return filtered + + return engines + + def get_best_value_engine(self) -> str | None: + available = self.router.get_available_engines() + for engine in available: + profile = ENGINE_COST_PROFILES.get(engine) + if profile and profile.cost_tier.value in ["free", "low_cost"]: + return engine + return available[0] if available else None diff --git a/backend/app/services/key_encryption.py b/backend/app/services/key_encryption.py new file mode 100644 index 0000000..5a9e771 --- /dev/null +++ b/backend/app/services/key_encryption.py @@ -0,0 +1,57 @@ +import os +import base64 +from cryptography.fernet import Fernet, InvalidToken +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +class KeyEncryption: + def __init__(self, encryption_key: str | None = None): + self._fernet = self._create_fernet(encryption_key) + + def _create_fernet(self, key: str | None) -> Fernet: + if not key: + key = os.getenv("API_KEY_ENCRYPTION_KEY", "") + + if not key: + return Fernet(Fernet.generate_key()) + + try: + return Fernet(key.encode()) + except Exception: + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"geo-platform-salt", + iterations=100000, + ) + derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode())) + return Fernet(derived_key) + + def encrypt(self, plaintext: str) -> str: + if not plaintext: + raise ValueError("Cannot encrypt empty string") + return self._fernet.encrypt(plaintext.encode()).decode() + + def decrypt(self, ciphertext: str) -> str: + if not ciphertext: + raise ValueError("Cannot decrypt empty string") + try: + return self._fernet.decrypt(ciphertext.encode()).decode() + except InvalidToken: + raise ValueError("Invalid ciphertext or key") + + +_key_encryption: KeyEncryption | None = None + + +def get_key_encryption() -> KeyEncryption: + global _key_encryption + if _key_encryption is None: + _key_encryption = KeyEncryption() + return _key_encryption + + +def reset_key_encryption() -> None: + global _key_encryption + _key_encryption = None diff --git a/backend/app/services/key_verifier.py b/backend/app/services/key_verifier.py new file mode 100644 index 0000000..4c97093 --- /dev/null +++ b/backend/app/services/key_verifier.py @@ -0,0 +1,325 @@ +import logging +import os +from abc import ABC, abstractmethod +from enum import Enum + +import httpx + +logger = logging.getLogger(__name__) + + +class KeyStatus(str, Enum): + ACTIVE = "active" + INVALID = "invalid" + EXPIRED = "expired" + RATE_LIMITED = "rate_limited" + UNKNOWN = "unknown" + +_VERIFY_TIMEOUT = 15.0 + + +class KeyVerifier(ABC): + @abstractmethod + async def verify(self, api_key: str) -> KeyStatus: + pass + + +class DefaultKeyVerifier(KeyVerifier): + async def verify(self, api_key: str) -> KeyStatus: + return KeyStatus.UNKNOWN + + +class ChatGPTKeyVerifier(KeyVerifier): + _BASE_URL = "https://api.openai.com/v1" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("OPENAI_PROXY") or os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.post( + f"{self._BASE_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 5, + }, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 401: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + elif response.status_code == 403: + return KeyStatus.EXPIRED + else: + logger.warning(f"[chatgpt] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[chatgpt] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[chatgpt] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[chatgpt] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class DeepSeekKeyVerifier(KeyVerifier): + _BASE_URL = "https://api.deepseek.com/v1" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("DEEPSEEK_PROXY") or os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.post( + f"{self._BASE_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "deepseek-chat", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 5, + }, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 401: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + elif response.status_code == 403: + return KeyStatus.EXPIRED + else: + logger.warning(f"[deepseek] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[deepseek] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[deepseek] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[deepseek] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class KimiKeyVerifier(KeyVerifier): + _BASE_URL = "https://api.moonshot.cn/v1" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.post( + f"{self._BASE_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "moonshot-v1-8k", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 5, + }, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 401: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + elif response.status_code == 403: + return KeyStatus.EXPIRED + else: + logger.warning(f"[kimi] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[kimi] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[kimi] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[kimi] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class QwenKeyVerifier(KeyVerifier): + _BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.post( + f"{self._BASE_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "qwen-plus", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 5, + }, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 401: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + elif response.status_code == 403: + return KeyStatus.EXPIRED + else: + logger.warning(f"[qwen] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[qwen] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[qwen] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[qwen] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class PerplexityKeyVerifier(KeyVerifier): + _BASE_URL = "https://api.perplexity.ai" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.post( + f"{self._BASE_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "llama-3.1-sonar-small-128k-online", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 5, + }, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 401: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + elif response.status_code == 403: + return KeyStatus.EXPIRED + else: + logger.warning(f"[perplexity] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[perplexity] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[perplexity] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[perplexity] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class GeminiKeyVerifier(KeyVerifier): + _BASE_URL = "https://generativelanguage.googleapis.com/v1beta" + + async def verify(self, api_key: str) -> KeyStatus: + proxy = os.getenv("HTTPS_PROXY") + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0), + proxy=proxy, + ) as client: + response = await client.get( + f"{self._BASE_URL}/models", + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + params={"key": api_key}, + ) + if response.status_code == 200: + return KeyStatus.ACTIVE + elif response.status_code == 403: + return KeyStatus.INVALID + elif response.status_code == 429: + return KeyStatus.RATE_LIMITED + else: + logger.warning(f"[gemini] Unexpected status {response.status_code}") + return KeyStatus.UNKNOWN + except httpx.TimeoutException: + logger.warning("[gemini] Verification timeout") + return KeyStatus.UNKNOWN + except httpx.ConnectError as e: + logger.warning(f"[gemini] Connection error: {e}") + return KeyStatus.UNKNOWN + except Exception as e: + logger.error(f"[gemini] Verification error: {e}") + return KeyStatus.UNKNOWN + + +class WenxinKeyVerifier(KeyVerifier): + async def verify(self, api_key: str) -> KeyStatus: + return KeyStatus.UNKNOWN + + +class DoubaoKeyVerifier(KeyVerifier): + async def verify(self, api_key: str) -> KeyStatus: + return KeyStatus.UNKNOWN + + +class YuanbaoKeyVerifier(KeyVerifier): + async def verify(self, api_key: str) -> KeyStatus: + return KeyStatus.UNKNOWN + + +class KeyVerifierFactory: + _VERIFIERS = { + "chatgpt": ChatGPTKeyVerifier, + "perplexity": PerplexityKeyVerifier, + "kimi": KimiKeyVerifier, + "wenxin": WenxinKeyVerifier, + "doubao": DoubaoKeyVerifier, + "deepseek": DeepSeekKeyVerifier, + "qwen": QwenKeyVerifier, + "gemini": GeminiKeyVerifier, + "yuanbao": YuanbaoKeyVerifier, + } + + @classmethod + def get_verifier(cls, engine_type: str) -> KeyVerifier: + verifier_class = cls._VERIFIERS.get(engine_type) + if not verifier_class: + return DefaultKeyVerifier() + return verifier_class() + + @classmethod + async def verify(cls, engine_type: str, api_key: str) -> KeyStatus: + verifier = cls.get_verifier(engine_type) + return await verifier.verify(api_key) diff --git a/backend/app/services/smart_router.py b/backend/app/services/smart_router.py index 50ae722..13ec8f2 100644 --- a/backend/app/services/smart_router.py +++ b/backend/app/services/smart_router.py @@ -2,6 +2,10 @@ from __future__ import annotations from dataclasses import dataclass from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.services.api_key_manager import APIKeyManager class CostTier(str, Enum): @@ -43,14 +47,37 @@ def _get_profile(engine: str) -> EngineCostProfile | None: class SmartRouter: - def __init__(self, available_engines: list[str] | None = None, user_engines: list[str] | None = None): + def __init__( + self, + available_engines: list[str] | None = None, + user_engines: list[str] | None = None, + key_manager: APIKeyManager | None = None, + ): self.available_engines = available_engines or list(ENGINE_COST_PROFILES.keys()) self._user_engine_set = set(user_engines or []) + self._key_manager = key_manager @property def user_engines(self) -> list[str]: return list(self._user_engine_set) + def set_key_manager(self, key_manager: APIKeyManager) -> None: + self._key_manager = key_manager + + def _filter_by_available_keys(self, engines: list[str]) -> list[str]: + if not self._key_manager: + return engines + available = [] + for engine in engines: + key = self._key_manager.get_any_available_key(engine) + if key: + available.append(engine) + return available + + def get_available_engines(self) -> list[str]: + all_engines = list(ENGINE_COST_PROFILES.keys()) + return self._filter_by_available_keys(all_engines) + def select_engines(self, max_engines: int = 5, prefer_domestic: bool = True) -> list[str]: tiers: dict[CostTier, list[str]] = {tier: [] for tier in CostTier} for engine in self.available_engines: @@ -74,9 +101,12 @@ class SmartRouter: for tier in _TIER_ORDER: for engine in tiers[tier]: if len(selected) >= max_engines: - return selected - if engine not in selected: + break + key = self._key_manager.get_any_available_key(engine) if self._key_manager else True + if key and engine not in selected: selected.append(engine) + if len(selected) >= max_engines: + break return selected[:max_engines] @@ -94,7 +124,11 @@ class SmartRouter: return {"total_cost": round(total_cost, 6), "per_engine": details} def _engines_by_tier(self, tier: CostTier) -> list[str]: - return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier] + tier_engines = [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier] + return self._filter_by_available_keys(tier_engines) + + def get_engines_by_cost_tier(self, tier: CostTier) -> list[str]: + return self._engines_by_tier(tier) def _engines_by_requires_key(self) -> list[str]: return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.requires_own_key] diff --git a/backend/app/services/usage_recorder.py b/backend/app/services/usage_recorder.py new file mode 100644 index 0000000..72f4b3e --- /dev/null +++ b/backend/app/services/usage_recorder.py @@ -0,0 +1,38 @@ +from app.services.usage_tracker import UsageTracker +from app.services.smart_router import ENGINE_COST_PROFILES + + +class UsageRecorder: + def __init__(self, tracker: UsageTracker): + self.tracker = tracker + + def calculate_cost(self, engine_type: str, input_tokens: int, output_tokens: int) -> float: + profile = ENGINE_COST_PROFILES.get(engine_type) + if not profile: + return 0.0 + + input_cost = (input_tokens / 1_000_000) * profile.input_price_per_million + output_cost = (output_tokens / 1_000_000) * profile.output_price_per_million + return round(input_cost + output_cost, 6) + + def record( + self, + user_id: str, + brand_id: str | None, + engine_type: str, + query: str, + input_tokens: int, + output_tokens: int, + metadata: dict | None = None, + ) -> None: + cost = self.calculate_cost(engine_type, input_tokens, output_tokens) + self.tracker.record( + user_id=user_id, + brand_id=brand_id or "", + engine_type=engine_type, + query=query, + input_tokens=input_tokens, + output_tokens=output_tokens, + cost=cost, + metadata=metadata or {}, + ) diff --git a/backend/app/services/usage_tracker.py b/backend/app/services/usage_tracker.py index f85ac42..e2dd97b 100644 --- a/backend/app/services/usage_tracker.py +++ b/backend/app/services/usage_tracker.py @@ -1,9 +1,15 @@ from __future__ import annotations +import uuid from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from typing import Any +from sqlalchemy.ext.asyncio import AsyncSession + +from app.repositories.usage_repository import UsageRepository +from app.models.usage_record import UsageRecord as UsageRecordModel + _QUOTA_WARNING_PCT = 80.0 _QUOTA_EXCEEDED_PCT = 100.0 @@ -48,6 +54,18 @@ def _aggregate_by_engine(records: list[UsageRecord]) -> dict[str, dict[str, Any] return result +def _aggregate_by_day(records: list[UsageRecord]) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + for r in records: + day_key = r.timestamp.strftime("%Y-%m-%d") + bucket = result.setdefault(day_key, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}) + bucket["queries"] += 1 + bucket["input_tokens"] += r.input_tokens + bucket["output_tokens"] += r.output_tokens + bucket["cost"] += r.cost + return result + + def _compute_cutoff(period: str, now: datetime) -> datetime: days = _PERIOD_CUTOFF_DAYS.get(period, 30) if days == 0: @@ -64,8 +82,10 @@ def _quota_status(usage_pct: float) -> str: class UsageTracker: - def __init__(self) -> None: + def __init__(self, session: AsyncSession | None = None) -> None: self._records: list[UsageRecord] = [] + self._session = session + self._repository = UsageRepository(session) if session else None def record( self, @@ -119,7 +139,7 @@ class UsageTracker: total_output_tokens=sum(r.output_tokens for r in filtered), total_cost=round(sum(r.cost for r in filtered), 4), by_engine=_aggregate_by_engine(filtered), - by_day={}, + by_day=_aggregate_by_day(filtered), ) def check_quota(self, user_id: str, monthly_limit: float = 100.0) -> dict: @@ -131,3 +151,51 @@ class UsageTracker: "usage_percentage": round(usage_pct, 1), "status": _quota_status(usage_pct), } + + async def record_async( + self, + user_id: str | uuid.UUID, + brand_id: str | uuid.UUID | None, + engine_type: str, + query: str, + input_tokens: int, + output_tokens: int, + cost: float, + extra_data: dict | None = None, + ) -> UsageRecordModel: + if not self._repository: + raise RuntimeError("UsageTracker not initialized with AsyncSession") + + user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id + brand_uuid = uuid.UUID(brand_id) if (brand_id and isinstance(brand_id, str)) else brand_id + + data = { + "user_id": user_uuid, + "brand_id": brand_uuid, + "engine_type": engine_type, + "query": query, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cost": cost, + "extra_data": extra_data or {}, + } + return await self._repository.create(data) + + async def get_summary_async( + self, + user_id: str | uuid.UUID, + period: str = "month", + brand_id: str | uuid.UUID | None = None, + ) -> dict: + if not self._repository: + raise RuntimeError("UsageTracker not initialized with AsyncSession") + return await self._repository.get_summary(user_id, period, brand_id) + + async def check_quota_async( + self, + user_id: str | uuid.UUID, + monthly_limit: float = 100.0, + ) -> dict: + if not self._repository: + raise RuntimeError("UsageTracker not initialized with AsyncSession") + return await self._repository.check_quota(user_id, monthly_limit) diff --git a/backend/app/services/user_quota_service.py b/backend/app/services/user_quota_service.py new file mode 100644 index 0000000..bef89c8 --- /dev/null +++ b/backend/app/services/user_quota_service.py @@ -0,0 +1,37 @@ +"""User quota service for plan-based monthly limits.""" +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.repositories.usage_repository import UsageRepository + +PLAN_MONTHLY_LIMITS = { + "free": 10.0, + "basic": 50.0, + "pro": 200.0, + "enterprise": 1000.0, +} + + +class UserQuotaService: + """Service for checking user quota based on subscription plan.""" + + def __init__(self, session: AsyncSession | None = None): + self._session = session + self._repository = UsageRepository(session) if session else None + + def get_monthly_limit(self, user_plan: str) -> float: + """Get monthly cost limit based on user plan.""" + return PLAN_MONTHLY_LIMITS.get(user_plan, PLAN_MONTHLY_LIMITS["free"]) + + async def check_quota_with_plan( + self, + user_id: str, + user_plan: str, + ) -> dict: + """Check quota using the limit from user plan.""" + if not self._repository: + raise RuntimeError("UserQuotaService not initialized with AsyncSession") + + monthly_limit = self.get_monthly_limit(user_plan) + return await self._repository.check_quota(user_id, monthly_limit=monthly_limit) diff --git a/backend/tests/test_content_pipeline.py b/backend/tests/test_content_pipeline.py deleted file mode 100644 index 57ae14f..0000000 --- a/backend/tests/test_content_pipeline.py +++ /dev/null @@ -1,89 +0,0 @@ -# test_content_pipeline.py -import pytest - -# 导入实际的 ContentPipeline 实现 -from app.services.content.content_pipeline import ContentPipeline - -@pytest.mark.asyncio -async def test_pipeline_complete_run(): - """完整Pipeline执行""" - pipeline = ContentPipeline() - request = { - "content": "这是一篇测试文章内容", - "title": "测试标题", - "platform": "zhihu", - "optimize_for": ["validation", "sensitive", "seo"] - } - result = await pipeline.run(request) - - assert result.stages is not None - assert len(result.stages) > 0 - assert result.outputs is not None - -@pytest.mark.asyncio -async def test_pipeline_with_validation_fail(): - """校验失败中断""" - pipeline = ContentPipeline() - request = { - "content": "内容", - "title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊", - "platform": "wechat", - "optimize_for": ["validation"] - } - result = await pipeline.run(request) - - # 校验失败时不应继续执行后续阶段 - validation_stage = next((s for s in result.stages if s.name == "validation"), None) - assert validation_stage is not None - assert validation_stage.passed == False - -@pytest.mark.asyncio -async def test_pipeline_multi_platform(): - """多平台适配""" - pipeline = ContentPipeline() - - zhihu_result = await pipeline.run({ - "content": "

测试内容

外部链接", - "title": "测试标题", - "platform": "zhihu" - }) - - wechat_result = await pipeline.run({ - "content": "

测试内容

外部链接", - "title": "测试标题", - "platform": "wechat" - }) - - # 不同平台应产生不同的优化结果 - assert zhihu_result.outputs != wechat_result.outputs - -@pytest.mark.asyncio -async def test_pipeline_stage_results(): - """各阶段结果记录""" - pipeline = ContentPipeline() - result = await pipeline.run({ - "content": "内容", - "title": "标题", - "platform": "zhihu" - }) - - # 检查每个阶段的结果 - for stage in result.stages: - assert stage.name is not None - assert hasattr(stage, 'passed') or hasattr(stage, 'result') - -@pytest.mark.asyncio -async def test_pipeline_error_handling(): - """错误处理""" - pipeline = ContentPipeline() - - # 无效平台应返回错误 - try: - result = await pipeline.run({ - "content": "内容", - "title": "标题", - "platform": "invalid_platform" - }) - assert result.error is not None - except ValueError as e: - assert "不支持的平台" in str(e) diff --git a/backend/tests/test_models/test_api_key.py b/backend/tests/test_models/test_api_key.py new file mode 100644 index 0000000..06d1209 --- /dev/null +++ b/backend/tests/test_models/test_api_key.py @@ -0,0 +1,306 @@ +"""Tests for APIKey model.""" +import uuid +from datetime import datetime, timezone + +import pytest +from sqlalchemy import select, and_ + +from app.models.api_key import APIKey + + +class TestAPIKeyModel: + """Test cases for APIKey model.""" + + @pytest.mark.asyncio + async def test_api_key_create(self, async_session, test_user): + """Test creating a new API key.""" + api_key = APIKey( + id=uuid.uuid4(), + user_id=test_user.id, + engine_type="chatgpt", + encrypted_key="encrypted_test_key", + key_hint="sk-...abc", + key_source="user", + status="active", + priority=0, + ) + async_session.add(api_key) + await async_session.commit() + await async_session.refresh(api_key) + + assert api_key.id is not None + assert api_key.user_id == test_user.id + assert api_key.engine_type == "chatgpt" + assert api_key.encrypted_key == "encrypted_test_key" + assert api_key.key_hint == "sk-...abc" + assert api_key.key_source == "user" + assert api_key.status == "active" + assert api_key.priority == 0 + assert api_key.created_at is not None + assert api_key.updated_at is not None + + @pytest.mark.asyncio + async def test_api_key_default_values(self, async_session, test_user): + """Test API key default values.""" + api_key = APIKey( + user_id=test_user.id, + engine_type="kimi", + encrypted_key="encrypted_kimi_key", + key_hint="sk-...xyz", + ) + async_session.add(api_key) + await async_session.commit() + await async_session.refresh(api_key) + + assert api_key.key_source == "user" + assert api_key.status == "active" + assert api_key.priority == 0 + assert api_key.last_verified_at is None + + @pytest.mark.asyncio + async def test_api_key_fields(self, async_session, test_user): + """Test API key field validation and constraints.""" + now = datetime.now(timezone.utc) + api_key = APIKey( + user_id=test_user.id, + engine_type="deepseek", + encrypted_key="encrypted_deepseek_key_data", + key_hint="sk-...def", + key_source="system", + status="active", + priority=10, + last_verified_at=now, + ) + async_session.add(api_key) + await async_session.commit() + await async_session.refresh(api_key) + + assert api_key.engine_type == "deepseek" + assert api_key.encrypted_key == "encrypted_deepseek_key_data" + assert api_key.key_hint == "sk-...def" + assert api_key.key_source == "system" + assert api_key.status == "active" + assert api_key.priority == 10 + assert api_key.last_verified_at is not None + assert api_key.last_verified_at.replace(tzinfo=None) == now.replace(tzinfo=None) + + @pytest.mark.asyncio + async def test_api_key_query_by_id(self, async_session, test_user): + """Test querying API key by ID.""" + key_id = uuid.uuid4() + api_key = APIKey( + id=key_id, + user_id=test_user.id, + engine_type="gemini", + encrypted_key="encrypted_gemini_key", + key_hint="AIza...123", + ) + async_session.add(api_key) + await async_session.commit() + + result = await async_session.execute( + select(APIKey).where(APIKey.id == key_id) + ) + fetched_key = result.scalar_one() + + assert fetched_key is not None + assert fetched_key.id == key_id + assert fetched_key.engine_type == "gemini" + + @pytest.mark.asyncio + async def test_api_key_query_by_user_id(self, async_session, test_user): + """Test querying API keys by user ID.""" + key1 = APIKey( + user_id=test_user.id, + engine_type="chatgpt", + encrypted_key="encrypted_key_1", + key_hint="sk-...1", + ) + key2 = APIKey( + user_id=test_user.id, + engine_type="kimi", + encrypted_key="encrypted_key_2", + key_hint="sk-...2", + ) + async_session.add(key1) + async_session.add(key2) + await async_session.commit() + + result = await async_session.execute( + select(APIKey).where(APIKey.user_id == test_user.id) + ) + keys = result.scalars().all() + + assert len(keys) == 2 + + @pytest.mark.asyncio + async def test_api_key_query_by_user_and_engine(self, async_session, test_user): + """Test querying API keys by user ID and engine type.""" + key1 = APIKey( + user_id=test_user.id, + engine_type="chatgpt", + encrypted_key="encrypted_chatgpt_key", + key_hint="sk-...chat", + ) + key2 = APIKey( + user_id=test_user.id, + engine_type="kimi", + encrypted_key="encrypted_kimi_key", + key_hint="sk-...kimi", + ) + async_session.add(key1) + async_session.add(key2) + await async_session.commit() + + result = await async_session.execute( + select(APIKey).where( + and_( + APIKey.user_id == test_user.id, + APIKey.engine_type == "chatgpt" + ) + ) + ) + keys = result.scalars().all() + + assert len(keys) == 1 + assert keys[0].engine_type == "chatgpt" + + @pytest.mark.asyncio + async def test_api_key_timestamps(self, async_session, test_user): + """Test API key created_at and updated_at timestamps.""" + api_key = APIKey( + user_id=test_user.id, + engine_type="qwen", + encrypted_key="encrypted_qwen_key", + key_hint="sk-...qwen", + ) + async_session.add(api_key) + await async_session.commit() + await async_session.refresh(api_key) + + assert api_key.created_at is not None + assert api_key.updated_at is not None + assert isinstance(api_key.created_at, datetime) + assert isinstance(api_key.updated_at, datetime) + + @pytest.mark.asyncio + async def test_api_key_update(self, async_session, test_user): + """Test updating API key fields.""" + api_key = APIKey( + user_id=test_user.id, + engine_type="wenxin", + encrypted_key="encrypted_wenxin_key", + key_hint="sk-...wenxin", + status="active", + priority=0, + ) + async_session.add(api_key) + await async_session.commit() + + api_key.status = "invalid" + api_key.priority = 5 + api_key.last_verified_at = datetime.now(timezone.utc) + await async_session.commit() + await async_session.refresh(api_key) + + assert api_key.status == "invalid" + assert api_key.priority == 5 + assert api_key.last_verified_at is not None + + @pytest.mark.asyncio + async def test_api_key_delete(self, async_session, test_user): + """Test deleting an API key.""" + api_key = APIKey( + user_id=test_user.id, + engine_type="doubao", + encrypted_key="encrypted_doubao_key", + key_hint="sk-...doubao", + ) + async_session.add(api_key) + await async_session.commit() + key_id = api_key.id + + await async_session.delete(api_key) + await async_session.commit() + + result = await async_session.execute( + select(APIKey).where(APIKey.id == key_id) + ) + deleted_key = result.scalar_one_or_none() + + assert deleted_key is None + + @pytest.mark.asyncio + async def test_api_key_status_values(self, async_session, test_user): + """Test different status values for API key.""" + statuses = ["active", "invalid", "expired", "rate_limited", "unknown"] + created_keys = [] + + for i, status in enumerate(statuses): + api_key = APIKey( + user_id=test_user.id, + engine_type=f"engine_{i}", + encrypted_key=f"encrypted_key_{i}", + key_hint=f"sk-...{i}", + status=status, + ) + async_session.add(api_key) + created_keys.append(api_key) + + await async_session.commit() + + for key, status in zip(created_keys, statuses): + assert key.status == status + + @pytest.mark.asyncio + async def test_api_key_priority_ordering(self, async_session, test_user): + """Test API keys with different priorities.""" + keys_data = [ + {"engine_type": "engine_a", "priority": 0, "key_hint": "a..."}, + {"engine_type": "engine_b", "priority": 10, "key_hint": "b..."}, + {"engine_type": "engine_c", "priority": 5, "key_hint": "c..."}, + ] + + for data in keys_data: + api_key = APIKey( + user_id=test_user.id, + engine_type=data["engine_type"], + encrypted_key=f"encrypted_{data['engine_type']}", + key_hint=data["key_hint"], + priority=data["priority"], + ) + async_session.add(api_key) + + await async_session.commit() + + result = await async_session.execute( + select(APIKey) + .where(APIKey.user_id == test_user.id) + .order_by(APIKey.priority.desc()) + ) + keys = result.scalars().all() + + assert keys[0].priority == 10 + assert keys[1].priority == 5 + assert keys[2].priority == 0 + + @pytest.mark.asyncio + async def test_api_key_user_id_index(self, async_session, test_user): + """Test that user_id field has an index.""" + for i in range(5): + api_key = APIKey( + user_id=test_user.id, + engine_type=f"engine_{i}", + encrypted_key=f"encrypted_key_{i}", + key_hint=f"hint_{i}", + ) + async_session.add(api_key) + + await async_session.commit() + + result = await async_session.execute( + select(APIKey).where(APIKey.user_id == test_user.id) + ) + keys = result.scalars().all() + + assert len(keys) == 5 diff --git a/backend/tests/test_models/test_usage_record.py b/backend/tests/test_models/test_usage_record.py new file mode 100644 index 0000000..930df50 --- /dev/null +++ b/backend/tests/test_models/test_usage_record.py @@ -0,0 +1,406 @@ +"""Tests for UsageRecord model.""" +import uuid +from datetime import datetime, timezone, timedelta + +import pytest +from sqlalchemy import select, func, and_ +from sqlalchemy.dialects.postgresql import insert as pg_insert + +from app.models.usage_record import UsageRecord + + +class TestUsageRecordModel: + """Test cases for UsageRecord model.""" + + @pytest.mark.asyncio + async def test_usage_record_create(self, async_session, test_user): + """Test creating a new usage record.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="chatgpt", + query="What is SEO optimization?", + input_tokens=100, + output_tokens=200, + cost=0.015, + extra_data={"model": "gpt-4"}, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.id is not None + assert record.user_id == test_user.id + assert record.engine_type == "chatgpt" + assert record.query == "What is SEO optimization?" + assert record.input_tokens == 100 + assert record.output_tokens == 200 + assert record.cost == 0.015 + assert record.extra_data == {"model": "gpt-4"} + assert record.timestamp is not None + assert record.created_at is not None + + @pytest.mark.asyncio + async def test_usage_record_default_values(self, async_session, test_user): + """Test usage record default values.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="kimi", + query="Test query", + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.input_tokens == 0 + assert record.output_tokens == 0 + assert record.cost == 0.0 + assert record.extra_data == {} + assert record.brand_id is None + + @pytest.mark.asyncio + async def test_usage_record_query_by_user_id(self, async_session, test_user): + """Test querying usage records by user ID.""" + user_id = test_user.id + for i in range(3): + record = UsageRecord( + user_id=user_id, + engine_type="deepseek", + query=f"Test query {i}", + cost=float(i), + ) + async_session.add(record) + await async_session.commit() + + result = await async_session.execute( + select(UsageRecord).where(UsageRecord.user_id == user_id) + ) + records = result.scalars().all() + + assert len(records) == 3 + + @pytest.mark.asyncio + async def test_usage_record_query_by_user_and_engine(self, async_session, test_user): + """Test querying usage records by user ID and engine type.""" + user_id = test_user.id + record1 = UsageRecord( + user_id=user_id, + engine_type="chatgpt", + query="Test chatgpt", + cost=1.0, + ) + record2 = UsageRecord( + user_id=user_id, + engine_type="kimi", + query="Test kimi", + cost=2.0, + ) + async_session.add(record1) + async_session.add(record2) + await async_session.commit() + + result = await async_session.execute( + select(UsageRecord).where( + and_( + UsageRecord.user_id == user_id, + UsageRecord.engine_type == "chatgpt" + ) + ) + ) + records = result.scalars().all() + + assert len(records) == 1 + assert records[0].engine_type == "chatgpt" + + @pytest.mark.asyncio + async def test_usage_record_query_by_time_range(self, async_session, test_user): + """Test querying usage records by time range.""" + user_id = test_user.id + now = datetime.now(timezone.utc) + + old_record = UsageRecord( + user_id=user_id, + engine_type="gemini", + query="Old query", + cost=1.0, + timestamp=now - timedelta(days=10), + ) + new_record = UsageRecord( + user_id=user_id, + engine_type="gemini", + query="New query", + cost=2.0, + timestamp=now, + ) + async_session.add(old_record) + async_session.add(new_record) + await async_session.commit() + + cutoff = now - timedelta(days=7) + result = await async_session.execute( + select(UsageRecord).where( + and_( + UsageRecord.user_id == user_id, + UsageRecord.timestamp >= cutoff + ) + ) + ) + records = result.scalars().all() + + assert len(records) == 1 + assert records[0].query == "New query" + + @pytest.mark.asyncio + async def test_usage_record_aggregate_by_user(self, async_session, test_user): + """Test aggregating usage records by user.""" + user_id = test_user.id + records_data = [ + {"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01}, + {"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02}, + {"engine": "chatgpt", "input_tokens": 200, "output_tokens": 400, "cost": 0.03}, + ] + for data in records_data: + record = UsageRecord( + user_id=user_id, + engine_type=data["engine"], + query=f"Query for {data['engine']}", + input_tokens=data["input_tokens"], + output_tokens=data["output_tokens"], + cost=data["cost"], + ) + async_session.add(record) + await async_session.commit() + + result = await async_session.execute( + select( + UsageRecord.engine_type, + func.count(UsageRecord.id).label("count"), + func.sum(UsageRecord.input_tokens).label("total_input"), + func.sum(UsageRecord.output_tokens).label("total_output"), + func.sum(UsageRecord.cost).label("total_cost"), + ).where(UsageRecord.user_id == user_id).group_by(UsageRecord.engine_type) + ) + aggregates = result.all() + + assert len(aggregates) == 2 + + chatgpt_agg = next(a for a in aggregates if a.engine_type == "chatgpt") + assert chatgpt_agg.count == 2 + assert chatgpt_agg.total_input == 300 + assert chatgpt_agg.total_output == 600 + assert chatgpt_agg.total_cost == 0.04 + + @pytest.mark.asyncio + async def test_usage_record_aggregate_by_day(self, async_session, test_user): + """Test aggregating usage records by day.""" + user_id = test_user.id + now = datetime.now(timezone.utc) + today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) + + record1 = UsageRecord( + user_id=user_id, + engine_type="qwen", + query="Query today 1", + cost=1.0, + timestamp=today_start + timedelta(hours=10), + ) + record2 = UsageRecord( + user_id=user_id, + engine_type="qwen", + query="Query today 2", + cost=2.0, + timestamp=today_start + timedelta(hours=14), + ) + record3 = UsageRecord( + user_id=user_id, + engine_type="qwen", + query="Query yesterday", + cost=3.0, + timestamp=today_start - timedelta(days=1), + ) + async_session.add_all([record1, record2, record3]) + await async_session.commit() + + result = await async_session.execute( + select( + func.date(UsageRecord.timestamp).label("date"), + func.count(UsageRecord.id).label("count"), + func.sum(UsageRecord.cost).label("total_cost"), + ).where( + UsageRecord.user_id == user_id + ).group_by(func.date(UsageRecord.timestamp)) + ) + aggregates = result.all() + + assert len(aggregates) == 2 + + @pytest.mark.asyncio + async def test_usage_record_with_brand_id(self, async_session, test_user): + """Test usage record with brand association.""" + brand_id = uuid.uuid4() + record = UsageRecord( + user_id=test_user.id, + brand_id=brand_id, + engine_type="wenxin", + query="Brand query", + cost=5.0, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.brand_id == brand_id + + @pytest.mark.asyncio + async def test_usage_record_index_user_engine(self, async_session, test_user): + """Test composite index on user_id and engine_type.""" + user_id = test_user.id + for i in range(5): + record = UsageRecord( + user_id=user_id, + engine_type="doubao", + query=f"Doubao query {i}", + cost=float(i), + ) + async_session.add(record) + await async_session.commit() + + result = await async_session.execute( + select(UsageRecord).where( + and_( + UsageRecord.user_id == user_id, + UsageRecord.engine_type == "doubao" + ) + ) + ) + records = result.scalars().all() + + assert len(records) == 5 + + @pytest.mark.asyncio + async def test_usage_record_update(self, async_session, test_user): + """Test updating usage record fields.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="xinghuo", + query="Original query", + cost=1.0, + ) + async_session.add(record) + await async_session.commit() + + record.cost = 10.0 + record.extra_data = {"updated": True} + await async_session.commit() + await async_session.refresh(record) + + assert record.cost == 10.0 + assert record.extra_data == {"updated": True} + + @pytest.mark.asyncio + async def test_usage_record_delete(self, async_session, test_user): + """Test deleting a usage record.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="yuanbao", + query="Delete me", + cost=1.0, + ) + async_session.add(record) + await async_session.commit() + record_id = record.id + + await async_session.delete(record) + await async_session.commit() + + result = await async_session.execute( + select(UsageRecord).where(UsageRecord.id == record_id) + ) + deleted_record = result.scalar_one_or_none() + + assert deleted_record is None + + @pytest.mark.asyncio + async def test_usage_record_timestamps(self, async_session, test_user): + """Test usage record timestamp fields.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="perplexity", + query="Timestamp test", + cost=1.0, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.timestamp is not None + assert record.created_at is not None + assert isinstance(record.timestamp, datetime) + assert isinstance(record.created_at, datetime) + + @pytest.mark.asyncio + async def test_usage_record_query_multiple_users(self, async_session, test_user): + """Test querying records for multiple users.""" + other_user_id = uuid.uuid4() + + user1_record = UsageRecord( + user_id=test_user.id, + engine_type="chatgpt", + query="User 1 query", + cost=1.0, + ) + user2_record = UsageRecord( + user_id=other_user_id, + engine_type="kimi", + query="User 2 query", + cost=2.0, + ) + async_session.add(user1_record) + async_session.add(user2_record) + await async_session.commit() + + result_user1 = await async_session.execute( + select(UsageRecord).where(UsageRecord.user_id == test_user.id) + ) + result_user2 = await async_session.execute( + select(UsageRecord).where(UsageRecord.user_id == other_user_id) + ) + + assert len(result_user1.scalars().all()) == 1 + assert len(result_user2.scalars().all()) == 1 + + @pytest.mark.asyncio + async def test_usage_record_empty_query_field(self, async_session, test_user): + """Test usage record with empty query field.""" + record = UsageRecord( + user_id=test_user.id, + engine_type="deepseek", + query="", + cost=0.0, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.query == "" + + @pytest.mark.asyncio + async def test_usage_record_large_metadata(self, async_session, test_user): + """Test usage record with large metadata dict.""" + large_metadata = { + "model": "gpt-4-turbo", + "temperature": 0.7, + "max_tokens": 1000, + "system_prompt": "You are a helpful assistant." * 100, + "nested": {"key": {"deep": {"value": [1, 2, 3]}}}, + } + record = UsageRecord( + user_id=test_user.id, + engine_type="chatgpt", + query="Large metadata test", + extra_data=large_metadata, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + assert record.extra_data == large_metadata diff --git a/backend/tests/test_repositories/test_usage_quota_integration.py b/backend/tests/test_repositories/test_usage_quota_integration.py new file mode 100644 index 0000000..3adb7ed --- /dev/null +++ b/backend/tests/test_repositories/test_usage_quota_integration.py @@ -0,0 +1,377 @@ +"""Tests for by_day aggregation and user quota service.""" +import uuid +from datetime import datetime, timezone, timedelta + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.user import User +from app.models.usage_record import UsageRecord +from app.repositories.usage_repository import UsageRepository +from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS + + +@pytest_asyncio.fixture +async def async_engine(): + """Create async engine for testing with SQLite.""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """Create async session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + await session.rollback() + + +@pytest_asyncio.fixture +async def test_user_free(async_session): + """Create a free plan test user.""" + user = User( + id=uuid.uuid4(), + email="free@example.com", + password_hash="hashed_password", + name="Free User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_user_basic(async_session): + """Create a basic plan test user.""" + user = User( + id=uuid.uuid4(), + email="basic@example.com", + password_hash="hashed_password", + name="Basic User", + plan="basic", + max_queries=50, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_user_pro(async_session): + """Create a pro plan test user.""" + user = User( + id=uuid.uuid4(), + email="pro@example.com", + password_hash="hashed_password", + name="Pro User", + plan="pro", + max_queries=500, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +class TestByDayAggregation: + """Test cases for by_day aggregation in usage summary.""" + + @pytest.mark.asyncio + async def test_by_day_returns_data(self, async_session, test_user_free): + """Test that by_day aggregation returns non-empty data when records exist.""" + repo = UsageRepository(async_session) + + for i in range(3): + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "chatgpt", + "query": f"Query {i}", + "cost": 0.01, + "input_tokens": 100, + "output_tokens": 200, + }) + + summary = await repo.get_summary(str(test_user_free.id), period="month") + + assert "by_day" in summary + assert len(summary["by_day"]) > 0, "by_day should not be empty when records exist" + + @pytest.mark.asyncio + async def test_by_day_groups_by_date(self, async_session, test_user_free): + """Test that records are correctly grouped by date.""" + repo = UsageRepository(async_session) + + for i in range(5): + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "deepseek", + "query": f"Query {i}", + "cost": 0.02, + }) + + summary = await repo.get_summary(str(test_user_free.id), period="month") + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + assert today in summary["by_day"], f"Today's date {today} should be in by_day" + assert summary["by_day"][today]["queries"] == 5 + assert summary["by_day"][today]["cost"] == 0.10 + + @pytest.mark.asyncio + async def test_by_day_aggregates_tokens(self, async_session, test_user_free): + """Test that by_day correctly aggregates tokens.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "qwen", + "query": "Query 1", + "input_tokens": 100, + "output_tokens": 200, + "cost": 0.01, + }) + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "qwen", + "query": "Query 2", + "input_tokens": 150, + "output_tokens": 300, + "cost": 0.02, + }) + + summary = await repo.get_summary(str(test_user_free.id), period="month") + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + + assert summary["by_day"][today]["input_tokens"] == 250 + assert summary["by_day"][today]["output_tokens"] == 500 + assert summary["by_day"][today]["cost"] == 0.03 + + @pytest.mark.asyncio + async def test_by_day_empty_when_no_records(self, async_session, test_user_free): + """Test that by_day is empty when no records exist.""" + repo = UsageRepository(async_session) + + summary = await repo.get_summary(str(test_user_free.id), period="month") + + assert summary["by_day"] == {} + + @pytest.mark.asyncio + async def test_by_day_multiple_days(self, async_session, test_user_free): + """Test by_day when records span multiple days.""" + repo = UsageRepository(async_session) + yesterday = datetime.now(timezone.utc) - timedelta(days=1) + + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "gemini", + "query": "Yesterday query", + "cost": 0.05, + "timestamp": yesterday, + }) + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "gemini", + "query": "Today query", + "cost": 0.05, + }) + + summary = await repo.get_summary(str(test_user_free.id), period="week") + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + yesterday_str = yesterday.strftime("%Y-%m-%d") + + assert today in summary["by_day"] + assert yesterday_str in summary["by_day"] + + +class TestUserQuotaService: + """Test cases for UserQuotaService with plan-based monthly limits.""" + + def test_plan_monthly_limits_defined(self): + """Test that all plan monthly limits are defined.""" + assert "free" in PLAN_MONTHLY_LIMITS + assert "basic" in PLAN_MONTHLY_LIMITS + assert "pro" in PLAN_MONTHLY_LIMITS + assert "enterprise" in PLAN_MONTHLY_LIMITS + + def test_free_plan_monthly_limit(self): + """Test that free plan has 10 yuan monthly limit.""" + assert PLAN_MONTHLY_LIMITS["free"] == 10.0 + + def test_basic_plan_monthly_limit(self): + """Test that basic plan has 50 yuan monthly limit.""" + assert PLAN_MONTHLY_LIMITS["basic"] == 50.0 + + def test_pro_plan_monthly_limit(self): + """Test that pro plan has 200 yuan monthly limit.""" + assert PLAN_MONTHLY_LIMITS["pro"] == 200.0 + + def test_enterprise_plan_monthly_limit(self): + """Test that enterprise plan has 1000 yuan monthly limit.""" + assert PLAN_MONTHLY_LIMITS["enterprise"] == 1000.0 + + def test_unknown_plan_defaults_to_free(self): + """Test that unknown plan defaults to free plan limit.""" + service = UserQuotaService() + limit = service.get_monthly_limit("unknown_plan") + assert limit == 10.0 + + def test_get_monthly_limit_free(self): + """Test get_monthly_limit for free plan.""" + service = UserQuotaService() + assert service.get_monthly_limit("free") == 10.0 + + def test_get_monthly_limit_basic(self): + """Test get_monthly_limit for basic plan.""" + service = UserQuotaService() + assert service.get_monthly_limit("basic") == 50.0 + + def test_get_monthly_limit_pro(self): + """Test get_monthly_limit for pro plan.""" + service = UserQuotaService() + assert service.get_monthly_limit("pro") == 200.0 + + def test_get_monthly_limit_enterprise(self): + """Test get_monthly_limit for enterprise plan.""" + service = UserQuotaService() + assert service.get_monthly_limit("enterprise") == 1000.0 + + @pytest.mark.asyncio + async def test_check_quota_with_free_plan(self, async_session, test_user_free): + """Test quota check uses free plan limit (10 yuan).""" + repo = UsageRepository(async_session) + + for i in range(5): + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "chatgpt", + "query": f"Query {i}", + "cost": 1.0, + }) + + service = UserQuotaService(session=async_session) + result = await service.check_quota_with_plan( + user_id=str(test_user_free.id), + user_plan="free" + ) + + assert result["limit"] == 10.0 + assert result["used"] == 5.0 + assert result["usage_percentage"] == 50.0 + + @pytest.mark.asyncio + async def test_check_quota_with_basic_plan(self, async_session, test_user_basic): + """Test quota check uses basic plan limit (50 yuan).""" + repo = UsageRepository(async_session) + + for i in range(10): + await repo.create({ + "user_id": test_user_basic.id, + "engine_type": "deepseek", + "query": f"Query {i}", + "cost": 2.0, + }) + + service = UserQuotaService(session=async_session) + result = await service.check_quota_with_plan( + user_id=str(test_user_basic.id), + user_plan="basic" + ) + + assert result["limit"] == 50.0 + assert result["used"] == 20.0 + assert result["usage_percentage"] == 40.0 + + @pytest.mark.asyncio + async def test_check_quota_with_pro_plan(self, async_session, test_user_pro): + """Test quota check uses pro plan limit (200 yuan).""" + repo = UsageRepository(async_session) + + for i in range(10): + await repo.create({ + "user_id": test_user_pro.id, + "engine_type": "qwen", + "query": f"Query {i}", + "cost": 10.0, + }) + + service = UserQuotaService(session=async_session) + result = await service.check_quota_with_plan( + user_id=str(test_user_pro.id), + user_plan="pro" + ) + + assert result["limit"] == 200.0 + assert result["used"] == 100.0 + assert result["usage_percentage"] == 50.0 + + @pytest.mark.asyncio + async def test_check_quota_exceeded_status(self, async_session, test_user_free): + """Test that exceeded status works correctly with plan limits.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user_free.id, + "engine_type": "gemini", + "query": "Expensive query", + "cost": 15.0, + }) + + service = UserQuotaService(session=async_session) + result = await service.check_quota_with_plan( + user_id=str(test_user_free.id), + user_plan="free" + ) + + assert result["limit"] == 10.0 + assert result["used"] == 15.0 + assert result["usage_percentage"] == 150.0 + assert result["status"] == "exceeded" + + @pytest.mark.asyncio + async def test_check_quota_warning_status(self, async_session, test_user_basic): + """Test that warning status works correctly with plan limits.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user_basic.id, + "engine_type": "kimi", + "query": "Moderate query", + "cost": 45.0, + }) + + service = UserQuotaService(session=async_session) + result = await service.check_quota_with_plan( + user_id=str(test_user_basic.id), + user_plan="basic" + ) + + assert result["limit"] == 50.0 + assert result["used"] == 45.0 + assert result["usage_percentage"] == 90.0 + assert result["status"] == "warning" diff --git a/backend/tests/test_repositories/test_usage_repository.py b/backend/tests/test_repositories/test_usage_repository.py new file mode 100644 index 0000000..d20e271 --- /dev/null +++ b/backend/tests/test_repositories/test_usage_repository.py @@ -0,0 +1,371 @@ +"""Tests for UsageRepository.""" +import uuid +from datetime import datetime, timezone, timedelta + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.user import User +from app.models.usage_record import UsageRecord +from app.repositories.usage_repository import UsageRepository + + +@pytest_asyncio.fixture +async def async_engine(): + """Create async engine for testing with SQLite.""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """Create async session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + await session.rollback() + + +@pytest_asyncio.fixture +async def test_user(async_session): + """Create a test user.""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash="hashed_password", + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +class TestUsageRepository: + """Test cases for UsageRepository.""" + + @pytest.mark.asyncio + async def test_create(self, async_session, test_user): + """Test creating a usage record.""" + repo = UsageRepository(async_session) + + data = { + "user_id": test_user.id, + "engine_type": "chatgpt", + "query": "Test query", + "input_tokens": 100, + "output_tokens": 200, + "cost": 0.015, + "extra_data": {"model": "gpt-4"}, + } + + record = await repo.create(data) + + assert record.id is not None + assert record.user_id == test_user.id + assert record.engine_type == "chatgpt" + assert record.query == "Test query" + assert record.input_tokens == 100 + assert record.output_tokens == 200 + assert record.cost == 0.015 + assert record.extra_data == {"model": "gpt-4"} + + @pytest.mark.asyncio + async def test_create_minimal(self, async_session, test_user): + """Test creating a usage record with minimal data.""" + repo = UsageRepository(async_session) + + data = { + "user_id": test_user.id, + "engine_type": "deepseek", + "query": "Minimal query", + } + + record = await repo.create(data) + + assert record.id is not None + assert record.user_id == test_user.id + assert record.engine_type == "deepseek" + assert record.query == "Minimal query" + assert record.input_tokens == 0 + assert record.output_tokens == 0 + assert record.cost == 0.0 + assert record.extra_data == {} + + @pytest.mark.asyncio + async def test_get_summary(self, async_session, test_user): + """Test getting usage summary.""" + repo = UsageRepository(async_session) + + for i in range(3): + await repo.create({ + "user_id": test_user.id, + "engine_type": "chatgpt", + "query": f"Query {i}", + "input_tokens": 100, + "output_tokens": 200, + "cost": 0.01, + }) + + summary = await repo.get_summary(str(test_user.id), period="month") + + assert summary["period"] == "month" + assert summary["total_queries"] == 3 + assert summary["total_input_tokens"] == 300 + assert summary["total_output_tokens"] == 600 + assert summary["total_cost"] == 0.03 + assert "chatgpt" in summary["by_engine"] + assert summary["by_engine"]["chatgpt"]["queries"] == 3 + + @pytest.mark.asyncio + async def test_get_summary_by_engine(self, async_session, test_user): + """Test getting usage summary grouped by engine.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user.id, + "engine_type": "chatgpt", + "query": "ChatGPT query", + "cost": 0.02, + }) + await repo.create({ + "user_id": test_user.id, + "engine_type": "deepseek", + "query": "DeepSeek query", + "cost": 0.01, + }) + + summary = await repo.get_summary(str(test_user.id), period="month") + + assert len(summary["by_engine"]) == 2 + assert summary["by_engine"]["chatgpt"]["queries"] == 1 + assert summary["by_engine"]["chatgpt"]["cost"] == 0.02 + assert summary["by_engine"]["deepseek"]["queries"] == 1 + assert summary["by_engine"]["deepseek"]["cost"] == 0.01 + + @pytest.mark.asyncio + async def test_get_summary_by_day(self, async_session, test_user): + """Test getting usage summary grouped by day.""" + repo = UsageRepository(async_session) + + for i in range(2): + await repo.create({ + "user_id": test_user.id, + "engine_type": "qwen", + "query": f"Query {i}", + "cost": 0.01, + }) + + summary = await repo.get_summary(str(test_user.id), period="month") + + assert len(summary["by_day"]) >= 1 + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + assert today in summary["by_day"] + assert summary["by_day"][today]["queries"] == 2 + + @pytest.mark.asyncio + async def test_get_summary_with_brand_filter(self, async_session, test_user): + """Test getting usage summary filtered by brand.""" + repo = UsageRepository(async_session) + brand_id = uuid.uuid4() + + await repo.create({ + "user_id": test_user.id, + "brand_id": brand_id, + "engine_type": "kimi", + "query": "Brand query", + "cost": 0.05, + }) + await repo.create({ + "user_id": test_user.id, + "engine_type": "kimi", + "query": "No brand query", + "cost": 0.03, + }) + + summary = await repo.get_summary( + str(test_user.id), + period="month", + brand_id=str(brand_id), + ) + + assert summary["total_queries"] == 1 + assert summary["total_cost"] == 0.05 + + @pytest.mark.asyncio + async def test_check_quota(self, async_session, test_user): + """Test checking quota usage.""" + repo = UsageRepository(async_session) + + for i in range(5): + await repo.create({ + "user_id": test_user.id, + "engine_type": "gemini", + "query": f"Query {i}", + "cost": 1.0, + }) + + result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) + + assert result["used"] == 5.0 + assert result["limit"] == 100.0 + assert result["usage_percentage"] == 5.0 + assert result["status"] == "ok" + + @pytest.mark.asyncio + async def test_check_quota_warning(self, async_session, test_user): + """Test quota warning status.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user.id, + "engine_type": "chatgpt", + "query": "Expensive query", + "cost": 85.0, + }) + + result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) + + assert result["status"] == "warning" + assert result["usage_percentage"] == 85.0 + + @pytest.mark.asyncio + async def test_check_quota_exceeded(self, async_session, test_user): + """Test quota exceeded status.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user.id, + "engine_type": "chatgpt", + "query": "Very expensive query", + "cost": 120.0, + }) + + result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) + + assert result["status"] == "exceeded" + assert result["usage_percentage"] == 120.0 + + @pytest.mark.asyncio + async def test_get_by_id(self, async_session, test_user): + """Test getting a usage record by ID.""" + repo = UsageRepository(async_session) + + created = await repo.create({ + "user_id": test_user.id, + "engine_type": "wenxin", + "query": "Get by ID test", + "cost": 0.5, + }) + + fetched = await repo.get_by_id(created.id) + + assert fetched is not None + assert fetched.id == created.id + assert fetched.engine_type == "wenxin" + + @pytest.mark.asyncio + async def test_get_by_user(self, async_session, test_user): + """Test getting usage records by user.""" + repo = UsageRepository(async_session) + + for i in range(3): + await repo.create({ + "user_id": test_user.id, + "engine_type": "doubao", + "query": f"User query {i}", + "cost": 0.1, + }) + + records = await repo.get_by_user(str(test_user.id)) + + assert len(records) == 3 + + @pytest.mark.asyncio + async def test_get_by_user_and_engine(self, async_session, test_user): + """Test getting usage records by user and engine.""" + repo = UsageRepository(async_session) + + await repo.create({ + "user_id": test_user.id, + "engine_type": "xinghuo", + "query": "Xinghuo query", + "cost": 0.1, + }) + await repo.create({ + "user_id": test_user.id, + "engine_type": "perplexity", + "query": "Perplexity query", + "cost": 0.2, + }) + + records = await repo.get_by_user_and_engine( + str(test_user.id), + "xinghuo", + ) + + assert len(records) == 1 + assert records[0].engine_type == "xinghuo" + + @pytest.mark.asyncio + async def test_empty_summary(self, async_session, test_user): + """Test getting summary when no records exist.""" + repo = UsageRepository(async_session) + + summary = await repo.get_summary(str(test_user.id), period="month") + + assert summary["total_queries"] == 0 + assert summary["total_input_tokens"] == 0 + assert summary["total_output_tokens"] == 0 + assert summary["total_cost"] == 0.0 + assert summary["by_engine"] == {} + assert summary["by_day"] == {} + + @pytest.mark.asyncio + async def test_empty_quota_check(self, async_session, test_user): + """Test checking quota when no records exist.""" + repo = UsageRepository(async_session) + + result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) + + assert result["used"] == 0.0 + assert result["usage_percentage"] == 0.0 + assert result["status"] == "ok" + + @pytest.mark.asyncio + async def test_uuid_handling(self, async_session, test_user): + """Test that UUID handling works correctly.""" + repo = UsageRepository(async_session) + + data = { + "user_id": test_user.id, + "engine_type": "yuanbao", + "query": "UUID test", + "cost": 0.5, + } + + record = await repo.create(data) + summary = await repo.get_summary(test_user.id, period="month") + + assert summary["total_queries"] == 1 + assert record.user_id == test_user.id diff --git a/backend/tests/test_services/test_adapter_key_source.py b/backend/tests/test_services/test_adapter_key_source.py new file mode 100644 index 0000000..4eb26ec --- /dev/null +++ b/backend/tests/test_services/test_adapter_key_source.py @@ -0,0 +1,107 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from app.services.ai_engine.base import AIEngineAdapter, EngineType +from app.services.api_key_manager import APIKeyManager, KeySource + + +class MockAdapter(AIEngineAdapter): + def __init__(self, api_key: str | None = None, **kwargs): + super().__init__(api_key=api_key, **kwargs) + + async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None): + pass + + def get_engine_type(self) -> EngineType: + return EngineType.CHATGPT + + def _get_env_key(self) -> str | None: + return os.getenv("OPENAI_API_KEY", "") + + +class TestAdapterKeySource: + def test_adapter_accepts_key_manager_parameter(self): + key_manager = MagicMock(spec=APIKeyManager) + adapter = MockAdapter(key_manager=key_manager, user_id="user123") + assert adapter._key_manager is key_manager + assert adapter._user_id == "user123" + + def test_adapter_accepts_api_key_parameter(self): + adapter = MockAdapter(api_key="direct-key-123") + assert adapter.api_key == "direct-key-123" + + def test_resolve_key_from_direct_api_key(self): + adapter = MockAdapter(api_key="direct-key-123") + assert adapter.api_key == "direct-key-123" + + def test_resolve_key_from_key_manager_user_key(self): + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = "manager-key-456" + + adapter = MockAdapter(key_manager=key_manager, user_id="user123") + + key_manager.get_key.assert_called_once_with("chatgpt", user_id="user123") + assert adapter.api_key == "manager-key-456" + + def test_resolve_key_fallback_to_env_when_no_manager(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key-789") + + adapter = MockAdapter() + assert adapter.api_key == "env-key-789" + + def test_direct_api_key_priority_over_key_manager(self): + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = "manager-key-456" + + adapter = MockAdapter(api_key="direct-key-123", key_manager=key_manager, user_id="user123") + + key_manager.get_key.assert_not_called() + assert adapter.api_key == "direct-key-123" + + def test_user_key_priority_over_system_key(self): + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = "user-key-from-manager" + + adapter = MockAdapter(key_manager=key_manager, user_id="user123") + assert adapter.api_key == "user-key-from-manager" + + def test_no_key_available_returns_empty(self): + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = None + + adapter = MockAdapter(key_manager=key_manager, user_id="user123") + assert adapter.api_key == "" + + def test_adapter_with_all_parameters(self): + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = "full-key" + + adapter = MockAdapter( + api_key=None, + rate_limiter=MagicMock(), + proxy="http://proxy:8080", + key_manager=key_manager, + user_id="test-user" + ) + assert adapter._key_manager is key_manager + assert adapter._user_id == "test-user" + assert adapter.proxy == "http://proxy:8080" + + +class TestBatchQueryServiceKeyIntegration: + def test_build_adapters_with_key_manager(self): + from app.services.ai_engine.batch_query import BatchQueryService + + key_manager = MagicMock(spec=APIKeyManager) + key_manager.get_key.return_value = "batch-test-key" + + adapters = { + "chatgpt": MockAdapter(key_manager=key_manager, user_id="batch-user") + } + service = BatchQueryService(adapters) + service.set_user_context(user_id="batch-user", brand_id="brand-123") + + assert service._user_id == "batch-user" + assert service._brand_id == "brand-123" diff --git a/backend/tests/test_services/test_adapter_registration.py b/backend/tests/test_services/test_adapter_registration.py new file mode 100644 index 0000000..0b66449 --- /dev/null +++ b/backend/tests/test_services/test_adapter_registration.py @@ -0,0 +1,115 @@ +import pytest + +from app.services.ai_engine.base import EngineType + + +class TestAllAdaptersRegistered: + def test_all_nine_adapters_in_adapter_classes(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + + registered_types = set(_ADAPTER_CLASSES.keys()) + expected_types = { + EngineType.CHATGPT, + EngineType.PERPLEXITY, + EngineType.KIMI, + EngineType.WENXIN, + EngineType.DOUBAO, + EngineType.YUANBAO, + EngineType.DEEPSEEK, + EngineType.QWEN, + EngineType.GEMINI, + } + + assert registered_types == expected_types, ( + f"Missing adapters: {expected_types - registered_types}. " + f"Extra adapters: {registered_types - expected_types}" + ) + + def test_adapter_classes_has_nine_entries(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + + assert len(_ADAPTER_CLASSES) == 9, ( + f"Expected 9 adapters, got {len(_ADAPTER_CLASSES)}" + ) + + def test_deepseek_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.deepseek import DeepSeekAdapter + + assert EngineType.DEEPSEEK in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.DEEPSEEK] == DeepSeekAdapter + + def test_qwen_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.qwen import QwenAdapter + + assert EngineType.QWEN in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.QWEN] == QwenAdapter + + def test_gemini_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.gemini import GeminiAdapter + + assert EngineType.GEMINI in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.GEMINI] == GeminiAdapter + + def test_chatgpt_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.chatgpt import ChatGPTAdapter + + assert EngineType.CHATGPT in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.CHATGPT] == ChatGPTAdapter + + def test_perplexity_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.perplexity import PerplexityAdapter + + assert EngineType.PERPLEXITY in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.PERPLEXITY] == PerplexityAdapter + + def test_kimi_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.kimi import KimiAdapter + + assert EngineType.KIMI in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.KIMI] == KimiAdapter + + def test_wenxin_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.wenxin import WenxinAdapter + + assert EngineType.WENXIN in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.WENXIN] == WenxinAdapter + + def test_doubao_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.doubao import DoubaoAdapter + + assert EngineType.DOUBAO in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.DOUBAO] == DoubaoAdapter + + def test_yuanbao_adapter_registered(self): + from app.services.ai_engine.batch_query import _ADAPTER_CLASSES + from app.services.ai_engine.yuanbao import YuanbaoAdapter + + assert EngineType.YUANBAO in _ADAPTER_CLASSES + assert _ADAPTER_CLASSES[EngineType.YUANBAO] == YuanbaoAdapter + + +class TestBatchServiceWithAllAdapters: + def test_batch_service_builds_all_adapters(self): + from app.services.ai_engine.batch_query import _build_adapters + + adapters = _build_adapters() + + expected_engines = [ + "chatgpt", "perplexity", "kimi", "wenxin", "doubao", + "yuanbao", "deepseek", "qwen", "gemini" + ] + + for engine in expected_engines: + assert engine in adapters, f"Missing adapter: {engine}" + + assert len(adapters) == 9, ( + f"Expected 9 adapters built, got {len(adapters)}" + ) diff --git a/backend/tests/test_services/test_ai_engine_query.py b/backend/tests/test_services/test_ai_engine_query.py index f2a93c8..0270943 100644 --- a/backend/tests/test_services/test_ai_engine_query.py +++ b/backend/tests/test_services/test_ai_engine_query.py @@ -128,6 +128,9 @@ class TestAIEngineAdapterBase: def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + adapter = ConcreteAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "BrandX is the best insurance company", @@ -147,6 +150,9 @@ class TestAIEngineAdapterBase: def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + adapter = ConcreteAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "CompetitorY is also a good choice for insurance", @@ -166,6 +172,9 @@ class TestAIEngineAdapterBase: def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + adapter = ConcreteAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "BrandX and CompetitorY are both good insurance options", @@ -185,6 +194,9 @@ class TestAIEngineAdapterBase: def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + adapter = ConcreteAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "Some random text without brand names", @@ -204,6 +216,9 @@ class TestAIEngineAdapterBase: def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + adapter = ConcreteAdapter(api_key="test-key") has_brand, _, _, _ = adapter._detect_citations( "brandx is great", diff --git a/backend/tests/test_services/test_api_key_manager.py b/backend/tests/test_services/test_api_key_manager.py index 20c221f..0bef018 100644 --- a/backend/tests/test_services/test_api_key_manager.py +++ b/backend/tests/test_services/test_api_key_manager.py @@ -171,9 +171,14 @@ class TestKeyVerification: return APIKeyManager() @pytest.mark.asyncio - async def test_verify_valid_key(self, manager): - status = await manager.verify_key("chatgpt", "sk-valid-key-1234567890") - assert status == KeyStatus.ACTIVE + async def test_verify_calls_factory(self, manager): + from unittest.mock import patch, AsyncMock + + with patch('app.services.api_key_manager.KeyVerifierFactory.verify', new_callable=AsyncMock) as mock_verify: + mock_verify.return_value = KeyStatus.ACTIVE + status = await manager.verify_key("chatgpt", "sk-valid-key-1234567890") + assert status == KeyStatus.ACTIVE + mock_verify.assert_called_once_with("chatgpt", "sk-valid-key-1234567890") @pytest.mark.asyncio async def test_verify_empty_key(self, manager): diff --git a/backend/tests/test_services/test_batch_query_service.py b/backend/tests/test_services/test_batch_query_service.py index 63b7766..287f765 100644 --- a/backend/tests/test_services/test_batch_query_service.py +++ b/backend/tests/test_services/test_batch_query_service.py @@ -42,6 +42,9 @@ class _StubAdapter(AIEngineAdapter): def get_engine_type(self) -> EngineType: return self._engine_type + def _get_env_key(self) -> str | None: + return "" + class TestBatchQueryServiceInit: @pytest.mark.asyncio diff --git a/backend/tests/test_services/test_key_encryption.py b/backend/tests/test_services/test_key_encryption.py new file mode 100644 index 0000000..e7793da --- /dev/null +++ b/backend/tests/test_services/test_key_encryption.py @@ -0,0 +1,164 @@ +import pytest +from cryptography.fernet import Fernet + + +class TestKeyEncryptionBasic: + """基础加密解密测试""" + + def test_encrypt_decrypt_roundtrip(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + plaintext = "sk-1234567890abcdef" + ciphertext = encryption.encrypt(plaintext) + decrypted = encryption.decrypt(ciphertext) + assert decrypted == plaintext + + def test_encrypted_content_differs_from_plaintext(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + plaintext = "sk-1234567890abcdef" + ciphertext = encryption.encrypt(plaintext) + assert ciphertext != plaintext + assert ciphertext != plaintext.encode() + + def test_same_plaintext_produces_different_ciphertext(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + plaintext = "sk-1234567890abcdef" + ciphertext1 = encryption.encrypt(plaintext) + ciphertext2 = encryption.encrypt(plaintext) + assert ciphertext1 != ciphertext2 + + +class TestKeyEncryptionEdgeCases: + """边界情况测试""" + + def test_encrypt_empty_string_raises_error(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + with pytest.raises(ValueError, match="Cannot encrypt empty string"): + encryption.encrypt("") + + def test_decrypt_empty_string_raises_error(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + with pytest.raises(ValueError, match="Cannot decrypt empty string"): + encryption.decrypt("") + + def test_encrypt_special_characters(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + plaintext = "sk-中文测试!@#$%^&*()_+-=[]{}|;':\",./<>?" + ciphertext = encryption.encrypt(plaintext) + decrypted = encryption.decrypt(ciphertext) + assert decrypted == plaintext + + def test_encrypt_unicode_characters(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + plaintext = "sk-日本語テスト한국어" + ciphertext = encryption.encrypt(plaintext) + decrypted = encryption.decrypt(ciphertext) + assert decrypted == plaintext + + +class TestKeyEncryptionInvalidInputs: + """无效输入测试""" + + def test_decrypt_invalid_ciphertext_raises_error(self): + from app.services.key_encryption import KeyEncryption + + encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345") + with pytest.raises(ValueError, match="Invalid ciphertext or key"): + encryption.decrypt("invalid-ciphertext-that-cannot-be-decrypted") + + def test_decrypt_with_wrong_key_raises_error(self): + from app.services.key_encryption import KeyEncryption + + encryption1 = KeyEncryption(encryption_key="key-one-for-encryption-1234") + encryption2 = KeyEncryption(encryption_key="key-two-for-encryption-1234") + + plaintext = "sk-1234567890abcdef" + ciphertext = encryption1.encrypt(plaintext) + + with pytest.raises(ValueError, match="Invalid ciphertext or key"): + encryption2.decrypt(ciphertext) + + +class TestKeyEncryptionKeyFormats: + """密钥格式测试""" + + def test_valid_fernet_key_format(self): + from app.services.key_encryption import KeyEncryption + + valid_key = Fernet.generate_key().decode() + encryption = KeyEncryption(encryption_key=valid_key) + plaintext = "sk-1234567890abcdef" + ciphertext = encryption.encrypt(plaintext) + decrypted = encryption.decrypt(ciphertext) + assert decrypted == plaintext + + def test_password_based_key_derivation(self): + from app.services.key_encryption import KeyEncryption + + password = "my-secret-password-12345" + encryption = KeyEncryption(encryption_key=password) + plaintext = "sk-1234567890abcdef" + ciphertext = encryption.encrypt(plaintext) + decrypted = encryption.decrypt(ciphertext) + assert decrypted == plaintext + + +class TestKeyEncryptionSingleton: + """全局实例测试""" + + def test_get_key_encryption_returns_instance(self): + from app.services.key_encryption import get_key_encryption + + instance = get_key_encryption() + assert instance is not None + assert hasattr(instance, "encrypt") + assert hasattr(instance, "decrypt") + + def test_reset_key_encryption_clears_singleton(self): + from app.services.key_encryption import get_key_encryption, reset_key_encryption + + instance1 = get_key_encryption() + reset_key_encryption() + instance2 = get_key_encryption() + assert instance1 is not instance2 + + +class TestAPIKeyManagerEncryption: + """APIKeyManager集成测试""" + + def test_api_key_manager_uses_fernet_encryption(self): + from app.services.api_key_manager import APIKeyManager + + manager = APIKeyManager() + original_key = "sk-1234567890abcdef" + config = manager.add_key("chatgpt", original_key) + + assert config.encrypted_key != original_key + assert not config.encrypted_key.startswith("sk-") + + decrypted = manager.get_key("chatgpt") + assert decrypted == original_key + + def test_encrypted_key_is_not_base64_of_plaintext(self): + from app.services.api_key_manager import APIKeyManager + import base64 + + manager = APIKeyManager() + original_key = "sk-1234567890abcdef" + config = manager.add_key("chatgpt", original_key) + + plain_base64 = base64.b64encode(original_key.encode()).decode() + assert config.encrypted_key != plain_base64 diff --git a/backend/tests/test_services/test_key_verifier.py b/backend/tests/test_services/test_key_verifier.py new file mode 100644 index 0000000..10e1c33 --- /dev/null +++ b/backend/tests/test_services/test_key_verifier.py @@ -0,0 +1,354 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import httpx + +from app.services.key_verifier import KeyStatus +from app.services.key_verifier import ( + KeyVerifier, + KeyVerifierFactory, + ChatGPTKeyVerifier, + DeepSeekKeyVerifier, + KimiKeyVerifier, + QwenKeyVerifier, + PerplexityKeyVerifier, + GeminiKeyVerifier, + WenxinKeyVerifier, + DoubaoKeyVerifier, + YuanbaoKeyVerifier, + DefaultKeyVerifier, +) + + +def create_mock_client(mock_response: MagicMock) -> AsyncMock: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.get = AsyncMock(return_value=mock_response) + return mock_client + + +class TestKeyVerifierFactory: + def test_get_chatgpt_verifier(self): + verifier = KeyVerifierFactory.get_verifier("chatgpt") + assert isinstance(verifier, ChatGPTKeyVerifier) + + def test_get_deepseek_verifier(self): + verifier = KeyVerifierFactory.get_verifier("deepseek") + assert isinstance(verifier, DeepSeekKeyVerifier) + + def test_get_kimi_verifier(self): + verifier = KeyVerifierFactory.get_verifier("kimi") + assert isinstance(verifier, KimiKeyVerifier) + + def test_get_qwen_verifier(self): + verifier = KeyVerifierFactory.get_verifier("qwen") + assert isinstance(verifier, QwenKeyVerifier) + + def test_get_perplexity_verifier(self): + verifier = KeyVerifierFactory.get_verifier("perplexity") + assert isinstance(verifier, PerplexityKeyVerifier) + + def test_get_gemini_verifier(self): + verifier = KeyVerifierFactory.get_verifier("gemini") + assert isinstance(verifier, GeminiKeyVerifier) + + def test_get_unknown_verifier_returns_default(self): + verifier = KeyVerifierFactory.get_verifier("unknown_engine") + assert isinstance(verifier, DefaultKeyVerifier) + + @pytest.mark.asyncio + async def test_verify_delegates_to_correct_verifier(self): + with patch.object(ChatGPTKeyVerifier, 'verify', new_callable=AsyncMock) as mock_verify: + mock_verify.return_value = KeyStatus.ACTIVE + status = await KeyVerifierFactory.verify("chatgpt", "sk-test-key-12345") + mock_verify.assert_called_once_with("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + + +class TestDefaultKeyVerifier: + @pytest.mark.asyncio + async def test_default_verifier_returns_unknown(self): + verifier = DefaultKeyVerifier() + status = await verifier.verify("any-key") + assert status == KeyStatus.UNKNOWN + + +class TestChatGPTKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = ChatGPTKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + mock_client.post.assert_called_once() + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = ChatGPTKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Invalid API key" + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + @pytest.mark.asyncio + async def test_verify_rate_limited_key_returns_rate_limited(self): + verifier = ChatGPTKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 429 + mock_response.text = "Rate limit exceeded" + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-rate-limited-key") + assert status == KeyStatus.RATE_LIMITED + + @pytest.mark.asyncio + async def test_verify_network_error_returns_unknown(self): + verifier = ChatGPTKeyVerifier() + + mock_client = create_mock_client(MagicMock()) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key") + assert status == KeyStatus.UNKNOWN + + @pytest.mark.asyncio + async def test_verify_timeout_returns_unknown(self): + verifier = ChatGPTKeyVerifier() + + mock_client = create_mock_client(MagicMock()) + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Request timeout")) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key") + assert status == KeyStatus.UNKNOWN + + +class TestDeepSeekKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = DeepSeekKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + mock_client.post.assert_called_once() + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = DeepSeekKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 401 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + @pytest.mark.asyncio + async def test_verify_rate_limited_key_returns_rate_limited(self): + verifier = DeepSeekKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 429 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-rate-limited-key") + assert status == KeyStatus.RATE_LIMITED + + +class TestKimiKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = KimiKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + mock_client.post.assert_called_once() + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = KimiKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 401 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + +class TestQwenKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = QwenKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = QwenKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 401 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + @pytest.mark.asyncio + async def test_verify_forbidden_key_returns_expired(self): + verifier = QwenKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 403 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-expired-key") + assert status == KeyStatus.EXPIRED + + +class TestPerplexityKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = PerplexityKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = PerplexityKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 401 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + +class TestGeminiKeyVerifier: + @pytest.mark.asyncio + async def test_verify_active_key_returns_active(self): + verifier = GeminiKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-test-key-12345") + assert status == KeyStatus.ACTIVE + + @pytest.mark.asyncio + async def test_verify_invalid_key_returns_invalid(self): + verifier = GeminiKeyVerifier() + mock_response = MagicMock() + mock_response.status_code = 403 + + mock_client = create_mock_client(mock_response) + + with patch('httpx.AsyncClient', return_value=mock_client): + status = await verifier.verify("sk-invalid-key") + assert status == KeyStatus.INVALID + + +class TestOtherEngineVerifiers: + @pytest.mark.asyncio + async def test_wenxin_verifier_returns_unknown(self): + verifier = WenxinKeyVerifier() + status = await verifier.verify("test-key") + assert status == KeyStatus.UNKNOWN + + @pytest.mark.asyncio + async def test_doubao_verifier_returns_unknown(self): + verifier = DoubaoKeyVerifier() + status = await verifier.verify("test-key") + assert status == KeyStatus.UNKNOWN + + @pytest.mark.asyncio + async def test_yuanbao_verifier_returns_unknown(self): + verifier = YuanbaoKeyVerifier() + status = await verifier.verify("test-key") + assert status == KeyStatus.UNKNOWN + + +class TestKeyVerifierIntegrationWithAPIManager: + @pytest.mark.asyncio + async def test_api_key_manager_verify_uses_factory(self): + from app.services.api_key_manager import APIKeyManager + + manager = APIKeyManager() + + with patch.object(KeyVerifierFactory, 'verify', new_callable=AsyncMock) as mock_verify: + mock_verify.return_value = KeyStatus.ACTIVE + status = await manager.verify_key("chatgpt", "sk-test-key-1234567890") + assert status == KeyStatus.ACTIVE + mock_verify.assert_called_once_with("chatgpt", "sk-test-key-1234567890") + + @pytest.mark.asyncio + async def test_api_key_manager_verify_handles_factory_error(self): + from app.services.api_key_manager import APIKeyManager + + manager = APIKeyManager() + + with patch.object(KeyVerifierFactory, 'verify', new_callable=AsyncMock) as mock_verify: + mock_verify.side_effect = Exception("Network error") + status = await manager.verify_key("chatgpt", "sk-test-key-1234567890") + assert status == KeyStatus.UNKNOWN + + @pytest.mark.asyncio + async def test_api_key_manager_verify_short_key_returns_invalid(self): + from app.services.api_key_manager import APIKeyManager + + manager = APIKeyManager() + status = await manager.verify_key("chatgpt", "short") + assert status == KeyStatus.INVALID + + @pytest.mark.asyncio + async def test_api_key_manager_verify_empty_key_returns_invalid(self): + from app.services.api_key_manager import APIKeyManager + + manager = APIKeyManager() + status = await manager.verify_key("chatgpt", "") + assert status == KeyStatus.INVALID diff --git a/backend/tests/test_services/test_proxy_and_deepseek.py b/backend/tests/test_services/test_proxy_and_deepseek.py index 8ffa056..b0ba043 100644 --- a/backend/tests/test_services/test_proxy_and_deepseek.py +++ b/backend/tests/test_services/test_proxy_and_deepseek.py @@ -17,6 +17,9 @@ class _ConcreteAdapter(AIEngineAdapter): def get_engine_type(self): return EngineType.CHATGPT + def _get_env_key(self) -> str | None: + return None + class TestProxySupportBase: def test_base_init_with_proxy(self): diff --git a/backend/tests/test_services/test_smart_router_key_integration.py b/backend/tests/test_services/test_smart_router_key_integration.py new file mode 100644 index 0000000..621abd5 --- /dev/null +++ b/backend/tests/test_services/test_smart_router_key_integration.py @@ -0,0 +1,165 @@ +import pytest + +from app.services.api_key_manager import APIKeyManager, KeySource +from app.services.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter +from app.services.engine_selector import EngineSelector + + +class TestSmartRouterWithKeyManager: + """SmartRouter与APIKeyManager集成测试""" + + @pytest.fixture + def key_manager(self): + km = APIKeyManager() + km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM, priority=1) + km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM, priority=1) + km.add_key("gemini", "test-gemini-key", source=KeySource.ENV, priority=0) + return km + + @pytest.fixture + def key_manager_with_user_key(self): + km = APIKeyManager() + km.add_key("deepseek", "system-deepseek-key", source=KeySource.SYSTEM, priority=0) + km.add_key("deepseek", "user-deepseek-key", source=KeySource.USER, priority=10, user_id="user_123") + km.add_key("qwen", "system-qwen-key", source=KeySource.SYSTEM, priority=1) + return km + + @pytest.fixture + def key_manager_no_keys(self): + return APIKeyManager() + + def test_router_accepts_key_manager(self, key_manager): + router = SmartRouter(key_manager=key_manager) + assert router._key_manager is key_manager + + def test_set_key_manager_after_init(self): + router = SmartRouter() + assert router._key_manager is None + key_manager = APIKeyManager() + router.set_key_manager(key_manager) + assert router._key_manager is key_manager + + def test_filter_by_available_keys(self, key_manager): + router = SmartRouter(key_manager=key_manager) + all_engines = list(ENGINE_COST_PROFILES.keys()) + filtered = router._filter_by_available_keys(all_engines) + assert "deepseek" in filtered + assert "qwen" in filtered + assert "gemini" in filtered + assert "chatgpt" not in filtered + assert "perplexity" not in filtered + assert "wenxin" not in filtered + + def test_filter_by_available_keys_no_manager(self): + router = SmartRouter() + all_engines = list(ENGINE_COST_PROFILES.keys()) + filtered = router._filter_by_available_keys(all_engines) + assert set(filtered) == set(all_engines) + + def test_select_engines_filters_no_key(self, key_manager): + router = SmartRouter(key_manager=key_manager) + selected = router.select_engines(max_engines=10) + for engine in selected: + key = key_manager.get_any_available_key(engine) + assert key is not None, f"Selected engine {engine} has no available key" + + def test_select_engines_returns_subset(self, key_manager): + router = SmartRouter(key_manager=key_manager) + selected = router.select_engines(max_engines=2) + assert len(selected) <= 2 + assert len(selected) > 0 + + def test_select_engines_user_key_priority(self, key_manager_with_user_key): + router = SmartRouter(key_manager=key_manager_with_user_key) + selected = router.select_engines(max_engines=10) + assert "deepseek" in selected + key = key_manager_with_user_key.get_any_available_key("deepseek") + assert key == "user-deepseek-key" + + def test_get_available_engines(self, key_manager): + router = SmartRouter(key_manager=key_manager) + available = router.get_available_engines() + for engine in available: + key = key_manager.get_any_available_key(engine) + assert key is not None + + def test_get_engines_by_cost_tier_with_filter(self, key_manager): + router = SmartRouter(key_manager=key_manager) + free_engines = router.get_engines_by_cost_tier(CostTier.FREE) + for engine in free_engines: + profile = ENGINE_COST_PROFILES[engine] + assert profile.cost_tier == CostTier.FREE + key = key_manager.get_any_available_key(engine) + assert key is not None + + def test_all_engines_no_keys_returns_empty(self, key_manager_no_keys): + router = SmartRouter(key_manager=key_manager_no_keys) + selected = router.select_engines(max_engines=5) + assert selected == [] or all( + ENGINE_COST_PROFILES[e].cost_tier in [CostTier.FREE, CostTier.LOW_COST] + and not ENGINE_COST_PROFILES[e].requires_own_key + for e in selected + ) + + +class TestEngineSelector: + """EngineSelector智能引擎选择器测试""" + + @pytest.fixture + def key_manager(self): + km = APIKeyManager() + km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM) + km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM) + km.add_key("gemini", "test-gemini-key", source=KeySource.ENV) + km.add_key("wenxin", "test-wenxin-key", source=KeySource.SYSTEM) + return km + + @pytest.fixture + def selector(self, key_manager): + return EngineSelector(key_manager=key_manager) + + def test_initialization(self, selector, key_manager): + assert selector.key_manager is key_manager + assert selector.router is not None + assert selector.router._key_manager is key_manager + + def test_select_engines_returns_valid_engines(self, selector, key_manager): + selected = selector.select_engines(max_engines=5) + for engine in selected: + key = key_manager.get_any_available_key(engine) + assert key is not None, f"Engine {engine} selected but has no key" + + def test_select_engines_respects_max_engines(self, selector): + selected = selector.select_engines(max_engines=2) + assert len(selected) <= 2 + + def test_select_engines_with_min_cost_tier(self, selector): + selected = selector.select_engines(max_engines=10, min_cost_tier="free") + for engine in selected: + profile = ENGINE_COST_PROFILES[engine] + assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST, CostTier.MID_COST, CostTier.HIGH_COST] + tier_order = ["free", "low_cost", "mid_cost", "high_cost"] + assert tier_order.index(profile.cost_tier.value) >= tier_order.index("free") + + def test_get_best_value_engine(self, selector, key_manager): + best = selector.get_best_value_engine() + if best: + profile = ENGINE_COST_PROFILES[best] + key = key_manager.get_any_available_key(best) + assert key is not None + assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST] + + def test_get_best_value_returns_none_when_no_keys(self): + km = APIKeyManager() + selector = EngineSelector(key_manager=km) + best = selector.get_best_value_engine() + assert best is None + + def test_select_engines_priority_domestic(self, selector): + selected = selector.select_engines(max_engines=10, prefer_domestic=True) + domestic = [e for e in selected if ENGINE_COST_PROFILES[e].domestic] + international = [e for e in selected if not ENGINE_COST_PROFILES[e].domestic] + if domestic and international: + last_domestic_idx = max(selected.index(e) for e in domestic) + first_intl_idx = min(selected.index(e) for e in international) + assert last_domestic_idx < first_intl_idx diff --git a/backend/tests/test_services/test_usage_recording.py b/backend/tests/test_services/test_usage_recording.py new file mode 100644 index 0000000..8019406 --- /dev/null +++ b/backend/tests/test_services/test_usage_recording.py @@ -0,0 +1,320 @@ +import pytest +from datetime import UTC, datetime +from unittest.mock import MagicMock, AsyncMock + +from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType +from app.services.usage_tracker import UsageTracker + + +def _make_result_with_tokens( + engine_type: EngineType, + query: str = "test query", + input_tokens: int = 100, + output_tokens: int = 200, + has_brand: bool = False, +) -> AIQueryResult: + return AIQueryResult( + engine_type=engine_type, + query=query, + raw_response="some response with content", + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=False, + brand_context="brand context" if has_brand else None, + competitor_contexts=[], + response_time_ms=100, + timestamp=datetime.now(UTC), + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + +class _MockAdapter(AIEngineAdapter): + def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None): + super().__init__(api_key="test-key") + self._engine_type = engine_type + self._result = result + + async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult: + return self._result + + def get_engine_type(self) -> EngineType: + return self._engine_type + + def _get_env_key(self) -> str | None: + return None + + +class TestUsageRecordingOnQuery: + """测试查询后自动记录用量""" + + @pytest.mark.asyncio + async def test_query_single_records_usage(self): + from app.services.ai_engine.batch_query import BatchQueryService + from app.services.usage_recorder import UsageRecorder + + result = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200) + adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT, result=result)} + service = BatchQueryService(adapters) + service.set_user_context("user123", "brand456") + + await service.query_single(EngineType.CHATGPT, "test query", "BrandX") + + summary = service.get_usage_summary() + assert summary.total_queries == 1 + assert summary.total_input_tokens == 100 + assert summary.total_output_tokens == 200 + assert summary.total_cost > 0 + + @pytest.mark.asyncio + async def test_batch_query_records_usage_for_all(self): + from app.services.ai_engine.batch_query import BatchQueryService + + r1 = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200) + r2 = _make_result_with_tokens(EngineType.DEEPSEEK, input_tokens=150, output_tokens=300) + adapters = { + "chatgpt": _MockAdapter(EngineType.CHATGPT, result=r1), + "deepseek": _MockAdapter(EngineType.DEEPSEEK, result=r2), + } + service = BatchQueryService(adapters) + service.set_user_context("user123") + + await service.query_batch( + [EngineType.CHATGPT, EngineType.DEEPSEEK], + "test query", + "BrandX", + ) + + summary = service.get_usage_summary() + assert summary.total_queries == 2 + assert summary.total_input_tokens == 250 + assert summary.total_output_tokens == 500 + + @pytest.mark.asyncio + async def test_no_user_context_no_recording(self): + from app.services.ai_engine.batch_query import BatchQueryService + + result = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200) + adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT, result=result)} + service = BatchQueryService(adapters) + + await service.query_single(EngineType.CHATGPT, "test query", "BrandX") + + summary = service.get_usage_summary() + assert summary.total_queries == 0 + + +class TestUsageRecorderCostCalculation: + """测试不同引擎记录正确的成本""" + + def test_chatgpt_cost_calculation(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + cost = recorder.calculate_cost("chatgpt", 1000, 2000) + expected_input = (1000 / 1_000_000) * 1.0 + expected_output = (2000 / 1_000_000) * 4.0 + expected = round(expected_input + expected_output, 6) + assert cost == expected + + def test_deepseek_cost_calculation(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + cost = recorder.calculate_cost("deepseek", 1000, 2000) + expected_input = (1000 / 1_000_000) * 0.25 + expected_output = (2000 / 1_000_000) * 6.0 + expected = round(expected_input + expected_output, 6) + assert cost == expected + + def test_kimi_cost_calculation(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + cost = recorder.calculate_cost("kimi", 1000, 2000) + expected_input = (1000 / 1_000_000) * 12.0 + expected_output = (2000 / 1_000_000) * 12.0 + expected = round(expected_input + expected_output, 6) + assert cost == expected + + def test_unknown_engine_zero_cost(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + cost = recorder.calculate_cost("unknown_engine", 1000, 2000) + assert cost == 0.0 + + def test_record_with_cost(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + recorder.record( + user_id="user123", + brand_id="brand456", + engine_type="chatgpt", + query="test query", + input_tokens=1000, + output_tokens=2000, + ) + + summary = tracker.get_summary(user_id="user123") + assert summary.total_queries == 1 + assert summary.total_input_tokens == 1000 + assert summary.total_output_tokens == 2000 + assert summary.total_cost > 0 + + +class TestUsageSummaryCalculation: + """测试用量汇总计算正确""" + + def test_summary_by_engine(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200) + recorder.record("user1", "brand1", "chatgpt", "q2", 150, 250) + recorder.record("user1", "brand1", "deepseek", "q3", 200, 300) + + summary = tracker.get_summary(user_id="user1") + assert summary.by_engine["chatgpt"]["queries"] == 2 + assert summary.by_engine["chatgpt"]["input_tokens"] == 250 + assert summary.by_engine["deepseek"]["queries"] == 1 + + def test_summary_by_day(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200) + recorder.record("user1", "brand1", "chatgpt", "q2", 150, 250) + + summary = tracker.get_summary(user_id="user1") + today = datetime.now(UTC).strftime("%Y-%m-%d") + assert today in summary.by_day + assert summary.by_day[today]["queries"] == 2 + + def test_summary_filter_by_brand(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200) + recorder.record("user1", "brand2", "chatgpt", "q2", 150, 250) + + summary = tracker.get_summary(user_id="user1", brand_id="brand1") + assert summary.total_queries == 1 + assert summary.total_input_tokens == 100 + + +class TestQuotaCheck: + """测试配额检查触发""" + + def test_quota_check_warning(self): + from app.services.usage_recorder import UsageRecorder + + tracker = UsageTracker() + recorder = UsageRecorder(tracker) + + cost = recorder.calculate_cost("chatgpt", 1000000, 2000000) + recorder.record("user1", "", "chatgpt", "q1", 1000000, 2000000) + quota = tracker.check_quota("user1", monthly_limit=cost / 2) + + assert quota["status"] == "exceeded" + assert quota["used"] > 0 + + def test_quota_check_ok(self): + tracker = UsageTracker() + quota = tracker.check_quota("new_user", monthly_limit=100.0) + assert quota["status"] == "ok" + assert quota["used"] == 0 + + +class TestBatchQueryServiceUsageIntegration: + """测试BatchQueryService与用量记录的集成""" + + @pytest.mark.asyncio + async def test_service_has_usage_recorder(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)} + service = BatchQueryService(adapters) + assert hasattr(service, "_recorder") + assert hasattr(service, "_tracker") + + @pytest.mark.asyncio + async def test_set_user_context_stores_context(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)} + service = BatchQueryService(adapters) + service.set_user_context("user123", "brand456") + + assert service._user_id == "user123" + assert service._brand_id == "brand456" + + @pytest.mark.asyncio + async def test_get_usage_summary_returns_summary(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)} + service = BatchQueryService(adapters) + + summary = service.get_usage_summary() + assert hasattr(summary, "total_queries") + assert hasattr(summary, "total_input_tokens") + assert hasattr(summary, "total_output_tokens") + assert hasattr(summary, "total_cost") + + +class TestAIQueryResultTokens: + """测试AIQueryResult包含token统计字段""" + + def test_result_has_token_fields(self): + result = AIQueryResult( + engine_type=EngineType.CHATGPT, + query="test", + raw_response="response", + citations=[], + has_brand_citation=False, + has_competitor_citation=False, + brand_context=None, + competitor_contexts=[], + response_time_ms=100, + timestamp=datetime.now(UTC), + input_tokens=100, + output_tokens=200, + ) + assert result.input_tokens == 100 + assert result.output_tokens == 200 + assert result.total_tokens == 300 + + def test_result_token_fields_default_to_zero(self): + result = AIQueryResult( + engine_type=EngineType.CHATGPT, + query="test", + raw_response="response", + citations=[], + has_brand_citation=False, + has_competitor_citation=False, + brand_context=None, + competitor_contexts=[], + response_time_ms=100, + timestamp=datetime.now(UTC), + ) + assert result.input_tokens == 0 + assert result.output_tokens == 0 + assert result.total_tokens == 0 diff --git a/frontend/app/(dashboard)/dashboard/usage/page.tsx b/frontend/app/(dashboard)/dashboard/usage/page.tsx index 3f83014..cbd9fc8 100644 --- a/frontend/app/(dashboard)/dashboard/usage/page.tsx +++ b/frontend/app/(dashboard)/dashboard/usage/page.tsx @@ -19,6 +19,7 @@ import { CheckCircle, Loader2, RefreshCw, + AlertCircle, } from "lucide-react"; import { LineChart, @@ -43,6 +44,18 @@ const TIME_RANGE_LABELS: Record = { month: "本月", }; +const ENGINE_LABELS: Record = { + deepseek: "DeepSeek", + chatgpt: "ChatGPT", + qwen: "通义千问", + gemini: "Google Gemini", + kimi: "Kimi", + wenxin: "文心一言", + doubao: "豆包", + yuanbao: "腾讯元宝", + perplexity: "Perplexity", +}; + const PIE_COLORS = [ "#3b82f6", "#8b5cf6", @@ -55,6 +68,37 @@ const PIE_COLORS = [ "#f97316", ]; +interface QuotaAPIResponse { + used: number; + limit: number; + usage_percentage: number; + status: "ok" | "warning" | "exceeded"; +} + +interface SummaryAPIResponse { + period: string; + start_date: string; + end_date: string; + total_queries: number; + total_input_tokens: number; + total_output_tokens: number; + total_cost: number; + by_engine: Record; +} + +interface ByEngineAPIResponse { + engines: Array<{ + type: string; + queries: number; + cost: number; + }>; +} + interface QuotaData { used: number; limit: number; @@ -83,37 +127,6 @@ interface UsageData { engineUsage: EngineUsageItem[]; } -const MOCK_USAGE_DATA: UsageData = { - quota: { used: 45.6, limit: 100, percentage: 45.6, status: "normal" }, - trends: [ - { date: "05-19", cost: 5.2 }, - { date: "05-20", cost: 6.8 }, - { date: "05-21", cost: 4.5 }, - { date: "05-22", cost: 7.3 }, - { date: "05-23", cost: 8.1 }, - { date: "05-24", cost: 6.9 }, - { date: "05-25", cost: 7.0 }, - ], - engineDistribution: [ - { engine: "deepseek", label: "DeepSeek", cost: 15.2 }, - { engine: "chatgpt", label: "ChatGPT", cost: 12.8 }, - { engine: "qwen", label: "通义千问", cost: 8.5 }, - { engine: "gemini", label: "Google Gemini", cost: 5.3 }, - { engine: "kimi", label: "Kimi", cost: 3.8 }, - ], - engineUsage: [ - { engine: "deepseek", label: "DeepSeek", queries: 1250, inputTokens: 3200000, outputTokens: 1800000, cost: 15.2 }, - { engine: "chatgpt", label: "ChatGPT", queries: 890, inputTokens: 2100000, outputTokens: 1500000, cost: 12.8 }, - { engine: "qwen", label: "通义千问", queries: 720, inputTokens: 1800000, outputTokens: 900000, cost: 8.5 }, - { engine: "gemini", label: "Google Gemini", queries: 450, inputTokens: 1100000, outputTokens: 600000, cost: 5.3 }, - { engine: "kimi", label: "Kimi", queries: 320, inputTokens: 800000, outputTokens: 400000, cost: 3.8 }, - { engine: "wenxin", label: "文心一言", queries: 280, inputTokens: 700000, outputTokens: 350000, cost: 0.02 }, - { engine: "doubao", label: "豆包", queries: 210, inputTokens: 520000, outputTokens: 280000, cost: 0.5 }, - { engine: "yuanbao", label: "腾讯元宝", queries: 150, inputTokens: 380000, outputTokens: 200000, cost: 0.7 }, - { engine: "perplexity", label: "Perplexity", queries: 30, inputTokens: 75000, outputTokens: 40000, cost: 2.6 }, - ], -}; - function formatTokenCount(count: number): string { if (count >= 1000000) return `${(count / 1000000).toFixed(1)}M`; if (count >= 1000) return `${(count / 1000).toFixed(0)}K`; @@ -186,19 +199,79 @@ function QuotaStatusBadge({ status }: { status: "normal" | "warning" | "exceeded export default function UsagePage() { const [timeRange, setTimeRange] = useState("7d"); - const { data: usageData, isLoading, refresh } = useApi( + const { data: quotaData, isLoading: isQuotaLoading, error: quotaError } = useApi("/api/v1/usage/quota"); + const { data: summaryData, isLoading: isSummaryLoading, error: summaryError, refresh: refreshSummary } = useApi( `/api/v1/usage/summary?period=${timeRange === "7d" ? "week" : timeRange === "30d" ? "month" : "month"}` ); + const { data: byEngineData, isLoading: isByEngineLoading, error: byEngineError } = useApi("/api/v1/usage/by-engine"); - const data = usageData || MOCK_USAGE_DATA; + const isLoading = isQuotaLoading || isSummaryLoading || isByEngineLoading; + const error = quotaError || summaryError || byEngineError; + const refresh = refreshSummary; + + const usageData = useMemo((): UsageData | null => { + if (!quotaData || !summaryData || !byEngineData) { + return null; + } + + const quota: QuotaData = { + used: quotaData.used, + limit: quotaData.limit, + percentage: quotaData.usage_percentage, + status: quotaData.status === "ok" ? "normal" : quotaData.status, + }; + + const trends: TrendItem[] = []; + if (summaryData.start_date && summaryData.end_date) { + const startDate = new Date(summaryData.start_date); + const endDate = new Date(summaryData.end_date); + const daysDiff = Math.ceil((endDate.getTime() - startDate.getTime()) / (1000 * 60 * 60 * 24)); + const avgDailyCost = daysDiff > 0 ? summaryData.total_cost / daysDiff : 0; + + for (let i = 0; i < Math.min(daysDiff, 7); i++) { + const date = new Date(startDate); + date.setDate(date.getDate() + i); + const monthDay = `${String(date.getMonth() + 1).padStart(2, "0")}-${String(date.getDate()).padStart(2, "0")}`; + trends.push({ date: monthDay, cost: avgDailyCost }); + } + } + + const engineDistribution = byEngineData.engines?.map((e) => ({ + engine: e.type, + label: ENGINE_LABELS[e.type] || e.type, + cost: e.cost, + })) || []; + + const engineUsage: EngineUsageItem[] = Object.entries(summaryData.by_engine || {}).map(([engine, data]) => ({ + engine, + label: ENGINE_LABELS[engine] || engine, + queries: data.queries, + inputTokens: data.input_tokens, + outputTokens: data.output_tokens, + cost: data.cost, + })); + + return { + quota, + trends, + engineDistribution, + engineUsage, + }; + }, [quotaData, summaryData, byEngineData]); + + const data = usageData; const totalCost = useMemo(() => { + if (!data) return 0; return data.engineUsage.reduce((sum, item) => sum + item.cost, 0); - }, [data.engineUsage]); + }, [data]); const totalQueries = useMemo(() => { + if (!data) return 0; return data.engineUsage.reduce((sum, item) => sum + item.queries, 0); - }, [data.engineUsage]); + }, [data]); + + const hasData = data && data.engineUsage.length > 0; return (
@@ -230,6 +303,22 @@ export default function UsagePage() {

正在加载用量数据...

+ ) : error ? ( +
+ +

加载失败: {error.message}

+ +
+ ) : !hasData ? ( +
+
+ +
+

暂无用量数据,开始查询后会自动记录

+
) : ( <>