feat: API Key管理+用量追踪完整功能链路v2(真实可用)
持久化存储: - APIKey模型 + APIKeyRepository(SQLAlchemy) - UsageRecord模型 + UsageRepository(SQLAlchemy) API Key验证: - KeyVerifier服务(真正调用引擎API验证) - 支持9个引擎的真实性验证 加密存储: - KeyEncryption服务(Fernet AES加密) - 环境变量API_KEY_ENCRYPTION_KEY 用量追踪: - UsageRecorder自动记录查询用量 - 按引擎/按日聚合(修复by_day空dict) - UserQuotaService支持套餐配额(free:10/basic:50/pro:200/enterprise:1000) 集成修复: - AI引擎适配器使用APIKeyManager获取Key(用户Key>系统Key>环境变量) - SmartRouter与APIKeyManager集成(过滤无Key引擎) - BatchQueryService自动记录用量并传递用户上下文 - 所有适配器支持引擎特定代理环境变量 前端: - usage页面替换MOCK为真实API调用 - 显示加载/错误/空状态 测试: 630 passed
This commit is contained in:
parent
290ef5a273
commit
4cc8f73bb4
|
|
@ -1,13 +1,13 @@
|
|||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_user, get_key_manager
|
||||
from app.models.user import User
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
from app.services.ai_engine.base import AIQueryResult, EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService, get_batch_service as _get_batch_service
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
|
|
@ -15,6 +15,9 @@ from app.services.ai_engine.perplexity import PerplexityAdapter
|
|||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -60,31 +63,6 @@ class BatchQueryResponse(BaseModel):
|
|||
citation_rate: CitationRateResponse
|
||||
|
||||
|
||||
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
|
||||
EngineType.CHATGPT: ChatGPTAdapter,
|
||||
EngineType.PERPLEXITY: PerplexityAdapter,
|
||||
EngineType.KIMI: KimiAdapter,
|
||||
EngineType.WENXIN: WenxinAdapter,
|
||||
EngineType.DOUBAO: DoubaoAdapter,
|
||||
EngineType.YUANBAO: YuanbaoAdapter,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _build_adapters() -> dict[str, AIEngineAdapter]:
|
||||
adapters: dict[str, AIEngineAdapter] = {}
|
||||
for engine_type, cls in _ADAPTER_CLASSES.items():
|
||||
try:
|
||||
adapters[engine_type.value] = cls()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
return adapters
|
||||
|
||||
|
||||
def get_batch_service() -> BatchQueryService:
|
||||
return BatchQueryService(_build_adapters())
|
||||
|
||||
|
||||
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
||||
return QueryResultResponse(
|
||||
engine_type=r.engine_type.value,
|
||||
|
|
@ -129,8 +107,10 @@ async def _execute_batch(
|
|||
async def query_single_engine(
|
||||
request: SingleQueryRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
||||
):
|
||||
service = get_batch_service()
|
||||
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
||||
service.set_user_context(str(current_user.id))
|
||||
engine_type = _parse_engine(request.engine)
|
||||
try:
|
||||
result = await service.query_single(
|
||||
|
|
@ -148,8 +128,10 @@ async def query_single_engine(
|
|||
async def query_batch_engines(
|
||||
request: BatchQueryRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
||||
):
|
||||
service = get_batch_service()
|
||||
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
||||
service.set_user_context(str(current_user.id))
|
||||
engine_types = [_parse_engine(e) for e in request.engines]
|
||||
return await _execute_batch(
|
||||
service, engine_types, request.query, request.brand_name, request.competitor_names
|
||||
|
|
@ -163,8 +145,10 @@ async def get_query_results(
|
|||
brand_name: str = Query(..., min_length=1, max_length=200),
|
||||
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
||||
):
|
||||
service = get_batch_service()
|
||||
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
||||
service.set_user_context(str(current_user.id))
|
||||
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
|
||||
engine_types = [_parse_engine(e) for e in engine_list]
|
||||
comp_names = (
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import uuid
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
|
@ -8,11 +9,19 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
from app.services.auth import verify_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_key_manager() -> APIKeyManager:
|
||||
manager = APIKeyManager()
|
||||
manager.load_env_keys()
|
||||
return manager
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends, Query
|
|||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.usage_tracker import UsageTracker
|
||||
from app.services.user_quota_service import UserQuotaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -44,6 +45,7 @@ async def get_usage_summary(
|
|||
"total_output_tokens": summary.total_output_tokens,
|
||||
"total_cost": summary.total_cost,
|
||||
"by_engine": summary.by_engine,
|
||||
"by_day": summary.by_day,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -51,7 +53,14 @@ async def get_usage_summary(
|
|||
async def get_quota(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return get_usage_tracker().check_quota(user_id=_user_id(current_user))
|
||||
tracker = get_usage_tracker()
|
||||
if tracker._session:
|
||||
quota_service = UserQuotaService(session=tracker._session)
|
||||
return await quota_service.check_quota_with_plan(
|
||||
user_id=_user_id(current_user),
|
||||
user_plan=current_user.plan,
|
||||
)
|
||||
return tracker.check_quota(user_id=_user_id(current_user))
|
||||
|
||||
|
||||
@router.get("/by-engine")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from app.models.user import User
|
||||
from app.models.api_key import APIKey
|
||||
from app.models.query import Query
|
||||
from app.models.citation_record import CitationRecord
|
||||
from app.models.query_task import QueryTask
|
||||
|
|
@ -30,9 +31,11 @@ from app.models.suggestion import Suggestion
|
|||
from app.models.alert import Alert
|
||||
from app.models.alert_setting import AlertSetting
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.models.usage_record import UsageRecord
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"APIKey",
|
||||
"Query",
|
||||
"CitationRecord",
|
||||
"QueryTask",
|
||||
|
|
@ -70,4 +73,5 @@ __all__ = [
|
|||
"Alert",
|
||||
"AlertSetting",
|
||||
"DetectionTask",
|
||||
"UsageRecord",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
engine_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
encrypted_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
key_hint: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
key_source: Mapped[str] = mapped_column(String(10), default="user")
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
priority: Mapped[int] = mapped_column(default=0)
|
||||
last_verified_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_api_keys_user_engine", "user_id", "engine_type"),
|
||||
Index("idx_api_keys_engine_status", "engine_type", "status"),
|
||||
)
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Float, DateTime, ForeignKey, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class UsageRecord(Base):
|
||||
__tablename__ = "usage_records"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
engine_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
)
|
||||
query: Mapped[str] = mapped_column(
|
||||
String(500),
|
||||
nullable=False,
|
||||
)
|
||||
input_tokens: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=0,
|
||||
)
|
||||
output_tokens: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=0,
|
||||
)
|
||||
cost: Mapped[float] = mapped_column(
|
||||
Float,
|
||||
default=0.0,
|
||||
)
|
||||
extra_data: Mapped[dict] = mapped_column(
|
||||
JSONType,
|
||||
default=dict,
|
||||
)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
default=func.now(),
|
||||
index=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_usage_records_user_engine", "user_id", "engine_type"),
|
||||
Index("idx_usage_records_user_timestamp", "user_id", "timestamp"),
|
||||
Index("idx_usage_records_engine_timestamp", "engine_type", "timestamp"),
|
||||
)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from app.repositories.api_key_repository import APIKeyRepository
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
|
||||
__all__ = [
|
||||
"APIKeyRepository",
|
||||
"UsageRepository",
|
||||
]
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select, and_, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.api_key import APIKey
|
||||
|
||||
|
||||
class APIKeyRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
engine_type: str,
|
||||
encrypted_key: str,
|
||||
key_hint: str,
|
||||
key_source: str = "user",
|
||||
status: str = "active",
|
||||
priority: int = 0,
|
||||
last_verified_at: datetime | None = None,
|
||||
) -> APIKey:
|
||||
api_key = APIKey(
|
||||
user_id=user_id,
|
||||
engine_type=engine_type,
|
||||
encrypted_key=encrypted_key,
|
||||
key_hint=key_hint,
|
||||
key_source=key_source,
|
||||
status=status,
|
||||
priority=priority,
|
||||
last_verified_at=last_verified_at,
|
||||
)
|
||||
self.session.add(api_key)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
async def get_by_id(self, key_id: uuid.UUID) -> APIKey | None:
|
||||
result = await self.session.execute(
|
||||
select(APIKey).where(APIKey.id == key_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
engine_type: str | None = None,
|
||||
) -> list[APIKey]:
|
||||
conditions = [APIKey.user_id == user_id]
|
||||
if engine_type:
|
||||
conditions.append(APIKey.engine_type == engine_type)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(APIKey).where(and_(*conditions)).order_by(APIKey.priority.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_user_and_engine(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
engine_type: str,
|
||||
) -> APIKey | None:
|
||||
result = await self.session.execute(
|
||||
select(APIKey).where(
|
||||
and_(
|
||||
APIKey.user_id == user_id,
|
||||
APIKey.engine_type == engine_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update(
|
||||
self,
|
||||
key_id: uuid.UUID,
|
||||
**kwargs,
|
||||
) -> APIKey | None:
|
||||
api_key = await self.get_by_id(key_id)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(api_key, key):
|
||||
setattr(api_key, key, value)
|
||||
|
||||
api_key.updated_at = datetime.utcnow()
|
||||
await self.session.commit()
|
||||
await self.session.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
async def delete(self, key_id: uuid.UUID) -> bool:
|
||||
result = await self.session.execute(
|
||||
delete(APIKey).where(APIKey.id == key_id)
|
||||
)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def list_all(
|
||||
self,
|
||||
engine_type: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[APIKey]:
|
||||
conditions = []
|
||||
if engine_type:
|
||||
conditions.append(APIKey.engine_type == engine_type)
|
||||
if status:
|
||||
conditions.append(APIKey.status == status)
|
||||
|
||||
query = select(APIKey)
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
query = query.order_by(APIKey.priority.desc()).limit(limit).offset(offset)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select, func, and_, case
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.usage_record import UsageRecord
|
||||
|
||||
|
||||
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
|
||||
_QUOTA_WARNING_PCT = 80.0
|
||||
_QUOTA_EXCEEDED_PCT = 100.0
|
||||
|
||||
|
||||
def _compute_cutoff(period: str, now: datetime) -> datetime:
|
||||
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
|
||||
if days == 0:
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return now - timedelta(days=days)
|
||||
|
||||
|
||||
def _quota_status(usage_pct: float) -> str:
|
||||
if usage_pct >= _QUOTA_EXCEEDED_PCT:
|
||||
return "exceeded"
|
||||
if usage_pct >= _QUOTA_WARNING_PCT:
|
||||
return "warning"
|
||||
return "ok"
|
||||
|
||||
|
||||
class UsageRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, data: dict) -> UsageRecord:
|
||||
record = UsageRecord(
|
||||
user_id=data["user_id"],
|
||||
brand_id=data.get("brand_id"),
|
||||
engine_type=data["engine_type"],
|
||||
query=data["query"],
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
cost=data.get("cost", 0.0),
|
||||
extra_data=data.get("extra_data", {}),
|
||||
timestamp=data.get("timestamp", datetime.now(timezone.utc)),
|
||||
)
|
||||
self.session.add(record)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(record)
|
||||
return record
|
||||
|
||||
async def get_summary(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
period: str = "month",
|
||||
brand_id: str | uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = _compute_cutoff(period, now)
|
||||
|
||||
if isinstance(user_id, str):
|
||||
user_id = uuid.UUID(user_id)
|
||||
if brand_id and isinstance(brand_id, str):
|
||||
brand_id = uuid.UUID(brand_id)
|
||||
|
||||
conditions = [
|
||||
UsageRecord.user_id == user_id,
|
||||
UsageRecord.timestamp >= cutoff,
|
||||
]
|
||||
if brand_id:
|
||||
conditions.append(UsageRecord.brand_id == brand_id)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(UsageRecord).where(and_(*conditions))
|
||||
)
|
||||
records = list(result.scalars().all())
|
||||
|
||||
total_queries = len(records)
|
||||
total_input_tokens = sum(r.input_tokens for r in records)
|
||||
total_output_tokens = sum(r.output_tokens for r in records)
|
||||
total_cost = round(sum(r.cost for r in records), 4)
|
||||
|
||||
by_engine: dict[str, dict] = {}
|
||||
for r in records:
|
||||
bucket = by_engine.setdefault(
|
||||
r.engine_type,
|
||||
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
||||
)
|
||||
bucket["queries"] += 1
|
||||
bucket["input_tokens"] += r.input_tokens
|
||||
bucket["output_tokens"] += r.output_tokens
|
||||
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
|
||||
|
||||
by_day: dict[str, dict] = {}
|
||||
for r in records:
|
||||
day_key = r.timestamp.strftime("%Y-%m-%d")
|
||||
bucket = by_day.setdefault(
|
||||
day_key,
|
||||
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
||||
)
|
||||
bucket["queries"] += 1
|
||||
bucket["input_tokens"] += r.input_tokens
|
||||
bucket["output_tokens"] += r.output_tokens
|
||||
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
|
||||
|
||||
return {
|
||||
"period": period,
|
||||
"start_date": cutoff.isoformat(),
|
||||
"end_date": now.isoformat(),
|
||||
"total_queries": total_queries,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_cost": total_cost,
|
||||
"by_engine": by_engine,
|
||||
"by_day": by_day,
|
||||
}
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
monthly_limit: float = 100.0,
|
||||
) -> dict:
|
||||
summary = await self.get_summary(user_id=user_id, period="month")
|
||||
usage_pct = (summary["total_cost"] / monthly_limit * 100) if monthly_limit > 0 else 0
|
||||
return {
|
||||
"used": summary["total_cost"],
|
||||
"limit": monthly_limit,
|
||||
"usage_percentage": round(usage_pct, 1),
|
||||
"status": _quota_status(usage_pct),
|
||||
}
|
||||
|
||||
async def get_by_id(self, record_id: uuid.UUID) -> UsageRecord | None:
|
||||
result = await self.session.execute(
|
||||
select(UsageRecord).where(UsageRecord.id == record_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[UsageRecord]:
|
||||
if isinstance(user_id, str):
|
||||
user_id = uuid.UUID(user_id)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(UsageRecord)
|
||||
.where(UsageRecord.user_id == user_id)
|
||||
.order_by(UsageRecord.timestamp.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_user_and_engine(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
engine_type: str,
|
||||
limit: int = 100,
|
||||
) -> list[UsageRecord]:
|
||||
if isinstance(user_id, str):
|
||||
user_id = uuid.UUID(user_id)
|
||||
|
||||
result = await self.session.execute(
|
||||
select(UsageRecord)
|
||||
.where(
|
||||
and_(
|
||||
UsageRecord.user_id == user_id,
|
||||
UsageRecord.engine_type == engine_type,
|
||||
)
|
||||
)
|
||||
.order_by(UsageRecord.timestamp.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
|
@ -9,6 +9,8 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
|
|
@ -49,15 +51,53 @@ class AIQueryResult:
|
|||
response_time_ms: int
|
||||
timestamp: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self.input_tokens + self.output_tokens
|
||||
|
||||
|
||||
class AIEngineAdapter(ABC):
|
||||
def __init__(self, api_key: str, rate_limiter=None, proxy: str | None = None):
|
||||
self.api_key = api_key
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager: APIKeyManager | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
self._key_manager = key_manager
|
||||
self._user_id = user_id
|
||||
self.api_key = self._resolve_api_key(api_key, key_manager, user_id)
|
||||
self.rate_limiter = rate_limiter
|
||||
self.proxy = proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
self.proxy = proxy or self._load_proxy()
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
direct_key: str | None,
|
||||
key_manager: APIKeyManager | None,
|
||||
user_id: str | None,
|
||||
) -> str:
|
||||
if direct_key and direct_key.strip():
|
||||
return direct_key
|
||||
|
||||
if key_manager:
|
||||
key = key_manager.get_key(self.get_engine_type().value, user_id=user_id)
|
||||
if key:
|
||||
return key
|
||||
|
||||
return self._get_env_key() or ""
|
||||
|
||||
@abstractmethod
|
||||
def _get_env_key(self) -> str | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,14 +1,103 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
from app.services.usage_tracker import UsageTracker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {}
|
||||
|
||||
|
||||
def register_adapter(cls: type[AIEngineAdapter]) -> None:
|
||||
engine_type = None
|
||||
try:
|
||||
temp = cls()
|
||||
engine_type = temp.get_engine_type()
|
||||
_ADAPTER_CLASSES[engine_type] = cls
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register adapter: {e}")
|
||||
|
||||
|
||||
from .chatgpt import ChatGPTAdapter
|
||||
from .perplexity import PerplexityAdapter
|
||||
from .kimi import KimiAdapter
|
||||
from .wenxin import WenxinAdapter
|
||||
from .doubao import DoubaoAdapter
|
||||
from .yuanbao import YuanbaoAdapter
|
||||
from .deepseek import DeepSeekAdapter
|
||||
from .qwen import QwenAdapter
|
||||
from .gemini import GeminiAdapter
|
||||
|
||||
register_adapter(ChatGPTAdapter)
|
||||
register_adapter(PerplexityAdapter)
|
||||
register_adapter(KimiAdapter)
|
||||
register_adapter(WenxinAdapter)
|
||||
register_adapter(DoubaoAdapter)
|
||||
register_adapter(YuanbaoAdapter)
|
||||
register_adapter(DeepSeekAdapter)
|
||||
register_adapter(QwenAdapter)
|
||||
register_adapter(GeminiAdapter)
|
||||
|
||||
|
||||
def get_batch_service(
|
||||
key_manager: "APIKeyManager | None" = None,
|
||||
user_id: str | None = None,
|
||||
) -> "BatchQueryService":
|
||||
if key_manager:
|
||||
adapters = _build_adapters_with_key_manager(key_manager=key_manager, user_id=user_id)
|
||||
else:
|
||||
adapters = _build_adapters()
|
||||
return BatchQueryService(adapters)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _build_adapters() -> dict[str, AIEngineAdapter]:
|
||||
adapters: dict[str, AIEngineAdapter] = {}
|
||||
for engine_type, cls in _ADAPTER_CLASSES.items():
|
||||
try:
|
||||
adapters[engine_type.value] = cls()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
return adapters
|
||||
|
||||
|
||||
def _build_adapters_with_key_manager(
|
||||
key_manager: "APIKeyManager | None" = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, AIEngineAdapter]:
|
||||
adapters: dict[str, AIEngineAdapter] = {}
|
||||
for engine_type, cls in _ADAPTER_CLASSES.items():
|
||||
try:
|
||||
adapters[engine_type.value] = cls(
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
return adapters
|
||||
|
||||
|
||||
class BatchQueryService:
|
||||
def __init__(self, adapters: dict[str, AIEngineAdapter]):
|
||||
self.adapters = adapters
|
||||
self._tracker = UsageTracker()
|
||||
self._recorder = UsageRecorder(self._tracker)
|
||||
self._user_id: str | None = None
|
||||
self._brand_id: str | None = None
|
||||
|
||||
def set_user_context(self, user_id: str, brand_id: str | None = None) -> None:
|
||||
self._user_id = user_id
|
||||
self._brand_id = brand_id
|
||||
|
||||
def get_usage_summary(self):
|
||||
return self._tracker.get_summary(user_id=self._user_id)
|
||||
|
||||
async def query_single(
|
||||
self,
|
||||
|
|
@ -20,7 +109,24 @@ class BatchQueryService:
|
|||
adapter = self.adapters.get(engine_type.value)
|
||||
if not adapter:
|
||||
raise ValueError(f"Unknown engine type: {engine_type}")
|
||||
return await adapter.query(query, brand_name, competitor_names)
|
||||
|
||||
result = await adapter.query(query, brand_name, competitor_names)
|
||||
|
||||
if self._user_id:
|
||||
self._recorder.record(
|
||||
user_id=self._user_id,
|
||||
brand_id=self._brand_id,
|
||||
engine_type=engine_type.value,
|
||||
query=query,
|
||||
input_tokens=result.input_tokens,
|
||||
output_tokens=result.output_tokens,
|
||||
metadata={
|
||||
"brand_name": brand_name,
|
||||
"response_time_ms": result.response_time_ms,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def query_batch(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("OPENAI_PROXY"),
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -43,6 +47,12 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("OPENAI_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -70,6 +80,10 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[chatgpt] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -87,4 +101,6 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ class DeepSeekAdapter(AIEngineAdapter):
|
|||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("DEEPSEEK_PROXY"),
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -43,6 +47,12 @@ class DeepSeekAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.DEEPSEEK
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("DEEPSEEK_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("DEEPSEEK_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -70,6 +80,10 @@ class DeepSeekAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[deepseek] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -87,4 +101,6 @@ class DeepSeekAdapter(AIEngineAdapter):
|
|||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,16 @@ class DoubaoAdapter(AIEngineAdapter):
|
|||
api_key: str | None = None,
|
||||
endpoint_id: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DOUBAO_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "")
|
||||
self._base_url = _DEFAULT_BASE_URL.rstrip("/")
|
||||
|
|
@ -38,6 +44,12 @@ class DoubaoAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.DOUBAO
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("DOUBAO_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("DOUBAO_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
def _get_model_id(self) -> str:
|
||||
if self._endpoint_id and self._endpoint_id.strip():
|
||||
if not self._endpoint_id.startswith("ep-"):
|
||||
|
|
@ -76,6 +88,10 @@ class DoubaoAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[doubao] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -92,5 +108,7 @@ class DoubaoAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", model_id), "usage": data.get("usage")},
|
||||
metadata={"model": data.get("model", model_id), "usage": usage},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,10 +23,15 @@ class GeminiAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("GOOGLE_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/")
|
||||
|
|
@ -46,6 +51,12 @@ class GeminiAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.GEMINI
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("GOOGLE_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("GOOGLE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def _request_with_retry(self, payload: dict) -> dict:
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.acquire()
|
||||
|
|
@ -117,6 +128,10 @@ class GeminiAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage_metadata = data.get("usageMetadata", {})
|
||||
input_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
|
||||
logger.info(
|
||||
f"[gemini] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -133,5 +148,7 @@ class GeminiAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": self._model},
|
||||
metadata={"model": self._model, "usage": usage_metadata},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,10 +20,16 @@ class KimiAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("MOONSHOT_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or _DEFAULT_MODEL
|
||||
self._base_url = (
|
||||
|
|
@ -41,6 +47,12 @@ class KimiAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.KIMI
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("MOONSHOT_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("MOONSHOT_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -71,6 +83,10 @@ class KimiAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[kimi] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -87,5 +103,7 @@ class KimiAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
metadata={"model": data.get("model", self._model), "usage": usage},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("PERPLEXITY_PROXY"),
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -43,6 +47,12 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.PERPLEXITY
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("PERPLEXITY_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("PERPLEXITY_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -71,6 +81,10 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[perplexity] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms"
|
||||
|
|
@ -88,6 +102,8 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
def _extract_citations(self, data: dict) -> list[CitationInfo]:
|
||||
|
|
|
|||
|
|
@ -20,10 +20,16 @@ class QwenAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DASHSCOPE_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("QWEN_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -41,6 +47,12 @@ class QwenAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.QWEN
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("DASHSCOPE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -71,6 +83,10 @@ class QwenAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[qwen] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -87,5 +103,7 @@ class QwenAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
metadata={"model": data.get("model", self._model), "usage": usage},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,10 +26,16 @@ class WenxinAdapter(AIEngineAdapter):
|
|||
api_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("BAIDU_QIANFAN_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||||
self._model = _DEFAULT_MODEL
|
||||
|
|
@ -40,6 +46,12 @@ class WenxinAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.WENXIN
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("BAIDU_QIANFAN_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
global _cached_token, _token_expires_at
|
||||
|
||||
|
|
@ -125,6 +137,10 @@ class WenxinAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[wenxin] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -141,5 +157,7 @@ class WenxinAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": self._model, "usage": data.get("usage")},
|
||||
metadata={"model": self._model, "usage": usage},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,10 +20,16 @@ class YuanbaoAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
key_manager=None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("HUNYUAN_API_KEY", ""),
|
||||
api_key=api_key,
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy,
|
||||
key_manager=key_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -41,6 +47,12 @@ class YuanbaoAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.YUANBAO
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("HUNYUAN_API_KEY", "")
|
||||
|
||||
def _load_proxy(self) -> str | None:
|
||||
return os.getenv("HUNYUAN_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -71,6 +83,10 @@ class YuanbaoAdapter(AIEngineAdapter):
|
|||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
f"[yuanbao] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
|
|
@ -87,5 +103,7 @@ class YuanbaoAdapter(AIEngineAdapter):
|
|||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
metadata={"model": data.get("model", self._model), "usage": usage},
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.repositories.api_key_repository import APIKeyRepository
|
||||
from app.services.key_encryption import KeyEncryption, get_key_encryption
|
||||
from app.services.key_verifier import KeyStatus, KeyVerifierFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -13,14 +21,6 @@ class KeySource(str, Enum):
|
|||
ENV = "env"
|
||||
|
||||
|
||||
class KeyStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INVALID = "invalid"
|
||||
EXPIRED = "expired"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKeyConfig:
|
||||
engine_type: str
|
||||
|
|
@ -116,13 +116,17 @@ class APIKeyManager:
|
|||
async def verify_key(self, engine_type: str, api_key: str) -> KeyStatus:
|
||||
if not api_key or len(api_key) < 10:
|
||||
return KeyStatus.INVALID
|
||||
return KeyStatus.ACTIVE
|
||||
try:
|
||||
return await KeyVerifierFactory.verify(engine_type, api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"[api_key_manager] Key verification failed: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
def _encrypt(self, plaintext: str) -> str:
|
||||
return base64.b64encode(plaintext.encode()).decode()
|
||||
return get_key_encryption().encrypt(plaintext)
|
||||
|
||||
def _decrypt(self, ciphertext: str) -> str:
|
||||
return base64.b64decode(ciphertext.encode()).decode()
|
||||
return get_key_encryption().decrypt(ciphertext)
|
||||
|
||||
def _mask_key(self, key: str) -> str:
|
||||
if len(key) <= 8:
|
||||
|
|
@ -134,3 +138,115 @@ class APIKeyManager:
|
|||
key = os.getenv(env_var, "")
|
||||
if key:
|
||||
self.add_key(engine, key, source=KeySource.ENV, priority=0)
|
||||
|
||||
async def add_key_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
engine_type: str,
|
||||
api_key: str,
|
||||
source: KeySource = KeySource.USER,
|
||||
user_id: uuid.UUID | None = None,
|
||||
priority: int = 0,
|
||||
) -> APIKeyConfig:
|
||||
config = APIKeyConfig(
|
||||
engine_type=engine_type,
|
||||
key_source=source,
|
||||
encrypted_key=self._encrypt(api_key),
|
||||
key_hint=self._mask_key(api_key),
|
||||
status=KeyStatus.UNKNOWN,
|
||||
priority=priority,
|
||||
user_id=str(user_id) if user_id else None,
|
||||
)
|
||||
repository = APIKeyRepository(session)
|
||||
await repository.create(
|
||||
user_id=user_id,
|
||||
engine_type=engine_type,
|
||||
encrypted_key=config.encrypted_key,
|
||||
key_hint=config.key_hint,
|
||||
key_source=source.value,
|
||||
status=config.status.value,
|
||||
priority=priority,
|
||||
)
|
||||
return config
|
||||
|
||||
async def get_key_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
engine_type: str,
|
||||
user_id: uuid.UUID | None = None,
|
||||
) -> str | None:
|
||||
repository = APIKeyRepository(session)
|
||||
keys = await repository.get_by_user(user_id, engine_type)
|
||||
for key in keys:
|
||||
source = KeySource(key.key_source)
|
||||
status = KeyStatus(key.status)
|
||||
if user_id:
|
||||
if source == KeySource.USER and status in self._USABLE_STATUSES:
|
||||
return self._decrypt(key.encrypted_key)
|
||||
if source in self._FALLBACK_SOURCES and status in self._USABLE_STATUSES:
|
||||
return self._decrypt(key.encrypted_key)
|
||||
return None
|
||||
|
||||
async def get_any_available_key_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
engine_type: str,
|
||||
) -> str | None:
|
||||
repository = APIKeyRepository(session)
|
||||
keys = await repository.get_by_user(None, engine_type)
|
||||
for key in keys:
|
||||
status = KeyStatus(key.status)
|
||||
if status in self._USABLE_STATUSES:
|
||||
return self._decrypt(key.encrypted_key)
|
||||
return None
|
||||
|
||||
async def remove_key_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
engine_type: str,
|
||||
key_hint: str,
|
||||
) -> bool:
|
||||
repository = APIKeyRepository(session)
|
||||
keys = await repository.get_by_user(user_id, engine_type)
|
||||
for key in keys:
|
||||
if key.key_hint == key_hint:
|
||||
deleted = await repository.delete(key.id)
|
||||
return deleted
|
||||
return False
|
||||
|
||||
async def list_keys_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID | None = None,
|
||||
engine_type: str | None = None,
|
||||
) -> list[APIKeyConfig]:
|
||||
repository = APIKeyRepository(session)
|
||||
if user_id:
|
||||
keys = await repository.get_by_user(user_id, engine_type)
|
||||
else:
|
||||
keys = await repository.list_all(engine_type=engine_type)
|
||||
return [
|
||||
APIKeyConfig(
|
||||
engine_type=k.engine_type,
|
||||
key_source=KeySource(k.key_source),
|
||||
encrypted_key=k.encrypted_key,
|
||||
key_hint=k.key_hint,
|
||||
status=KeyStatus(k.status),
|
||||
priority=k.priority,
|
||||
last_verified_at=str(k.last_verified_at) if k.last_verified_at else None,
|
||||
created_at=str(k.created_at) if k.created_at else None,
|
||||
user_id=str(k.user_id) if k.user_id else None,
|
||||
)
|
||||
for k in keys
|
||||
]
|
||||
|
||||
async def update_key_status_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
key_id: uuid.UUID,
|
||||
status: KeyStatus,
|
||||
) -> bool:
|
||||
repository = APIKeyRepository(session)
|
||||
updated = await repository.update(key_id, status=status.value)
|
||||
return updated is not None
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@ from datetime import datetime, timezone
|
|||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.ai_engines import _build_adapters
|
||||
from app.models.brand import Brand
|
||||
from app.models.detection_task import VALID_FREQUENCIES, DetectionTask
|
||||
from app.services.ai_engine.base import EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
from app.services.ai_engine.batch_query import BatchQueryService, _build_adapters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
from app.services.smart_router import ENGINE_COST_PROFILES, SmartRouter
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
|
||||
class EngineSelector:
|
||||
def __init__(self, key_manager: APIKeyManager):
|
||||
self.key_manager = key_manager
|
||||
self.router = SmartRouter(key_manager=key_manager)
|
||||
|
||||
def select_engines(
|
||||
self,
|
||||
max_engines: int = 5,
|
||||
prefer_domestic: bool = True,
|
||||
min_cost_tier: str | None = None,
|
||||
) -> list[str]:
|
||||
engines = self.router.select_engines(max_engines, prefer_domestic)
|
||||
|
||||
if min_cost_tier:
|
||||
tier_order = ["free", "low_cost", "mid_cost", "high_cost"]
|
||||
if min_cost_tier in tier_order:
|
||||
min_idx = tier_order.index(min_cost_tier)
|
||||
filtered = []
|
||||
for engine in engines:
|
||||
profile = ENGINE_COST_PROFILES.get(engine)
|
||||
if profile and tier_order.index(profile.cost_tier.value) >= min_idx:
|
||||
filtered.append(engine)
|
||||
return filtered
|
||||
|
||||
return engines
|
||||
|
||||
def get_best_value_engine(self) -> str | None:
|
||||
available = self.router.get_available_engines()
|
||||
for engine in available:
|
||||
profile = ENGINE_COST_PROFILES.get(engine)
|
||||
if profile and profile.cost_tier.value in ["free", "low_cost"]:
|
||||
return engine
|
||||
return available[0] if available else None
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import base64
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
|
||||
class KeyEncryption:
|
||||
def __init__(self, encryption_key: str | None = None):
|
||||
self._fernet = self._create_fernet(encryption_key)
|
||||
|
||||
def _create_fernet(self, key: str | None) -> Fernet:
|
||||
if not key:
|
||||
key = os.getenv("API_KEY_ENCRYPTION_KEY", "")
|
||||
|
||||
if not key:
|
||||
return Fernet(Fernet.generate_key())
|
||||
|
||||
try:
|
||||
return Fernet(key.encode())
|
||||
except Exception:
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=b"geo-platform-salt",
|
||||
iterations=100000,
|
||||
)
|
||||
derived_key = base64.urlsafe_b64encode(kdf.derive(key.encode()))
|
||||
return Fernet(derived_key)
|
||||
|
||||
def encrypt(self, plaintext: str) -> str:
|
||||
if not plaintext:
|
||||
raise ValueError("Cannot encrypt empty string")
|
||||
return self._fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
def decrypt(self, ciphertext: str) -> str:
|
||||
if not ciphertext:
|
||||
raise ValueError("Cannot decrypt empty string")
|
||||
try:
|
||||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
||||
except InvalidToken:
|
||||
raise ValueError("Invalid ciphertext or key")
|
||||
|
||||
|
||||
_key_encryption: KeyEncryption | None = None
|
||||
|
||||
|
||||
def get_key_encryption() -> KeyEncryption:
|
||||
global _key_encryption
|
||||
if _key_encryption is None:
|
||||
_key_encryption = KeyEncryption()
|
||||
return _key_encryption
|
||||
|
||||
|
||||
def reset_key_encryption() -> None:
|
||||
global _key_encryption
|
||||
_key_encryption = None
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INVALID = "invalid"
|
||||
EXPIRED = "expired"
|
||||
RATE_LIMITED = "rate_limited"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
_VERIFY_TIMEOUT = 15.0
|
||||
|
||||
|
||||
class KeyVerifier(ABC):
|
||||
@abstractmethod
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultKeyVerifier(KeyVerifier):
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class ChatGPTKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("OPENAI_PROXY") or os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"{self._BASE_URL}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 401:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.EXPIRED
|
||||
else:
|
||||
logger.warning(f"[chatgpt] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[chatgpt] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[chatgpt] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[chatgpt] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class DeepSeekKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://api.deepseek.com/v1"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("DEEPSEEK_PROXY") or os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"{self._BASE_URL}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "deepseek-chat",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 401:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.EXPIRED
|
||||
else:
|
||||
logger.warning(f"[deepseek] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[deepseek] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[deepseek] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[deepseek] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class KimiKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://api.moonshot.cn/v1"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"{self._BASE_URL}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "moonshot-v1-8k",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 401:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.EXPIRED
|
||||
else:
|
||||
logger.warning(f"[kimi] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[kimi] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[kimi] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[kimi] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class QwenKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"{self._BASE_URL}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "qwen-plus",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 401:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.EXPIRED
|
||||
else:
|
||||
logger.warning(f"[qwen] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[qwen] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[qwen] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[qwen] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class PerplexityKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://api.perplexity.ai"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"{self._BASE_URL}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "llama-3.1-sonar-small-128k-online",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 401:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.EXPIRED
|
||||
else:
|
||||
logger.warning(f"[perplexity] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[perplexity] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[perplexity] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[perplexity] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class GeminiKeyVerifier(KeyVerifier):
|
||||
_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
proxy = os.getenv("HTTPS_PROXY")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=_VERIFY_TIMEOUT, write=10.0, pool=10.0),
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"{self._BASE_URL}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
params={"key": api_key},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return KeyStatus.ACTIVE
|
||||
elif response.status_code == 403:
|
||||
return KeyStatus.INVALID
|
||||
elif response.status_code == 429:
|
||||
return KeyStatus.RATE_LIMITED
|
||||
else:
|
||||
logger.warning(f"[gemini] Unexpected status {response.status_code}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("[gemini] Verification timeout")
|
||||
return KeyStatus.UNKNOWN
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"[gemini] Connection error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
except Exception as e:
|
||||
logger.error(f"[gemini] Verification error: {e}")
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class WenxinKeyVerifier(KeyVerifier):
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class DoubaoKeyVerifier(KeyVerifier):
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class YuanbaoKeyVerifier(KeyVerifier):
|
||||
async def verify(self, api_key: str) -> KeyStatus:
|
||||
return KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class KeyVerifierFactory:
|
||||
_VERIFIERS = {
|
||||
"chatgpt": ChatGPTKeyVerifier,
|
||||
"perplexity": PerplexityKeyVerifier,
|
||||
"kimi": KimiKeyVerifier,
|
||||
"wenxin": WenxinKeyVerifier,
|
||||
"doubao": DoubaoKeyVerifier,
|
||||
"deepseek": DeepSeekKeyVerifier,
|
||||
"qwen": QwenKeyVerifier,
|
||||
"gemini": GeminiKeyVerifier,
|
||||
"yuanbao": YuanbaoKeyVerifier,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_verifier(cls, engine_type: str) -> KeyVerifier:
|
||||
verifier_class = cls._VERIFIERS.get(engine_type)
|
||||
if not verifier_class:
|
||||
return DefaultKeyVerifier()
|
||||
return verifier_class()
|
||||
|
||||
@classmethod
|
||||
async def verify(cls, engine_type: str, api_key: str) -> KeyStatus:
|
||||
verifier = cls.get_verifier(engine_type)
|
||||
return await verifier.verify(api_key)
|
||||
|
|
@ -2,6 +2,10 @@ from __future__ import annotations
|
|||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
|
||||
class CostTier(str, Enum):
|
||||
|
|
@ -43,14 +47,37 @@ def _get_profile(engine: str) -> EngineCostProfile | None:
|
|||
|
||||
|
||||
class SmartRouter:
|
||||
def __init__(self, available_engines: list[str] | None = None, user_engines: list[str] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
available_engines: list[str] | None = None,
|
||||
user_engines: list[str] | None = None,
|
||||
key_manager: APIKeyManager | None = None,
|
||||
):
|
||||
self.available_engines = available_engines or list(ENGINE_COST_PROFILES.keys())
|
||||
self._user_engine_set = set(user_engines or [])
|
||||
self._key_manager = key_manager
|
||||
|
||||
@property
|
||||
def user_engines(self) -> list[str]:
|
||||
return list(self._user_engine_set)
|
||||
|
||||
def set_key_manager(self, key_manager: APIKeyManager) -> None:
|
||||
self._key_manager = key_manager
|
||||
|
||||
def _filter_by_available_keys(self, engines: list[str]) -> list[str]:
|
||||
if not self._key_manager:
|
||||
return engines
|
||||
available = []
|
||||
for engine in engines:
|
||||
key = self._key_manager.get_any_available_key(engine)
|
||||
if key:
|
||||
available.append(engine)
|
||||
return available
|
||||
|
||||
def get_available_engines(self) -> list[str]:
|
||||
all_engines = list(ENGINE_COST_PROFILES.keys())
|
||||
return self._filter_by_available_keys(all_engines)
|
||||
|
||||
def select_engines(self, max_engines: int = 5, prefer_domestic: bool = True) -> list[str]:
|
||||
tiers: dict[CostTier, list[str]] = {tier: [] for tier in CostTier}
|
||||
for engine in self.available_engines:
|
||||
|
|
@ -74,9 +101,12 @@ class SmartRouter:
|
|||
for tier in _TIER_ORDER:
|
||||
for engine in tiers[tier]:
|
||||
if len(selected) >= max_engines:
|
||||
return selected
|
||||
if engine not in selected:
|
||||
break
|
||||
key = self._key_manager.get_any_available_key(engine) if self._key_manager else True
|
||||
if key and engine not in selected:
|
||||
selected.append(engine)
|
||||
if len(selected) >= max_engines:
|
||||
break
|
||||
|
||||
return selected[:max_engines]
|
||||
|
||||
|
|
@ -94,7 +124,11 @@ class SmartRouter:
|
|||
return {"total_cost": round(total_cost, 6), "per_engine": details}
|
||||
|
||||
def _engines_by_tier(self, tier: CostTier) -> list[str]:
|
||||
return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier]
|
||||
tier_engines = [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.cost_tier == tier]
|
||||
return self._filter_by_available_keys(tier_engines)
|
||||
|
||||
def get_engines_by_cost_tier(self, tier: CostTier) -> list[str]:
|
||||
return self._engines_by_tier(tier)
|
||||
|
||||
def _engines_by_requires_key(self) -> list[str]:
|
||||
return [e for e in self.available_engines if (p := _get_profile(e)) is not None and p.requires_own_key]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
from app.services.usage_tracker import UsageTracker
|
||||
from app.services.smart_router import ENGINE_COST_PROFILES
|
||||
|
||||
|
||||
class UsageRecorder:
|
||||
def __init__(self, tracker: UsageTracker):
|
||||
self.tracker = tracker
|
||||
|
||||
def calculate_cost(self, engine_type: str, input_tokens: int, output_tokens: int) -> float:
|
||||
profile = ENGINE_COST_PROFILES.get(engine_type)
|
||||
if not profile:
|
||||
return 0.0
|
||||
|
||||
input_cost = (input_tokens / 1_000_000) * profile.input_price_per_million
|
||||
output_cost = (output_tokens / 1_000_000) * profile.output_price_per_million
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
def record(
|
||||
self,
|
||||
user_id: str,
|
||||
brand_id: str | None,
|
||||
engine_type: str,
|
||||
query: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
cost = self.calculate_cost(engine_type, input_tokens, output_tokens)
|
||||
self.tracker.record(
|
||||
user_id=user_id,
|
||||
brand_id=brand_id or "",
|
||||
engine_type=engine_type,
|
||||
query=query,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
|
@ -1,9 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
from app.models.usage_record import UsageRecord as UsageRecordModel
|
||||
|
||||
_QUOTA_WARNING_PCT = 80.0
|
||||
_QUOTA_EXCEEDED_PCT = 100.0
|
||||
|
||||
|
|
@ -48,6 +54,18 @@ def _aggregate_by_engine(records: list[UsageRecord]) -> dict[str, dict[str, Any]
|
|||
return result
|
||||
|
||||
|
||||
def _aggregate_by_day(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
|
||||
result: dict[str, dict[str, Any]] = {}
|
||||
for r in records:
|
||||
day_key = r.timestamp.strftime("%Y-%m-%d")
|
||||
bucket = result.setdefault(day_key, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
|
||||
bucket["queries"] += 1
|
||||
bucket["input_tokens"] += r.input_tokens
|
||||
bucket["output_tokens"] += r.output_tokens
|
||||
bucket["cost"] += r.cost
|
||||
return result
|
||||
|
||||
|
||||
def _compute_cutoff(period: str, now: datetime) -> datetime:
|
||||
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
|
||||
if days == 0:
|
||||
|
|
@ -64,8 +82,10 @@ def _quota_status(usage_pct: float) -> str:
|
|||
|
||||
|
||||
class UsageTracker:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, session: AsyncSession | None = None) -> None:
|
||||
self._records: list[UsageRecord] = []
|
||||
self._session = session
|
||||
self._repository = UsageRepository(session) if session else None
|
||||
|
||||
def record(
|
||||
self,
|
||||
|
|
@ -119,7 +139,7 @@ class UsageTracker:
|
|||
total_output_tokens=sum(r.output_tokens for r in filtered),
|
||||
total_cost=round(sum(r.cost for r in filtered), 4),
|
||||
by_engine=_aggregate_by_engine(filtered),
|
||||
by_day={},
|
||||
by_day=_aggregate_by_day(filtered),
|
||||
)
|
||||
|
||||
def check_quota(self, user_id: str, monthly_limit: float = 100.0) -> dict:
|
||||
|
|
@ -131,3 +151,51 @@ class UsageTracker:
|
|||
"usage_percentage": round(usage_pct, 1),
|
||||
"status": _quota_status(usage_pct),
|
||||
}
|
||||
|
||||
async def record_async(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
brand_id: str | uuid.UUID | None,
|
||||
engine_type: str,
|
||||
query: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cost: float,
|
||||
extra_data: dict | None = None,
|
||||
) -> UsageRecordModel:
|
||||
if not self._repository:
|
||||
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
||||
|
||||
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
brand_uuid = uuid.UUID(brand_id) if (brand_id and isinstance(brand_id, str)) else brand_id
|
||||
|
||||
data = {
|
||||
"user_id": user_uuid,
|
||||
"brand_id": brand_uuid,
|
||||
"engine_type": engine_type,
|
||||
"query": query,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cost": cost,
|
||||
"extra_data": extra_data or {},
|
||||
}
|
||||
return await self._repository.create(data)
|
||||
|
||||
async def get_summary_async(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
period: str = "month",
|
||||
brand_id: str | uuid.UUID | None = None,
|
||||
) -> dict:
|
||||
if not self._repository:
|
||||
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
||||
return await self._repository.get_summary(user_id, period, brand_id)
|
||||
|
||||
async def check_quota_async(
|
||||
self,
|
||||
user_id: str | uuid.UUID,
|
||||
monthly_limit: float = 100.0,
|
||||
) -> dict:
|
||||
if not self._repository:
|
||||
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
||||
return await self._repository.check_quota(user_id, monthly_limit)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
"""User quota service for plan-based monthly limits."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
|
||||
PLAN_MONTHLY_LIMITS = {
|
||||
"free": 10.0,
|
||||
"basic": 50.0,
|
||||
"pro": 200.0,
|
||||
"enterprise": 1000.0,
|
||||
}
|
||||
|
||||
|
||||
class UserQuotaService:
|
||||
"""Service for checking user quota based on subscription plan."""
|
||||
|
||||
def __init__(self, session: AsyncSession | None = None):
|
||||
self._session = session
|
||||
self._repository = UsageRepository(session) if session else None
|
||||
|
||||
def get_monthly_limit(self, user_plan: str) -> float:
|
||||
"""Get monthly cost limit based on user plan."""
|
||||
return PLAN_MONTHLY_LIMITS.get(user_plan, PLAN_MONTHLY_LIMITS["free"])
|
||||
|
||||
async def check_quota_with_plan(
|
||||
self,
|
||||
user_id: str,
|
||||
user_plan: str,
|
||||
) -> dict:
|
||||
"""Check quota using the limit from user plan."""
|
||||
if not self._repository:
|
||||
raise RuntimeError("UserQuotaService not initialized with AsyncSession")
|
||||
|
||||
monthly_limit = self.get_monthly_limit(user_plan)
|
||||
return await self._repository.check_quota(user_id, monthly_limit=monthly_limit)
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
# test_content_pipeline.py
|
||||
import pytest
|
||||
|
||||
# 导入实际的 ContentPipeline 实现
|
||||
from app.services.content.content_pipeline import ContentPipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_complete_run():
|
||||
"""完整Pipeline执行"""
|
||||
pipeline = ContentPipeline()
|
||||
request = {
|
||||
"content": "这是一篇测试文章内容",
|
||||
"title": "测试标题",
|
||||
"platform": "zhihu",
|
||||
"optimize_for": ["validation", "sensitive", "seo"]
|
||||
}
|
||||
result = await pipeline.run(request)
|
||||
|
||||
assert result.stages is not None
|
||||
assert len(result.stages) > 0
|
||||
assert result.outputs is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_validation_fail():
|
||||
"""校验失败中断"""
|
||||
pipeline = ContentPipeline()
|
||||
request = {
|
||||
"content": "内容",
|
||||
"title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊",
|
||||
"platform": "wechat",
|
||||
"optimize_for": ["validation"]
|
||||
}
|
||||
result = await pipeline.run(request)
|
||||
|
||||
# 校验失败时不应继续执行后续阶段
|
||||
validation_stage = next((s for s in result.stages if s.name == "validation"), None)
|
||||
assert validation_stage is not None
|
||||
assert validation_stage.passed == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_multi_platform():
|
||||
"""多平台适配"""
|
||||
pipeline = ContentPipeline()
|
||||
|
||||
zhihu_result = await pipeline.run({
|
||||
"content": "<p>测试内容</p><a href='http://baidu.com'>外部链接</a>",
|
||||
"title": "测试标题",
|
||||
"platform": "zhihu"
|
||||
})
|
||||
|
||||
wechat_result = await pipeline.run({
|
||||
"content": "<p>测试内容</p><a href='http://baidu.com'>外部链接</a>",
|
||||
"title": "测试标题",
|
||||
"platform": "wechat"
|
||||
})
|
||||
|
||||
# 不同平台应产生不同的优化结果
|
||||
assert zhihu_result.outputs != wechat_result.outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_stage_results():
|
||||
"""各阶段结果记录"""
|
||||
pipeline = ContentPipeline()
|
||||
result = await pipeline.run({
|
||||
"content": "内容",
|
||||
"title": "标题",
|
||||
"platform": "zhihu"
|
||||
})
|
||||
|
||||
# 检查每个阶段的结果
|
||||
for stage in result.stages:
|
||||
assert stage.name is not None
|
||||
assert hasattr(stage, 'passed') or hasattr(stage, 'result')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling():
|
||||
"""错误处理"""
|
||||
pipeline = ContentPipeline()
|
||||
|
||||
# 无效平台应返回错误
|
||||
try:
|
||||
result = await pipeline.run({
|
||||
"content": "内容",
|
||||
"title": "标题",
|
||||
"platform": "invalid_platform"
|
||||
})
|
||||
assert result.error is not None
|
||||
except ValueError as e:
|
||||
assert "不支持的平台" in str(e)
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
"""Tests for APIKey model."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
from app.models.api_key import APIKey
|
||||
|
||||
|
||||
class TestAPIKeyModel:
|
||||
"""Test cases for APIKey model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_create(self, async_session, test_user):
|
||||
"""Test creating a new API key."""
|
||||
api_key = APIKey(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_test_key",
|
||||
key_hint="sk-...abc",
|
||||
key_source="user",
|
||||
status="active",
|
||||
priority=0,
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.id is not None
|
||||
assert api_key.user_id == test_user.id
|
||||
assert api_key.engine_type == "chatgpt"
|
||||
assert api_key.encrypted_key == "encrypted_test_key"
|
||||
assert api_key.key_hint == "sk-...abc"
|
||||
assert api_key.key_source == "user"
|
||||
assert api_key.status == "active"
|
||||
assert api_key.priority == 0
|
||||
assert api_key.created_at is not None
|
||||
assert api_key.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_default_values(self, async_session, test_user):
|
||||
"""Test API key default values."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_kimi_key",
|
||||
key_hint="sk-...xyz",
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.key_source == "user"
|
||||
assert api_key.status == "active"
|
||||
assert api_key.priority == 0
|
||||
assert api_key.last_verified_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_fields(self, async_session, test_user):
|
||||
"""Test API key field validation and constraints."""
|
||||
now = datetime.now(timezone.utc)
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="deepseek",
|
||||
encrypted_key="encrypted_deepseek_key_data",
|
||||
key_hint="sk-...def",
|
||||
key_source="system",
|
||||
status="active",
|
||||
priority=10,
|
||||
last_verified_at=now,
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.engine_type == "deepseek"
|
||||
assert api_key.encrypted_key == "encrypted_deepseek_key_data"
|
||||
assert api_key.key_hint == "sk-...def"
|
||||
assert api_key.key_source == "system"
|
||||
assert api_key.status == "active"
|
||||
assert api_key.priority == 10
|
||||
assert api_key.last_verified_at is not None
|
||||
assert api_key.last_verified_at.replace(tzinfo=None) == now.replace(tzinfo=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_query_by_id(self, async_session, test_user):
|
||||
"""Test querying API key by ID."""
|
||||
key_id = uuid.uuid4()
|
||||
api_key = APIKey(
|
||||
id=key_id,
|
||||
user_id=test_user.id,
|
||||
engine_type="gemini",
|
||||
encrypted_key="encrypted_gemini_key",
|
||||
key_hint="AIza...123",
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.id == key_id)
|
||||
)
|
||||
fetched_key = result.scalar_one()
|
||||
|
||||
assert fetched_key is not None
|
||||
assert fetched_key.id == key_id
|
||||
assert fetched_key.engine_type == "gemini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_query_by_user_id(self, async_session, test_user):
|
||||
"""Test querying API keys by user ID."""
|
||||
key1 = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_key_1",
|
||||
key_hint="sk-...1",
|
||||
)
|
||||
key2 = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_key_2",
|
||||
key_hint="sk-...2",
|
||||
)
|
||||
async_session.add(key1)
|
||||
async_session.add(key2)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.user_id == test_user.id)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
assert len(keys) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_query_by_user_and_engine(self, async_session, test_user):
|
||||
"""Test querying API keys by user ID and engine type."""
|
||||
key1 = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
encrypted_key="encrypted_chatgpt_key",
|
||||
key_hint="sk-...chat",
|
||||
)
|
||||
key2 = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="kimi",
|
||||
encrypted_key="encrypted_kimi_key",
|
||||
key_hint="sk-...kimi",
|
||||
)
|
||||
async_session.add(key1)
|
||||
async_session.add(key2)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(
|
||||
and_(
|
||||
APIKey.user_id == test_user.id,
|
||||
APIKey.engine_type == "chatgpt"
|
||||
)
|
||||
)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
assert len(keys) == 1
|
||||
assert keys[0].engine_type == "chatgpt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_timestamps(self, async_session, test_user):
|
||||
"""Test API key created_at and updated_at timestamps."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="qwen",
|
||||
encrypted_key="encrypted_qwen_key",
|
||||
key_hint="sk-...qwen",
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.created_at is not None
|
||||
assert api_key.updated_at is not None
|
||||
assert isinstance(api_key.created_at, datetime)
|
||||
assert isinstance(api_key.updated_at, datetime)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_update(self, async_session, test_user):
|
||||
"""Test updating API key fields."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="wenxin",
|
||||
encrypted_key="encrypted_wenxin_key",
|
||||
key_hint="sk-...wenxin",
|
||||
status="active",
|
||||
priority=0,
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
|
||||
api_key.status = "invalid"
|
||||
api_key.priority = 5
|
||||
api_key.last_verified_at = datetime.now(timezone.utc)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(api_key)
|
||||
|
||||
assert api_key.status == "invalid"
|
||||
assert api_key.priority == 5
|
||||
assert api_key.last_verified_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_delete(self, async_session, test_user):
|
||||
"""Test deleting an API key."""
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type="doubao",
|
||||
encrypted_key="encrypted_doubao_key",
|
||||
key_hint="sk-...doubao",
|
||||
)
|
||||
async_session.add(api_key)
|
||||
await async_session.commit()
|
||||
key_id = api_key.id
|
||||
|
||||
await async_session.delete(api_key)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.id == key_id)
|
||||
)
|
||||
deleted_key = result.scalar_one_or_none()
|
||||
|
||||
assert deleted_key is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_status_values(self, async_session, test_user):
|
||||
"""Test different status values for API key."""
|
||||
statuses = ["active", "invalid", "expired", "rate_limited", "unknown"]
|
||||
created_keys = []
|
||||
|
||||
for i, status in enumerate(statuses):
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type=f"engine_{i}",
|
||||
encrypted_key=f"encrypted_key_{i}",
|
||||
key_hint=f"sk-...{i}",
|
||||
status=status,
|
||||
)
|
||||
async_session.add(api_key)
|
||||
created_keys.append(api_key)
|
||||
|
||||
await async_session.commit()
|
||||
|
||||
for key, status in zip(created_keys, statuses):
|
||||
assert key.status == status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_priority_ordering(self, async_session, test_user):
|
||||
"""Test API keys with different priorities."""
|
||||
keys_data = [
|
||||
{"engine_type": "engine_a", "priority": 0, "key_hint": "a..."},
|
||||
{"engine_type": "engine_b", "priority": 10, "key_hint": "b..."},
|
||||
{"engine_type": "engine_c", "priority": 5, "key_hint": "c..."},
|
||||
]
|
||||
|
||||
for data in keys_data:
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type=data["engine_type"],
|
||||
encrypted_key=f"encrypted_{data['engine_type']}",
|
||||
key_hint=data["key_hint"],
|
||||
priority=data["priority"],
|
||||
)
|
||||
async_session.add(api_key)
|
||||
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey)
|
||||
.where(APIKey.user_id == test_user.id)
|
||||
.order_by(APIKey.priority.desc())
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
assert keys[0].priority == 10
|
||||
assert keys[1].priority == 5
|
||||
assert keys[2].priority == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_user_id_index(self, async_session, test_user):
|
||||
"""Test that user_id field has an index."""
|
||||
for i in range(5):
|
||||
api_key = APIKey(
|
||||
user_id=test_user.id,
|
||||
engine_type=f"engine_{i}",
|
||||
encrypted_key=f"encrypted_key_{i}",
|
||||
key_hint=f"hint_{i}",
|
||||
)
|
||||
async_session.add(api_key)
|
||||
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(APIKey).where(APIKey.user_id == test_user.id)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
assert len(keys) == 5
|
||||
|
|
@ -0,0 +1,406 @@
|
|||
"""Tests for UsageRecord model."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from app.models.usage_record import UsageRecord
|
||||
|
||||
|
||||
class TestUsageRecordModel:
|
||||
"""Test cases for UsageRecord model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_create(self, async_session, test_user):
|
||||
"""Test creating a new usage record."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
query="What is SEO optimization?",
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
cost=0.015,
|
||||
extra_data={"model": "gpt-4"},
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.engine_type == "chatgpt"
|
||||
assert record.query == "What is SEO optimization?"
|
||||
assert record.input_tokens == 100
|
||||
assert record.output_tokens == 200
|
||||
assert record.cost == 0.015
|
||||
assert record.extra_data == {"model": "gpt-4"}
|
||||
assert record.timestamp is not None
|
||||
assert record.created_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_default_values(self, async_session, test_user):
|
||||
"""Test usage record default values."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="kimi",
|
||||
query="Test query",
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.input_tokens == 0
|
||||
assert record.output_tokens == 0
|
||||
assert record.cost == 0.0
|
||||
assert record.extra_data == {}
|
||||
assert record.brand_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_user_id(self, async_session, test_user):
|
||||
"""Test querying usage records by user ID."""
|
||||
user_id = test_user.id
|
||||
for i in range(3):
|
||||
record = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="deepseek",
|
||||
query=f"Test query {i}",
|
||||
cost=float(i),
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.user_id == user_id)
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
assert len(records) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_user_and_engine(self, async_session, test_user):
|
||||
"""Test querying usage records by user ID and engine type."""
|
||||
user_id = test_user.id
|
||||
record1 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="chatgpt",
|
||||
query="Test chatgpt",
|
||||
cost=1.0,
|
||||
)
|
||||
record2 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="kimi",
|
||||
query="Test kimi",
|
||||
cost=2.0,
|
||||
)
|
||||
async_session.add(record1)
|
||||
async_session.add(record2)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.user_id == user_id,
|
||||
UsageRecord.engine_type == "chatgpt"
|
||||
)
|
||||
)
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].engine_type == "chatgpt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_by_time_range(self, async_session, test_user):
|
||||
"""Test querying usage records by time range."""
|
||||
user_id = test_user.id
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
old_record = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="gemini",
|
||||
query="Old query",
|
||||
cost=1.0,
|
||||
timestamp=now - timedelta(days=10),
|
||||
)
|
||||
new_record = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="gemini",
|
||||
query="New query",
|
||||
cost=2.0,
|
||||
timestamp=now,
|
||||
)
|
||||
async_session.add(old_record)
|
||||
async_session.add(new_record)
|
||||
await async_session.commit()
|
||||
|
||||
cutoff = now - timedelta(days=7)
|
||||
result = await async_session.execute(
|
||||
select(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.user_id == user_id,
|
||||
UsageRecord.timestamp >= cutoff
|
||||
)
|
||||
)
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].query == "New query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_aggregate_by_user(self, async_session, test_user):
|
||||
"""Test aggregating usage records by user."""
|
||||
user_id = test_user.id
|
||||
records_data = [
|
||||
{"engine": "chatgpt", "input_tokens": 100, "output_tokens": 200, "cost": 0.01},
|
||||
{"engine": "kimi", "input_tokens": 150, "output_tokens": 300, "cost": 0.02},
|
||||
{"engine": "chatgpt", "input_tokens": 200, "output_tokens": 400, "cost": 0.03},
|
||||
]
|
||||
for data in records_data:
|
||||
record = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type=data["engine"],
|
||||
query=f"Query for {data['engine']}",
|
||||
input_tokens=data["input_tokens"],
|
||||
output_tokens=data["output_tokens"],
|
||||
cost=data["cost"],
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(
|
||||
UsageRecord.engine_type,
|
||||
func.count(UsageRecord.id).label("count"),
|
||||
func.sum(UsageRecord.input_tokens).label("total_input"),
|
||||
func.sum(UsageRecord.output_tokens).label("total_output"),
|
||||
func.sum(UsageRecord.cost).label("total_cost"),
|
||||
).where(UsageRecord.user_id == user_id).group_by(UsageRecord.engine_type)
|
||||
)
|
||||
aggregates = result.all()
|
||||
|
||||
assert len(aggregates) == 2
|
||||
|
||||
chatgpt_agg = next(a for a in aggregates if a.engine_type == "chatgpt")
|
||||
assert chatgpt_agg.count == 2
|
||||
assert chatgpt_agg.total_input == 300
|
||||
assert chatgpt_agg.total_output == 600
|
||||
assert chatgpt_agg.total_cost == 0.04
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_aggregate_by_day(self, async_session, test_user):
|
||||
"""Test aggregating usage records by day."""
|
||||
user_id = test_user.id
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
record1 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="qwen",
|
||||
query="Query today 1",
|
||||
cost=1.0,
|
||||
timestamp=today_start + timedelta(hours=10),
|
||||
)
|
||||
record2 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="qwen",
|
||||
query="Query today 2",
|
||||
cost=2.0,
|
||||
timestamp=today_start + timedelta(hours=14),
|
||||
)
|
||||
record3 = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="qwen",
|
||||
query="Query yesterday",
|
||||
cost=3.0,
|
||||
timestamp=today_start - timedelta(days=1),
|
||||
)
|
||||
async_session.add_all([record1, record2, record3])
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(
|
||||
func.date(UsageRecord.timestamp).label("date"),
|
||||
func.count(UsageRecord.id).label("count"),
|
||||
func.sum(UsageRecord.cost).label("total_cost"),
|
||||
).where(
|
||||
UsageRecord.user_id == user_id
|
||||
).group_by(func.date(UsageRecord.timestamp))
|
||||
)
|
||||
aggregates = result.all()
|
||||
|
||||
assert len(aggregates) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_with_brand_id(self, async_session, test_user):
|
||||
"""Test usage record with brand association."""
|
||||
brand_id = uuid.uuid4()
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
brand_id=brand_id,
|
||||
engine_type="wenxin",
|
||||
query="Brand query",
|
||||
cost=5.0,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.brand_id == brand_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_index_user_engine(self, async_session, test_user):
|
||||
"""Test composite index on user_id and engine_type."""
|
||||
user_id = test_user.id
|
||||
for i in range(5):
|
||||
record = UsageRecord(
|
||||
user_id=user_id,
|
||||
engine_type="doubao",
|
||||
query=f"Doubao query {i}",
|
||||
cost=float(i),
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(UsageRecord).where(
|
||||
and_(
|
||||
UsageRecord.user_id == user_id,
|
||||
UsageRecord.engine_type == "doubao"
|
||||
)
|
||||
)
|
||||
)
|
||||
records = result.scalars().all()
|
||||
|
||||
assert len(records) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_update(self, async_session, test_user):
|
||||
"""Test updating usage record fields."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="xinghuo",
|
||||
query="Original query",
|
||||
cost=1.0,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
|
||||
record.cost = 10.0
|
||||
record.extra_data = {"updated": True}
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.cost == 10.0
|
||||
assert record.extra_data == {"updated": True}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_delete(self, async_session, test_user):
|
||||
"""Test deleting a usage record."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="yuanbao",
|
||||
query="Delete me",
|
||||
cost=1.0,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
record_id = record.id
|
||||
|
||||
await async_session.delete(record)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.id == record_id)
|
||||
)
|
||||
deleted_record = result.scalar_one_or_none()
|
||||
|
||||
assert deleted_record is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_timestamps(self, async_session, test_user):
|
||||
"""Test usage record timestamp fields."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="perplexity",
|
||||
query="Timestamp test",
|
||||
cost=1.0,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.timestamp is not None
|
||||
assert record.created_at is not None
|
||||
assert isinstance(record.timestamp, datetime)
|
||||
assert isinstance(record.created_at, datetime)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_query_multiple_users(self, async_session, test_user):
|
||||
"""Test querying records for multiple users."""
|
||||
other_user_id = uuid.uuid4()
|
||||
|
||||
user1_record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
query="User 1 query",
|
||||
cost=1.0,
|
||||
)
|
||||
user2_record = UsageRecord(
|
||||
user_id=other_user_id,
|
||||
engine_type="kimi",
|
||||
query="User 2 query",
|
||||
cost=2.0,
|
||||
)
|
||||
async_session.add(user1_record)
|
||||
async_session.add(user2_record)
|
||||
await async_session.commit()
|
||||
|
||||
result_user1 = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.user_id == test_user.id)
|
||||
)
|
||||
result_user2 = await async_session.execute(
|
||||
select(UsageRecord).where(UsageRecord.user_id == other_user_id)
|
||||
)
|
||||
|
||||
assert len(result_user1.scalars().all()) == 1
|
||||
assert len(result_user2.scalars().all()) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_empty_query_field(self, async_session, test_user):
|
||||
"""Test usage record with empty query field."""
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="deepseek",
|
||||
query="",
|
||||
cost=0.0,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.query == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_large_metadata(self, async_session, test_user):
|
||||
"""Test usage record with large metadata dict."""
|
||||
large_metadata = {
|
||||
"model": "gpt-4-turbo",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"system_prompt": "You are a helpful assistant." * 100,
|
||||
"nested": {"key": {"deep": {"value": [1, 2, 3]}}},
|
||||
}
|
||||
record = UsageRecord(
|
||||
user_id=test_user.id,
|
||||
engine_type="chatgpt",
|
||||
query="Large metadata test",
|
||||
extra_data=large_metadata,
|
||||
)
|
||||
async_session.add(record)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(record)
|
||||
|
||||
assert record.extra_data == large_metadata
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
"""Tests for by_day aggregation and user quota service."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.user import User
|
||||
from app.models.usage_record import UsageRecord
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
"""Create async engine for testing with SQLite."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
"""Create async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_free(async_session):
|
||||
"""Create a free plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="free@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Free User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_basic(async_session):
|
||||
"""Create a basic plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="basic@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Basic User",
|
||||
plan="basic",
|
||||
max_queries=50,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_pro(async_session):
|
||||
"""Create a pro plan test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="pro@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Pro User",
|
||||
plan="pro",
|
||||
max_queries=500,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
class TestByDayAggregation:
|
||||
"""Test cases for by_day aggregation in usage summary."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_by_day_returns_data(self, async_session, test_user_free):
|
||||
"""Test that by_day aggregation returns non-empty data when records exist."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.01,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
||||
|
||||
assert "by_day" in summary
|
||||
assert len(summary["by_day"]) > 0, "by_day should not be empty when records exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_by_day_groups_by_date(self, async_session, test_user_free):
|
||||
"""Test that records are correctly grouped by date."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "deepseek",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.02,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
assert today in summary["by_day"], f"Today's date {today} should be in by_day"
|
||||
assert summary["by_day"][today]["queries"] == 5
|
||||
assert summary["by_day"][today]["cost"] == 0.10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_by_day_aggregates_tokens(self, async_session, test_user_free):
|
||||
"""Test that by_day correctly aggregates tokens."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "qwen",
|
||||
"query": "Query 1",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
"cost": 0.01,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "qwen",
|
||||
"query": "Query 2",
|
||||
"input_tokens": 150,
|
||||
"output_tokens": 300,
|
||||
"cost": 0.02,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
assert summary["by_day"][today]["input_tokens"] == 250
|
||||
assert summary["by_day"][today]["output_tokens"] == 500
|
||||
assert summary["by_day"][today]["cost"] == 0.03
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_by_day_empty_when_no_records(self, async_session, test_user_free):
|
||||
"""Test that by_day is empty when no records exist."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
||||
|
||||
assert summary["by_day"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_by_day_multiple_days(self, async_session, test_user_free):
|
||||
"""Test by_day when records span multiple days."""
|
||||
repo = UsageRepository(async_session)
|
||||
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "gemini",
|
||||
"query": "Yesterday query",
|
||||
"cost": 0.05,
|
||||
"timestamp": yesterday,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "gemini",
|
||||
"query": "Today query",
|
||||
"cost": 0.05,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user_free.id), period="week")
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
yesterday_str = yesterday.strftime("%Y-%m-%d")
|
||||
|
||||
assert today in summary["by_day"]
|
||||
assert yesterday_str in summary["by_day"]
|
||||
|
||||
|
||||
class TestUserQuotaService:
|
||||
"""Test cases for UserQuotaService with plan-based monthly limits."""
|
||||
|
||||
def test_plan_monthly_limits_defined(self):
|
||||
"""Test that all plan monthly limits are defined."""
|
||||
assert "free" in PLAN_MONTHLY_LIMITS
|
||||
assert "basic" in PLAN_MONTHLY_LIMITS
|
||||
assert "pro" in PLAN_MONTHLY_LIMITS
|
||||
assert "enterprise" in PLAN_MONTHLY_LIMITS
|
||||
|
||||
def test_free_plan_monthly_limit(self):
|
||||
"""Test that free plan has 10 yuan monthly limit."""
|
||||
assert PLAN_MONTHLY_LIMITS["free"] == 10.0
|
||||
|
||||
def test_basic_plan_monthly_limit(self):
|
||||
"""Test that basic plan has 50 yuan monthly limit."""
|
||||
assert PLAN_MONTHLY_LIMITS["basic"] == 50.0
|
||||
|
||||
def test_pro_plan_monthly_limit(self):
|
||||
"""Test that pro plan has 200 yuan monthly limit."""
|
||||
assert PLAN_MONTHLY_LIMITS["pro"] == 200.0
|
||||
|
||||
def test_enterprise_plan_monthly_limit(self):
|
||||
"""Test that enterprise plan has 1000 yuan monthly limit."""
|
||||
assert PLAN_MONTHLY_LIMITS["enterprise"] == 1000.0
|
||||
|
||||
def test_unknown_plan_defaults_to_free(self):
|
||||
"""Test that unknown plan defaults to free plan limit."""
|
||||
service = UserQuotaService()
|
||||
limit = service.get_monthly_limit("unknown_plan")
|
||||
assert limit == 10.0
|
||||
|
||||
def test_get_monthly_limit_free(self):
|
||||
"""Test get_monthly_limit for free plan."""
|
||||
service = UserQuotaService()
|
||||
assert service.get_monthly_limit("free") == 10.0
|
||||
|
||||
def test_get_monthly_limit_basic(self):
|
||||
"""Test get_monthly_limit for basic plan."""
|
||||
service = UserQuotaService()
|
||||
assert service.get_monthly_limit("basic") == 50.0
|
||||
|
||||
def test_get_monthly_limit_pro(self):
|
||||
"""Test get_monthly_limit for pro plan."""
|
||||
service = UserQuotaService()
|
||||
assert service.get_monthly_limit("pro") == 200.0
|
||||
|
||||
def test_get_monthly_limit_enterprise(self):
|
||||
"""Test get_monthly_limit for enterprise plan."""
|
||||
service = UserQuotaService()
|
||||
assert service.get_monthly_limit("enterprise") == 1000.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_with_free_plan(self, async_session, test_user_free):
|
||||
"""Test quota check uses free plan limit (10 yuan)."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"cost": 1.0,
|
||||
})
|
||||
|
||||
service = UserQuotaService(session=async_session)
|
||||
result = await service.check_quota_with_plan(
|
||||
user_id=str(test_user_free.id),
|
||||
user_plan="free"
|
||||
)
|
||||
|
||||
assert result["limit"] == 10.0
|
||||
assert result["used"] == 5.0
|
||||
assert result["usage_percentage"] == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_with_basic_plan(self, async_session, test_user_basic):
|
||||
"""Test quota check uses basic plan limit (50 yuan)."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(10):
|
||||
await repo.create({
|
||||
"user_id": test_user_basic.id,
|
||||
"engine_type": "deepseek",
|
||||
"query": f"Query {i}",
|
||||
"cost": 2.0,
|
||||
})
|
||||
|
||||
service = UserQuotaService(session=async_session)
|
||||
result = await service.check_quota_with_plan(
|
||||
user_id=str(test_user_basic.id),
|
||||
user_plan="basic"
|
||||
)
|
||||
|
||||
assert result["limit"] == 50.0
|
||||
assert result["used"] == 20.0
|
||||
assert result["usage_percentage"] == 40.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_with_pro_plan(self, async_session, test_user_pro):
|
||||
"""Test quota check uses pro plan limit (200 yuan)."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(10):
|
||||
await repo.create({
|
||||
"user_id": test_user_pro.id,
|
||||
"engine_type": "qwen",
|
||||
"query": f"Query {i}",
|
||||
"cost": 10.0,
|
||||
})
|
||||
|
||||
service = UserQuotaService(session=async_session)
|
||||
result = await service.check_quota_with_plan(
|
||||
user_id=str(test_user_pro.id),
|
||||
user_plan="pro"
|
||||
)
|
||||
|
||||
assert result["limit"] == 200.0
|
||||
assert result["used"] == 100.0
|
||||
assert result["usage_percentage"] == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_exceeded_status(self, async_session, test_user_free):
|
||||
"""Test that exceeded status works correctly with plan limits."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_free.id,
|
||||
"engine_type": "gemini",
|
||||
"query": "Expensive query",
|
||||
"cost": 15.0,
|
||||
})
|
||||
|
||||
service = UserQuotaService(session=async_session)
|
||||
result = await service.check_quota_with_plan(
|
||||
user_id=str(test_user_free.id),
|
||||
user_plan="free"
|
||||
)
|
||||
|
||||
assert result["limit"] == 10.0
|
||||
assert result["used"] == 15.0
|
||||
assert result["usage_percentage"] == 150.0
|
||||
assert result["status"] == "exceeded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_warning_status(self, async_session, test_user_basic):
|
||||
"""Test that warning status works correctly with plan limits."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user_basic.id,
|
||||
"engine_type": "kimi",
|
||||
"query": "Moderate query",
|
||||
"cost": 45.0,
|
||||
})
|
||||
|
||||
service = UserQuotaService(session=async_session)
|
||||
result = await service.check_quota_with_plan(
|
||||
user_id=str(test_user_basic.id),
|
||||
user_plan="basic"
|
||||
)
|
||||
|
||||
assert result["limit"] == 50.0
|
||||
assert result["used"] == 45.0
|
||||
assert result["usage_percentage"] == 90.0
|
||||
assert result["status"] == "warning"
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
"""Tests for UsageRepository."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.user import User
|
||||
from app.models.usage_record import UsageRecord
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
"""Create async engine for testing with SQLite."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
"""Create async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
class TestUsageRepository:
|
||||
"""Test cases for UsageRepository."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create(self, async_session, test_user):
|
||||
"""Test creating a usage record."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Test query",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
"cost": 0.015,
|
||||
"extra_data": {"model": "gpt-4"},
|
||||
}
|
||||
|
||||
record = await repo.create(data)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.engine_type == "chatgpt"
|
||||
assert record.query == "Test query"
|
||||
assert record.input_tokens == 100
|
||||
assert record.output_tokens == 200
|
||||
assert record.cost == 0.015
|
||||
assert record.extra_data == {"model": "gpt-4"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_minimal(self, async_session, test_user):
|
||||
"""Test creating a usage record with minimal data."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "deepseek",
|
||||
"query": "Minimal query",
|
||||
}
|
||||
|
||||
record = await repo.create(data)
|
||||
|
||||
assert record.id is not None
|
||||
assert record.user_id == test_user.id
|
||||
assert record.engine_type == "deepseek"
|
||||
assert record.query == "Minimal query"
|
||||
assert record.input_tokens == 0
|
||||
assert record.output_tokens == 0
|
||||
assert record.cost == 0.0
|
||||
assert record.extra_data == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary(self, async_session, test_user):
|
||||
"""Test getting usage summary."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": f"Query {i}",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 200,
|
||||
"cost": 0.01,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user.id), period="month")
|
||||
|
||||
assert summary["period"] == "month"
|
||||
assert summary["total_queries"] == 3
|
||||
assert summary["total_input_tokens"] == 300
|
||||
assert summary["total_output_tokens"] == 600
|
||||
assert summary["total_cost"] == 0.03
|
||||
assert "chatgpt" in summary["by_engine"]
|
||||
assert summary["by_engine"]["chatgpt"]["queries"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary_by_engine(self, async_session, test_user):
|
||||
"""Test getting usage summary grouped by engine."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": "ChatGPT query",
|
||||
"cost": 0.02,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "deepseek",
|
||||
"query": "DeepSeek query",
|
||||
"cost": 0.01,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user.id), period="month")
|
||||
|
||||
assert len(summary["by_engine"]) == 2
|
||||
assert summary["by_engine"]["chatgpt"]["queries"] == 1
|
||||
assert summary["by_engine"]["chatgpt"]["cost"] == 0.02
|
||||
assert summary["by_engine"]["deepseek"]["queries"] == 1
|
||||
assert summary["by_engine"]["deepseek"]["cost"] == 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary_by_day(self, async_session, test_user):
|
||||
"""Test getting usage summary grouped by day."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(2):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "qwen",
|
||||
"query": f"Query {i}",
|
||||
"cost": 0.01,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(str(test_user.id), period="month")
|
||||
|
||||
assert len(summary["by_day"]) >= 1
|
||||
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
assert today in summary["by_day"]
|
||||
assert summary["by_day"][today]["queries"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_summary_with_brand_filter(self, async_session, test_user):
|
||||
"""Test getting usage summary filtered by brand."""
|
||||
repo = UsageRepository(async_session)
|
||||
brand_id = uuid.uuid4()
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"brand_id": brand_id,
|
||||
"engine_type": "kimi",
|
||||
"query": "Brand query",
|
||||
"cost": 0.05,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "kimi",
|
||||
"query": "No brand query",
|
||||
"cost": 0.03,
|
||||
})
|
||||
|
||||
summary = await repo.get_summary(
|
||||
str(test_user.id),
|
||||
period="month",
|
||||
brand_id=str(brand_id),
|
||||
)
|
||||
|
||||
assert summary["total_queries"] == 1
|
||||
assert summary["total_cost"] == 0.05
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota(self, async_session, test_user):
|
||||
"""Test checking quota usage."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "gemini",
|
||||
"query": f"Query {i}",
|
||||
"cost": 1.0,
|
||||
})
|
||||
|
||||
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
||||
|
||||
assert result["used"] == 5.0
|
||||
assert result["limit"] == 100.0
|
||||
assert result["usage_percentage"] == 5.0
|
||||
assert result["status"] == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_warning(self, async_session, test_user):
|
||||
"""Test quota warning status."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Expensive query",
|
||||
"cost": 85.0,
|
||||
})
|
||||
|
||||
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
||||
|
||||
assert result["status"] == "warning"
|
||||
assert result["usage_percentage"] == 85.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_quota_exceeded(self, async_session, test_user):
|
||||
"""Test quota exceeded status."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "chatgpt",
|
||||
"query": "Very expensive query",
|
||||
"cost": 120.0,
|
||||
})
|
||||
|
||||
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
||||
|
||||
assert result["status"] == "exceeded"
|
||||
assert result["usage_percentage"] == 120.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id(self, async_session, test_user):
|
||||
"""Test getting a usage record by ID."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
created = await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "wenxin",
|
||||
"query": "Get by ID test",
|
||||
"cost": 0.5,
|
||||
})
|
||||
|
||||
fetched = await repo.get_by_id(created.id)
|
||||
|
||||
assert fetched is not None
|
||||
assert fetched.id == created.id
|
||||
assert fetched.engine_type == "wenxin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user(self, async_session, test_user):
|
||||
"""Test getting usage records by user."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
for i in range(3):
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "doubao",
|
||||
"query": f"User query {i}",
|
||||
"cost": 0.1,
|
||||
})
|
||||
|
||||
records = await repo.get_by_user(str(test_user.id))
|
||||
|
||||
assert len(records) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_and_engine(self, async_session, test_user):
|
||||
"""Test getting usage records by user and engine."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "xinghuo",
|
||||
"query": "Xinghuo query",
|
||||
"cost": 0.1,
|
||||
})
|
||||
await repo.create({
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "perplexity",
|
||||
"query": "Perplexity query",
|
||||
"cost": 0.2,
|
||||
})
|
||||
|
||||
records = await repo.get_by_user_and_engine(
|
||||
str(test_user.id),
|
||||
"xinghuo",
|
||||
)
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].engine_type == "xinghuo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_summary(self, async_session, test_user):
|
||||
"""Test getting summary when no records exist."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
summary = await repo.get_summary(str(test_user.id), period="month")
|
||||
|
||||
assert summary["total_queries"] == 0
|
||||
assert summary["total_input_tokens"] == 0
|
||||
assert summary["total_output_tokens"] == 0
|
||||
assert summary["total_cost"] == 0.0
|
||||
assert summary["by_engine"] == {}
|
||||
assert summary["by_day"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_quota_check(self, async_session, test_user):
|
||||
"""Test checking quota when no records exist."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
||||
|
||||
assert result["used"] == 0.0
|
||||
assert result["usage_percentage"] == 0.0
|
||||
assert result["status"] == "ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uuid_handling(self, async_session, test_user):
|
||||
"""Test that UUID handling works correctly."""
|
||||
repo = UsageRepository(async_session)
|
||||
|
||||
data = {
|
||||
"user_id": test_user.id,
|
||||
"engine_type": "yuanbao",
|
||||
"query": "UUID test",
|
||||
"cost": 0.5,
|
||||
}
|
||||
|
||||
record = await repo.create(data)
|
||||
summary = await repo.get_summary(test_user.id, period="month")
|
||||
|
||||
assert summary["total_queries"] == 1
|
||||
assert record.user_id == test_user.id
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, EngineType
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
|
||||
|
||||
class MockAdapter(AIEngineAdapter):
|
||||
def __init__(self, api_key: str | None = None, **kwargs):
|
||||
super().__init__(api_key=api_key, **kwargs)
|
||||
|
||||
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
|
||||
class TestAdapterKeySource:
|
||||
def test_adapter_accepts_key_manager_parameter(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
adapter = MockAdapter(key_manager=key_manager, user_id="user123")
|
||||
assert adapter._key_manager is key_manager
|
||||
assert adapter._user_id == "user123"
|
||||
|
||||
def test_adapter_accepts_api_key_parameter(self):
|
||||
adapter = MockAdapter(api_key="direct-key-123")
|
||||
assert adapter.api_key == "direct-key-123"
|
||||
|
||||
def test_resolve_key_from_direct_api_key(self):
|
||||
adapter = MockAdapter(api_key="direct-key-123")
|
||||
assert adapter.api_key == "direct-key-123"
|
||||
|
||||
def test_resolve_key_from_key_manager_user_key(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = "manager-key-456"
|
||||
|
||||
adapter = MockAdapter(key_manager=key_manager, user_id="user123")
|
||||
|
||||
key_manager.get_key.assert_called_once_with("chatgpt", user_id="user123")
|
||||
assert adapter.api_key == "manager-key-456"
|
||||
|
||||
def test_resolve_key_fallback_to_env_when_no_manager(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "env-key-789")
|
||||
|
||||
adapter = MockAdapter()
|
||||
assert adapter.api_key == "env-key-789"
|
||||
|
||||
def test_direct_api_key_priority_over_key_manager(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = "manager-key-456"
|
||||
|
||||
adapter = MockAdapter(api_key="direct-key-123", key_manager=key_manager, user_id="user123")
|
||||
|
||||
key_manager.get_key.assert_not_called()
|
||||
assert adapter.api_key == "direct-key-123"
|
||||
|
||||
def test_user_key_priority_over_system_key(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = "user-key-from-manager"
|
||||
|
||||
adapter = MockAdapter(key_manager=key_manager, user_id="user123")
|
||||
assert adapter.api_key == "user-key-from-manager"
|
||||
|
||||
def test_no_key_available_returns_empty(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = None
|
||||
|
||||
adapter = MockAdapter(key_manager=key_manager, user_id="user123")
|
||||
assert adapter.api_key == ""
|
||||
|
||||
def test_adapter_with_all_parameters(self):
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = "full-key"
|
||||
|
||||
adapter = MockAdapter(
|
||||
api_key=None,
|
||||
rate_limiter=MagicMock(),
|
||||
proxy="http://proxy:8080",
|
||||
key_manager=key_manager,
|
||||
user_id="test-user"
|
||||
)
|
||||
assert adapter._key_manager is key_manager
|
||||
assert adapter._user_id == "test-user"
|
||||
assert adapter.proxy == "http://proxy:8080"
|
||||
|
||||
|
||||
class TestBatchQueryServiceKeyIntegration:
|
||||
def test_build_adapters_with_key_manager(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
key_manager = MagicMock(spec=APIKeyManager)
|
||||
key_manager.get_key.return_value = "batch-test-key"
|
||||
|
||||
adapters = {
|
||||
"chatgpt": MockAdapter(key_manager=key_manager, user_id="batch-user")
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
service.set_user_context(user_id="batch-user", brand_id="brand-123")
|
||||
|
||||
assert service._user_id == "batch-user"
|
||||
assert service._brand_id == "brand-123"
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import EngineType
|
||||
|
||||
|
||||
class TestAllAdaptersRegistered:
|
||||
def test_all_nine_adapters_in_adapter_classes(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
|
||||
registered_types = set(_ADAPTER_CLASSES.keys())
|
||||
expected_types = {
|
||||
EngineType.CHATGPT,
|
||||
EngineType.PERPLEXITY,
|
||||
EngineType.KIMI,
|
||||
EngineType.WENXIN,
|
||||
EngineType.DOUBAO,
|
||||
EngineType.YUANBAO,
|
||||
EngineType.DEEPSEEK,
|
||||
EngineType.QWEN,
|
||||
EngineType.GEMINI,
|
||||
}
|
||||
|
||||
assert registered_types == expected_types, (
|
||||
f"Missing adapters: {expected_types - registered_types}. "
|
||||
f"Extra adapters: {registered_types - expected_types}"
|
||||
)
|
||||
|
||||
def test_adapter_classes_has_nine_entries(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
|
||||
assert len(_ADAPTER_CLASSES) == 9, (
|
||||
f"Expected 9 adapters, got {len(_ADAPTER_CLASSES)}"
|
||||
)
|
||||
|
||||
def test_deepseek_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.deepseek import DeepSeekAdapter
|
||||
|
||||
assert EngineType.DEEPSEEK in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.DEEPSEEK] == DeepSeekAdapter
|
||||
|
||||
def test_qwen_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.qwen import QwenAdapter
|
||||
|
||||
assert EngineType.QWEN in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.QWEN] == QwenAdapter
|
||||
|
||||
def test_gemini_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.gemini import GeminiAdapter
|
||||
|
||||
assert EngineType.GEMINI in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.GEMINI] == GeminiAdapter
|
||||
|
||||
def test_chatgpt_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
|
||||
assert EngineType.CHATGPT in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.CHATGPT] == ChatGPTAdapter
|
||||
|
||||
def test_perplexity_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
|
||||
assert EngineType.PERPLEXITY in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.PERPLEXITY] == PerplexityAdapter
|
||||
|
||||
def test_kimi_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
|
||||
assert EngineType.KIMI in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.KIMI] == KimiAdapter
|
||||
|
||||
def test_wenxin_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
|
||||
assert EngineType.WENXIN in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.WENXIN] == WenxinAdapter
|
||||
|
||||
def test_doubao_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
|
||||
assert EngineType.DOUBAO in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.DOUBAO] == DoubaoAdapter
|
||||
|
||||
def test_yuanbao_adapter_registered(self):
|
||||
from app.services.ai_engine.batch_query import _ADAPTER_CLASSES
|
||||
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
||||
|
||||
assert EngineType.YUANBAO in _ADAPTER_CLASSES
|
||||
assert _ADAPTER_CLASSES[EngineType.YUANBAO] == YuanbaoAdapter
|
||||
|
||||
|
||||
class TestBatchServiceWithAllAdapters:
|
||||
def test_batch_service_builds_all_adapters(self):
|
||||
from app.services.ai_engine.batch_query import _build_adapters
|
||||
|
||||
adapters = _build_adapters()
|
||||
|
||||
expected_engines = [
|
||||
"chatgpt", "perplexity", "kimi", "wenxin", "doubao",
|
||||
"yuanbao", "deepseek", "qwen", "gemini"
|
||||
]
|
||||
|
||||
for engine in expected_engines:
|
||||
assert engine in adapters, f"Missing adapter: {engine}"
|
||||
|
||||
assert len(adapters) == 9, (
|
||||
f"Expected 9 adapters built, got {len(adapters)}"
|
||||
)
|
||||
|
|
@ -128,6 +128,9 @@ class TestAIEngineAdapterBase:
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"BrandX is the best insurance company",
|
||||
|
|
@ -147,6 +150,9 @@ class TestAIEngineAdapterBase:
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"CompetitorY is also a good choice for insurance",
|
||||
|
|
@ -166,6 +172,9 @@ class TestAIEngineAdapterBase:
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"BrandX and CompetitorY are both good insurance options",
|
||||
|
|
@ -185,6 +194,9 @@ class TestAIEngineAdapterBase:
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"Some random text without brand names",
|
||||
|
|
@ -204,6 +216,9 @@ class TestAIEngineAdapterBase:
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, _, _, _ = adapter._detect_citations(
|
||||
"brandx is great",
|
||||
|
|
|
|||
|
|
@ -171,9 +171,14 @@ class TestKeyVerification:
|
|||
return APIKeyManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_valid_key(self, manager):
|
||||
async def test_verify_calls_factory(self, manager):
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
with patch('app.services.api_key_manager.KeyVerifierFactory.verify', new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.return_value = KeyStatus.ACTIVE
|
||||
status = await manager.verify_key("chatgpt", "sk-valid-key-1234567890")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
mock_verify.assert_called_once_with("chatgpt", "sk-valid-key-1234567890")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_empty_key(self, manager):
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ class _StubAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
return self._engine_type
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return ""
|
||||
|
||||
|
||||
class TestBatchQueryServiceInit:
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TestKeyEncryptionBasic:
|
||||
"""基础加密解密测试"""
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
decrypted = encryption.decrypt(ciphertext)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_encrypted_content_differs_from_plaintext(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
assert ciphertext != plaintext
|
||||
assert ciphertext != plaintext.encode()
|
||||
|
||||
def test_same_plaintext_produces_different_ciphertext(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext1 = encryption.encrypt(plaintext)
|
||||
ciphertext2 = encryption.encrypt(plaintext)
|
||||
assert ciphertext1 != ciphertext2
|
||||
|
||||
|
||||
class TestKeyEncryptionEdgeCases:
|
||||
"""边界情况测试"""
|
||||
|
||||
def test_encrypt_empty_string_raises_error(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
with pytest.raises(ValueError, match="Cannot encrypt empty string"):
|
||||
encryption.encrypt("")
|
||||
|
||||
def test_decrypt_empty_string_raises_error(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
with pytest.raises(ValueError, match="Cannot decrypt empty string"):
|
||||
encryption.decrypt("")
|
||||
|
||||
def test_encrypt_special_characters(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
plaintext = "sk-中文测试!@#$%^&*()_+-=[]{}|;':\",./<>?"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
decrypted = encryption.decrypt(ciphertext)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_encrypt_unicode_characters(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
plaintext = "sk-日本語テスト한국어"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
decrypted = encryption.decrypt(ciphertext)
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class TestKeyEncryptionInvalidInputs:
|
||||
"""无效输入测试"""
|
||||
|
||||
def test_decrypt_invalid_ciphertext_raises_error(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption = KeyEncryption(encryption_key="test-key-for-encryption-12345")
|
||||
with pytest.raises(ValueError, match="Invalid ciphertext or key"):
|
||||
encryption.decrypt("invalid-ciphertext-that-cannot-be-decrypted")
|
||||
|
||||
def test_decrypt_with_wrong_key_raises_error(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
encryption1 = KeyEncryption(encryption_key="key-one-for-encryption-1234")
|
||||
encryption2 = KeyEncryption(encryption_key="key-two-for-encryption-1234")
|
||||
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext = encryption1.encrypt(plaintext)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid ciphertext or key"):
|
||||
encryption2.decrypt(ciphertext)
|
||||
|
||||
|
||||
class TestKeyEncryptionKeyFormats:
|
||||
"""密钥格式测试"""
|
||||
|
||||
def test_valid_fernet_key_format(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
valid_key = Fernet.generate_key().decode()
|
||||
encryption = KeyEncryption(encryption_key=valid_key)
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
decrypted = encryption.decrypt(ciphertext)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_password_based_key_derivation(self):
|
||||
from app.services.key_encryption import KeyEncryption
|
||||
|
||||
password = "my-secret-password-12345"
|
||||
encryption = KeyEncryption(encryption_key=password)
|
||||
plaintext = "sk-1234567890abcdef"
|
||||
ciphertext = encryption.encrypt(plaintext)
|
||||
decrypted = encryption.decrypt(ciphertext)
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class TestKeyEncryptionSingleton:
|
||||
"""全局实例测试"""
|
||||
|
||||
def test_get_key_encryption_returns_instance(self):
|
||||
from app.services.key_encryption import get_key_encryption
|
||||
|
||||
instance = get_key_encryption()
|
||||
assert instance is not None
|
||||
assert hasattr(instance, "encrypt")
|
||||
assert hasattr(instance, "decrypt")
|
||||
|
||||
def test_reset_key_encryption_clears_singleton(self):
|
||||
from app.services.key_encryption import get_key_encryption, reset_key_encryption
|
||||
|
||||
instance1 = get_key_encryption()
|
||||
reset_key_encryption()
|
||||
instance2 = get_key_encryption()
|
||||
assert instance1 is not instance2
|
||||
|
||||
|
||||
class TestAPIKeyManagerEncryption:
|
||||
"""APIKeyManager集成测试"""
|
||||
|
||||
def test_api_key_manager_uses_fernet_encryption(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
manager = APIKeyManager()
|
||||
original_key = "sk-1234567890abcdef"
|
||||
config = manager.add_key("chatgpt", original_key)
|
||||
|
||||
assert config.encrypted_key != original_key
|
||||
assert not config.encrypted_key.startswith("sk-")
|
||||
|
||||
decrypted = manager.get_key("chatgpt")
|
||||
assert decrypted == original_key
|
||||
|
||||
def test_encrypted_key_is_not_base64_of_plaintext(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
import base64
|
||||
|
||||
manager = APIKeyManager()
|
||||
original_key = "sk-1234567890abcdef"
|
||||
config = manager.add_key("chatgpt", original_key)
|
||||
|
||||
plain_base64 = base64.b64encode(original_key.encode()).decode()
|
||||
assert config.encrypted_key != plain_base64
|
||||
|
|
@ -0,0 +1,354 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import httpx
|
||||
|
||||
from app.services.key_verifier import KeyStatus
|
||||
from app.services.key_verifier import (
|
||||
KeyVerifier,
|
||||
KeyVerifierFactory,
|
||||
ChatGPTKeyVerifier,
|
||||
DeepSeekKeyVerifier,
|
||||
KimiKeyVerifier,
|
||||
QwenKeyVerifier,
|
||||
PerplexityKeyVerifier,
|
||||
GeminiKeyVerifier,
|
||||
WenxinKeyVerifier,
|
||||
DoubaoKeyVerifier,
|
||||
YuanbaoKeyVerifier,
|
||||
DefaultKeyVerifier,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_client(mock_response: MagicMock) -> AsyncMock:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
return mock_client
|
||||
|
||||
|
||||
class TestKeyVerifierFactory:
|
||||
def test_get_chatgpt_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("chatgpt")
|
||||
assert isinstance(verifier, ChatGPTKeyVerifier)
|
||||
|
||||
def test_get_deepseek_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("deepseek")
|
||||
assert isinstance(verifier, DeepSeekKeyVerifier)
|
||||
|
||||
def test_get_kimi_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("kimi")
|
||||
assert isinstance(verifier, KimiKeyVerifier)
|
||||
|
||||
def test_get_qwen_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("qwen")
|
||||
assert isinstance(verifier, QwenKeyVerifier)
|
||||
|
||||
def test_get_perplexity_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("perplexity")
|
||||
assert isinstance(verifier, PerplexityKeyVerifier)
|
||||
|
||||
def test_get_gemini_verifier(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("gemini")
|
||||
assert isinstance(verifier, GeminiKeyVerifier)
|
||||
|
||||
def test_get_unknown_verifier_returns_default(self):
|
||||
verifier = KeyVerifierFactory.get_verifier("unknown_engine")
|
||||
assert isinstance(verifier, DefaultKeyVerifier)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_delegates_to_correct_verifier(self):
|
||||
with patch.object(ChatGPTKeyVerifier, 'verify', new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.return_value = KeyStatus.ACTIVE
|
||||
status = await KeyVerifierFactory.verify("chatgpt", "sk-test-key-12345")
|
||||
mock_verify.assert_called_once_with("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
|
||||
|
||||
class TestDefaultKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_verifier_returns_unknown(self):
|
||||
verifier = DefaultKeyVerifier()
|
||||
status = await verifier.verify("any-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class TestChatGPTKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = ChatGPTKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = ChatGPTKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid API key"
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_rate_limited_key_returns_rate_limited(self):
|
||||
verifier = ChatGPTKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
mock_response.text = "Rate limit exceeded"
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-rate-limited-key")
|
||||
assert status == KeyStatus.RATE_LIMITED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_network_error_returns_unknown(self):
|
||||
verifier = ChatGPTKeyVerifier()
|
||||
|
||||
mock_client = create_mock_client(MagicMock())
|
||||
mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed"))
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_timeout_returns_unknown(self):
|
||||
verifier = ChatGPTKeyVerifier()
|
||||
|
||||
mock_client = create_mock_client(MagicMock())
|
||||
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Request timeout"))
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class TestDeepSeekKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = DeepSeekKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = DeepSeekKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_rate_limited_key_returns_rate_limited(self):
|
||||
verifier = DeepSeekKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-rate-limited-key")
|
||||
assert status == KeyStatus.RATE_LIMITED
|
||||
|
||||
|
||||
class TestKimiKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = KimiKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = KimiKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
|
||||
class TestQwenKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = QwenKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = QwenKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_forbidden_key_returns_expired(self):
|
||||
verifier = QwenKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-expired-key")
|
||||
assert status == KeyStatus.EXPIRED
|
||||
|
||||
|
||||
class TestPerplexityKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = PerplexityKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = PerplexityKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
|
||||
class TestGeminiKeyVerifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_active_key_returns_active(self):
|
||||
verifier = GeminiKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-test-key-12345")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_key_returns_invalid(self):
|
||||
verifier = GeminiKeyVerifier()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
|
||||
mock_client = create_mock_client(mock_response)
|
||||
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
status = await verifier.verify("sk-invalid-key")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
|
||||
class TestOtherEngineVerifiers:
|
||||
@pytest.mark.asyncio
|
||||
async def test_wenxin_verifier_returns_unknown(self):
|
||||
verifier = WenxinKeyVerifier()
|
||||
status = await verifier.verify("test-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doubao_verifier_returns_unknown(self):
|
||||
verifier = DoubaoKeyVerifier()
|
||||
status = await verifier.verify("test-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yuanbao_verifier_returns_unknown(self):
|
||||
verifier = YuanbaoKeyVerifier()
|
||||
status = await verifier.verify("test-key")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
|
||||
class TestKeyVerifierIntegrationWithAPIManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_manager_verify_uses_factory(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
manager = APIKeyManager()
|
||||
|
||||
with patch.object(KeyVerifierFactory, 'verify', new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.return_value = KeyStatus.ACTIVE
|
||||
status = await manager.verify_key("chatgpt", "sk-test-key-1234567890")
|
||||
assert status == KeyStatus.ACTIVE
|
||||
mock_verify.assert_called_once_with("chatgpt", "sk-test-key-1234567890")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_manager_verify_handles_factory_error(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
manager = APIKeyManager()
|
||||
|
||||
with patch.object(KeyVerifierFactory, 'verify', new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.side_effect = Exception("Network error")
|
||||
status = await manager.verify_key("chatgpt", "sk-test-key-1234567890")
|
||||
assert status == KeyStatus.UNKNOWN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_manager_verify_short_key_returns_invalid(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
manager = APIKeyManager()
|
||||
status = await manager.verify_key("chatgpt", "short")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_manager_verify_empty_key_returns_invalid(self):
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
manager = APIKeyManager()
|
||||
status = await manager.verify_key("chatgpt", "")
|
||||
assert status == KeyStatus.INVALID
|
||||
|
|
@ -17,6 +17,9 @@ class _ConcreteAdapter(AIEngineAdapter):
|
|||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class TestProxySupportBase:
|
||||
def test_base_init_with_proxy(self):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,165 @@
|
|||
import pytest
|
||||
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource
|
||||
from app.services.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter
|
||||
from app.services.engine_selector import EngineSelector
|
||||
|
||||
|
||||
class TestSmartRouterWithKeyManager:
|
||||
"""SmartRouter与APIKeyManager集成测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager(self):
|
||||
km = APIKeyManager()
|
||||
km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM, priority=1)
|
||||
km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM, priority=1)
|
||||
km.add_key("gemini", "test-gemini-key", source=KeySource.ENV, priority=0)
|
||||
return km
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager_with_user_key(self):
|
||||
km = APIKeyManager()
|
||||
km.add_key("deepseek", "system-deepseek-key", source=KeySource.SYSTEM, priority=0)
|
||||
km.add_key("deepseek", "user-deepseek-key", source=KeySource.USER, priority=10, user_id="user_123")
|
||||
km.add_key("qwen", "system-qwen-key", source=KeySource.SYSTEM, priority=1)
|
||||
return km
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager_no_keys(self):
|
||||
return APIKeyManager()
|
||||
|
||||
def test_router_accepts_key_manager(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
assert router._key_manager is key_manager
|
||||
|
||||
def test_set_key_manager_after_init(self):
|
||||
router = SmartRouter()
|
||||
assert router._key_manager is None
|
||||
key_manager = APIKeyManager()
|
||||
router.set_key_manager(key_manager)
|
||||
assert router._key_manager is key_manager
|
||||
|
||||
def test_filter_by_available_keys(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
all_engines = list(ENGINE_COST_PROFILES.keys())
|
||||
filtered = router._filter_by_available_keys(all_engines)
|
||||
assert "deepseek" in filtered
|
||||
assert "qwen" in filtered
|
||||
assert "gemini" in filtered
|
||||
assert "chatgpt" not in filtered
|
||||
assert "perplexity" not in filtered
|
||||
assert "wenxin" not in filtered
|
||||
|
||||
def test_filter_by_available_keys_no_manager(self):
|
||||
router = SmartRouter()
|
||||
all_engines = list(ENGINE_COST_PROFILES.keys())
|
||||
filtered = router._filter_by_available_keys(all_engines)
|
||||
assert set(filtered) == set(all_engines)
|
||||
|
||||
def test_select_engines_filters_no_key(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
selected = router.select_engines(max_engines=10)
|
||||
for engine in selected:
|
||||
key = key_manager.get_any_available_key(engine)
|
||||
assert key is not None, f"Selected engine {engine} has no available key"
|
||||
|
||||
def test_select_engines_returns_subset(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
selected = router.select_engines(max_engines=2)
|
||||
assert len(selected) <= 2
|
||||
assert len(selected) > 0
|
||||
|
||||
def test_select_engines_user_key_priority(self, key_manager_with_user_key):
|
||||
router = SmartRouter(key_manager=key_manager_with_user_key)
|
||||
selected = router.select_engines(max_engines=10)
|
||||
assert "deepseek" in selected
|
||||
key = key_manager_with_user_key.get_any_available_key("deepseek")
|
||||
assert key == "user-deepseek-key"
|
||||
|
||||
def test_get_available_engines(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
available = router.get_available_engines()
|
||||
for engine in available:
|
||||
key = key_manager.get_any_available_key(engine)
|
||||
assert key is not None
|
||||
|
||||
def test_get_engines_by_cost_tier_with_filter(self, key_manager):
|
||||
router = SmartRouter(key_manager=key_manager)
|
||||
free_engines = router.get_engines_by_cost_tier(CostTier.FREE)
|
||||
for engine in free_engines:
|
||||
profile = ENGINE_COST_PROFILES[engine]
|
||||
assert profile.cost_tier == CostTier.FREE
|
||||
key = key_manager.get_any_available_key(engine)
|
||||
assert key is not None
|
||||
|
||||
def test_all_engines_no_keys_returns_empty(self, key_manager_no_keys):
|
||||
router = SmartRouter(key_manager=key_manager_no_keys)
|
||||
selected = router.select_engines(max_engines=5)
|
||||
assert selected == [] or all(
|
||||
ENGINE_COST_PROFILES[e].cost_tier in [CostTier.FREE, CostTier.LOW_COST]
|
||||
and not ENGINE_COST_PROFILES[e].requires_own_key
|
||||
for e in selected
|
||||
)
|
||||
|
||||
|
||||
class TestEngineSelector:
|
||||
"""EngineSelector智能引擎选择器测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def key_manager(self):
|
||||
km = APIKeyManager()
|
||||
km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM)
|
||||
km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM)
|
||||
km.add_key("gemini", "test-gemini-key", source=KeySource.ENV)
|
||||
km.add_key("wenxin", "test-wenxin-key", source=KeySource.SYSTEM)
|
||||
return km
|
||||
|
||||
@pytest.fixture
|
||||
def selector(self, key_manager):
|
||||
return EngineSelector(key_manager=key_manager)
|
||||
|
||||
def test_initialization(self, selector, key_manager):
|
||||
assert selector.key_manager is key_manager
|
||||
assert selector.router is not None
|
||||
assert selector.router._key_manager is key_manager
|
||||
|
||||
def test_select_engines_returns_valid_engines(self, selector, key_manager):
|
||||
selected = selector.select_engines(max_engines=5)
|
||||
for engine in selected:
|
||||
key = key_manager.get_any_available_key(engine)
|
||||
assert key is not None, f"Engine {engine} selected but has no key"
|
||||
|
||||
def test_select_engines_respects_max_engines(self, selector):
|
||||
selected = selector.select_engines(max_engines=2)
|
||||
assert len(selected) <= 2
|
||||
|
||||
def test_select_engines_with_min_cost_tier(self, selector):
|
||||
selected = selector.select_engines(max_engines=10, min_cost_tier="free")
|
||||
for engine in selected:
|
||||
profile = ENGINE_COST_PROFILES[engine]
|
||||
assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST, CostTier.MID_COST, CostTier.HIGH_COST]
|
||||
tier_order = ["free", "low_cost", "mid_cost", "high_cost"]
|
||||
assert tier_order.index(profile.cost_tier.value) >= tier_order.index("free")
|
||||
|
||||
def test_get_best_value_engine(self, selector, key_manager):
|
||||
best = selector.get_best_value_engine()
|
||||
if best:
|
||||
profile = ENGINE_COST_PROFILES[best]
|
||||
key = key_manager.get_any_available_key(best)
|
||||
assert key is not None
|
||||
assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST]
|
||||
|
||||
def test_get_best_value_returns_none_when_no_keys(self):
|
||||
km = APIKeyManager()
|
||||
selector = EngineSelector(key_manager=km)
|
||||
best = selector.get_best_value_engine()
|
||||
assert best is None
|
||||
|
||||
def test_select_engines_priority_domestic(self, selector):
|
||||
selected = selector.select_engines(max_engines=10, prefer_domestic=True)
|
||||
domestic = [e for e in selected if ENGINE_COST_PROFILES[e].domestic]
|
||||
international = [e for e in selected if not ENGINE_COST_PROFILES[e].domestic]
|
||||
if domestic and international:
|
||||
last_domestic_idx = max(selected.index(e) for e in domestic)
|
||||
first_intl_idx = min(selected.index(e) for e in international)
|
||||
assert last_domestic_idx < first_intl_idx
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
import pytest
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.usage_tracker import UsageTracker
|
||||
|
||||
|
||||
def _make_result_with_tokens(
|
||||
engine_type: EngineType,
|
||||
query: str = "test query",
|
||||
input_tokens: int = 100,
|
||||
output_tokens: int = 200,
|
||||
has_brand: bool = False,
|
||||
) -> AIQueryResult:
|
||||
return AIQueryResult(
|
||||
engine_type=engine_type,
|
||||
query=query,
|
||||
raw_response="some response with content",
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=False,
|
||||
brand_context="brand context" if has_brand else None,
|
||||
competitor_contexts=[],
|
||||
response_time_ms=100,
|
||||
timestamp=datetime.now(UTC),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
|
||||
class _MockAdapter(AIEngineAdapter):
|
||||
def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None):
|
||||
super().__init__(api_key="test-key")
|
||||
self._engine_type = engine_type
|
||||
self._result = result
|
||||
|
||||
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult:
|
||||
return self._result
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return self._engine_type
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class TestUsageRecordingOnQuery:
|
||||
"""测试查询后自动记录用量"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_single_records_usage(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
result = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200)
|
||||
adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT, result=result)}
|
||||
service = BatchQueryService(adapters)
|
||||
service.set_user_context("user123", "brand456")
|
||||
|
||||
await service.query_single(EngineType.CHATGPT, "test query", "BrandX")
|
||||
|
||||
summary = service.get_usage_summary()
|
||||
assert summary.total_queries == 1
|
||||
assert summary.total_input_tokens == 100
|
||||
assert summary.total_output_tokens == 200
|
||||
assert summary.total_cost > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_query_records_usage_for_all(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
r1 = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200)
|
||||
r2 = _make_result_with_tokens(EngineType.DEEPSEEK, input_tokens=150, output_tokens=300)
|
||||
adapters = {
|
||||
"chatgpt": _MockAdapter(EngineType.CHATGPT, result=r1),
|
||||
"deepseek": _MockAdapter(EngineType.DEEPSEEK, result=r2),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
service.set_user_context("user123")
|
||||
|
||||
await service.query_batch(
|
||||
[EngineType.CHATGPT, EngineType.DEEPSEEK],
|
||||
"test query",
|
||||
"BrandX",
|
||||
)
|
||||
|
||||
summary = service.get_usage_summary()
|
||||
assert summary.total_queries == 2
|
||||
assert summary.total_input_tokens == 250
|
||||
assert summary.total_output_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_context_no_recording(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
result = _make_result_with_tokens(EngineType.CHATGPT, input_tokens=100, output_tokens=200)
|
||||
adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT, result=result)}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
await service.query_single(EngineType.CHATGPT, "test query", "BrandX")
|
||||
|
||||
summary = service.get_usage_summary()
|
||||
assert summary.total_queries == 0
|
||||
|
||||
|
||||
class TestUsageRecorderCostCalculation:
|
||||
"""测试不同引擎记录正确的成本"""
|
||||
|
||||
def test_chatgpt_cost_calculation(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
cost = recorder.calculate_cost("chatgpt", 1000, 2000)
|
||||
expected_input = (1000 / 1_000_000) * 1.0
|
||||
expected_output = (2000 / 1_000_000) * 4.0
|
||||
expected = round(expected_input + expected_output, 6)
|
||||
assert cost == expected
|
||||
|
||||
def test_deepseek_cost_calculation(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
cost = recorder.calculate_cost("deepseek", 1000, 2000)
|
||||
expected_input = (1000 / 1_000_000) * 0.25
|
||||
expected_output = (2000 / 1_000_000) * 6.0
|
||||
expected = round(expected_input + expected_output, 6)
|
||||
assert cost == expected
|
||||
|
||||
def test_kimi_cost_calculation(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
cost = recorder.calculate_cost("kimi", 1000, 2000)
|
||||
expected_input = (1000 / 1_000_000) * 12.0
|
||||
expected_output = (2000 / 1_000_000) * 12.0
|
||||
expected = round(expected_input + expected_output, 6)
|
||||
assert cost == expected
|
||||
|
||||
def test_unknown_engine_zero_cost(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
cost = recorder.calculate_cost("unknown_engine", 1000, 2000)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_record_with_cost(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
recorder.record(
|
||||
user_id="user123",
|
||||
brand_id="brand456",
|
||||
engine_type="chatgpt",
|
||||
query="test query",
|
||||
input_tokens=1000,
|
||||
output_tokens=2000,
|
||||
)
|
||||
|
||||
summary = tracker.get_summary(user_id="user123")
|
||||
assert summary.total_queries == 1
|
||||
assert summary.total_input_tokens == 1000
|
||||
assert summary.total_output_tokens == 2000
|
||||
assert summary.total_cost > 0
|
||||
|
||||
|
||||
class TestUsageSummaryCalculation:
|
||||
"""测试用量汇总计算正确"""
|
||||
|
||||
def test_summary_by_engine(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200)
|
||||
recorder.record("user1", "brand1", "chatgpt", "q2", 150, 250)
|
||||
recorder.record("user1", "brand1", "deepseek", "q3", 200, 300)
|
||||
|
||||
summary = tracker.get_summary(user_id="user1")
|
||||
assert summary.by_engine["chatgpt"]["queries"] == 2
|
||||
assert summary.by_engine["chatgpt"]["input_tokens"] == 250
|
||||
assert summary.by_engine["deepseek"]["queries"] == 1
|
||||
|
||||
def test_summary_by_day(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200)
|
||||
recorder.record("user1", "brand1", "chatgpt", "q2", 150, 250)
|
||||
|
||||
summary = tracker.get_summary(user_id="user1")
|
||||
today = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
assert today in summary.by_day
|
||||
assert summary.by_day[today]["queries"] == 2
|
||||
|
||||
def test_summary_filter_by_brand(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
recorder.record("user1", "brand1", "chatgpt", "q1", 100, 200)
|
||||
recorder.record("user1", "brand2", "chatgpt", "q2", 150, 250)
|
||||
|
||||
summary = tracker.get_summary(user_id="user1", brand_id="brand1")
|
||||
assert summary.total_queries == 1
|
||||
assert summary.total_input_tokens == 100
|
||||
|
||||
|
||||
class TestQuotaCheck:
|
||||
"""测试配额检查触发"""
|
||||
|
||||
def test_quota_check_warning(self):
|
||||
from app.services.usage_recorder import UsageRecorder
|
||||
|
||||
tracker = UsageTracker()
|
||||
recorder = UsageRecorder(tracker)
|
||||
|
||||
cost = recorder.calculate_cost("chatgpt", 1000000, 2000000)
|
||||
recorder.record("user1", "", "chatgpt", "q1", 1000000, 2000000)
|
||||
quota = tracker.check_quota("user1", monthly_limit=cost / 2)
|
||||
|
||||
assert quota["status"] == "exceeded"
|
||||
assert quota["used"] > 0
|
||||
|
||||
def test_quota_check_ok(self):
|
||||
tracker = UsageTracker()
|
||||
quota = tracker.check_quota("new_user", monthly_limit=100.0)
|
||||
assert quota["status"] == "ok"
|
||||
assert quota["used"] == 0
|
||||
|
||||
|
||||
class TestBatchQueryServiceUsageIntegration:
|
||||
"""测试BatchQueryService与用量记录的集成"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_has_usage_recorder(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)}
|
||||
service = BatchQueryService(adapters)
|
||||
assert hasattr(service, "_recorder")
|
||||
assert hasattr(service, "_tracker")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_user_context_stores_context(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)}
|
||||
service = BatchQueryService(adapters)
|
||||
service.set_user_context("user123", "brand456")
|
||||
|
||||
assert service._user_id == "user123"
|
||||
assert service._brand_id == "brand456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_usage_summary_returns_summary(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {"chatgpt": _MockAdapter(EngineType.CHATGPT)}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
summary = service.get_usage_summary()
|
||||
assert hasattr(summary, "total_queries")
|
||||
assert hasattr(summary, "total_input_tokens")
|
||||
assert hasattr(summary, "total_output_tokens")
|
||||
assert hasattr(summary, "total_cost")
|
||||
|
||||
|
||||
class TestAIQueryResultTokens:
|
||||
"""测试AIQueryResult包含token统计字段"""
|
||||
|
||||
def test_result_has_token_fields(self):
|
||||
result = AIQueryResult(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
query="test",
|
||||
raw_response="response",
|
||||
citations=[],
|
||||
has_brand_citation=False,
|
||||
has_competitor_citation=False,
|
||||
brand_context=None,
|
||||
competitor_contexts=[],
|
||||
response_time_ms=100,
|
||||
timestamp=datetime.now(UTC),
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
)
|
||||
assert result.input_tokens == 100
|
||||
assert result.output_tokens == 200
|
||||
assert result.total_tokens == 300
|
||||
|
||||
def test_result_token_fields_default_to_zero(self):
|
||||
result = AIQueryResult(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
query="test",
|
||||
raw_response="response",
|
||||
citations=[],
|
||||
has_brand_citation=False,
|
||||
has_competitor_citation=False,
|
||||
brand_context=None,
|
||||
competitor_contexts=[],
|
||||
response_time_ms=100,
|
||||
timestamp=datetime.now(UTC),
|
||||
)
|
||||
assert result.input_tokens == 0
|
||||
assert result.output_tokens == 0
|
||||
assert result.total_tokens == 0
|
||||
|
|
@ -19,6 +19,7 @@ import {
|
|||
CheckCircle,
|
||||
Loader2,
|
||||
RefreshCw,
|
||||
AlertCircle,
|
||||
} from "lucide-react";
|
||||
import {
|
||||
LineChart,
|
||||
|
|
@ -43,6 +44,18 @@ const TIME_RANGE_LABELS: Record<TimeRange, string> = {
|
|||
month: "本月",
|
||||
};
|
||||
|
||||
const ENGINE_LABELS: Record<string, string> = {
|
||||
deepseek: "DeepSeek",
|
||||
chatgpt: "ChatGPT",
|
||||
qwen: "通义千问",
|
||||
gemini: "Google Gemini",
|
||||
kimi: "Kimi",
|
||||
wenxin: "文心一言",
|
||||
doubao: "豆包",
|
||||
yuanbao: "腾讯元宝",
|
||||
perplexity: "Perplexity",
|
||||
};
|
||||
|
||||
const PIE_COLORS = [
|
||||
"#3b82f6",
|
||||
"#8b5cf6",
|
||||
|
|
@ -55,6 +68,37 @@ const PIE_COLORS = [
|
|||
"#f97316",
|
||||
];
|
||||
|
||||
interface QuotaAPIResponse {
|
||||
used: number;
|
||||
limit: number;
|
||||
usage_percentage: number;
|
||||
status: "ok" | "warning" | "exceeded";
|
||||
}
|
||||
|
||||
interface SummaryAPIResponse {
|
||||
period: string;
|
||||
start_date: string;
|
||||
end_date: string;
|
||||
total_queries: number;
|
||||
total_input_tokens: number;
|
||||
total_output_tokens: number;
|
||||
total_cost: number;
|
||||
by_engine: Record<string, {
|
||||
queries: number;
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
cost: number;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface ByEngineAPIResponse {
|
||||
engines: Array<{
|
||||
type: string;
|
||||
queries: number;
|
||||
cost: number;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface QuotaData {
|
||||
used: number;
|
||||
limit: number;
|
||||
|
|
@ -83,37 +127,6 @@ interface UsageData {
|
|||
engineUsage: EngineUsageItem[];
|
||||
}
|
||||
|
||||
const MOCK_USAGE_DATA: UsageData = {
|
||||
quota: { used: 45.6, limit: 100, percentage: 45.6, status: "normal" },
|
||||
trends: [
|
||||
{ date: "05-19", cost: 5.2 },
|
||||
{ date: "05-20", cost: 6.8 },
|
||||
{ date: "05-21", cost: 4.5 },
|
||||
{ date: "05-22", cost: 7.3 },
|
||||
{ date: "05-23", cost: 8.1 },
|
||||
{ date: "05-24", cost: 6.9 },
|
||||
{ date: "05-25", cost: 7.0 },
|
||||
],
|
||||
engineDistribution: [
|
||||
{ engine: "deepseek", label: "DeepSeek", cost: 15.2 },
|
||||
{ engine: "chatgpt", label: "ChatGPT", cost: 12.8 },
|
||||
{ engine: "qwen", label: "通义千问", cost: 8.5 },
|
||||
{ engine: "gemini", label: "Google Gemini", cost: 5.3 },
|
||||
{ engine: "kimi", label: "Kimi", cost: 3.8 },
|
||||
],
|
||||
engineUsage: [
|
||||
{ engine: "deepseek", label: "DeepSeek", queries: 1250, inputTokens: 3200000, outputTokens: 1800000, cost: 15.2 },
|
||||
{ engine: "chatgpt", label: "ChatGPT", queries: 890, inputTokens: 2100000, outputTokens: 1500000, cost: 12.8 },
|
||||
{ engine: "qwen", label: "通义千问", queries: 720, inputTokens: 1800000, outputTokens: 900000, cost: 8.5 },
|
||||
{ engine: "gemini", label: "Google Gemini", queries: 450, inputTokens: 1100000, outputTokens: 600000, cost: 5.3 },
|
||||
{ engine: "kimi", label: "Kimi", queries: 320, inputTokens: 800000, outputTokens: 400000, cost: 3.8 },
|
||||
{ engine: "wenxin", label: "文心一言", queries: 280, inputTokens: 700000, outputTokens: 350000, cost: 0.02 },
|
||||
{ engine: "doubao", label: "豆包", queries: 210, inputTokens: 520000, outputTokens: 280000, cost: 0.5 },
|
||||
{ engine: "yuanbao", label: "腾讯元宝", queries: 150, inputTokens: 380000, outputTokens: 200000, cost: 0.7 },
|
||||
{ engine: "perplexity", label: "Perplexity", queries: 30, inputTokens: 75000, outputTokens: 40000, cost: 2.6 },
|
||||
],
|
||||
};
|
||||
|
||||
function formatTokenCount(count: number): string {
|
||||
if (count >= 1000000) return `${(count / 1000000).toFixed(1)}M`;
|
||||
if (count >= 1000) return `${(count / 1000).toFixed(0)}K`;
|
||||
|
|
@ -186,19 +199,79 @@ function QuotaStatusBadge({ status }: { status: "normal" | "warning" | "exceeded
|
|||
export default function UsagePage() {
|
||||
const [timeRange, setTimeRange] = useState<TimeRange>("7d");
|
||||
|
||||
const { data: usageData, isLoading, refresh } = useApi<UsageData>(
|
||||
const { data: quotaData, isLoading: isQuotaLoading, error: quotaError } = useApi<QuotaAPIResponse>("/api/v1/usage/quota");
|
||||
const { data: summaryData, isLoading: isSummaryLoading, error: summaryError, refresh: refreshSummary } = useApi<SummaryAPIResponse>(
|
||||
`/api/v1/usage/summary?period=${timeRange === "7d" ? "week" : timeRange === "30d" ? "month" : "month"}`
|
||||
);
|
||||
const { data: byEngineData, isLoading: isByEngineLoading, error: byEngineError } = useApi<ByEngineAPIResponse>("/api/v1/usage/by-engine");
|
||||
|
||||
const data = usageData || MOCK_USAGE_DATA;
|
||||
const isLoading = isQuotaLoading || isSummaryLoading || isByEngineLoading;
|
||||
const error = quotaError || summaryError || byEngineError;
|
||||
const refresh = refreshSummary;
|
||||
|
||||
const usageData = useMemo((): UsageData | null => {
|
||||
if (!quotaData || !summaryData || !byEngineData) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const quota: QuotaData = {
|
||||
used: quotaData.used,
|
||||
limit: quotaData.limit,
|
||||
percentage: quotaData.usage_percentage,
|
||||
status: quotaData.status === "ok" ? "normal" : quotaData.status,
|
||||
};
|
||||
|
||||
const trends: TrendItem[] = [];
|
||||
if (summaryData.start_date && summaryData.end_date) {
|
||||
const startDate = new Date(summaryData.start_date);
|
||||
const endDate = new Date(summaryData.end_date);
|
||||
const daysDiff = Math.ceil((endDate.getTime() - startDate.getTime()) / (1000 * 60 * 60 * 24));
|
||||
const avgDailyCost = daysDiff > 0 ? summaryData.total_cost / daysDiff : 0;
|
||||
|
||||
for (let i = 0; i < Math.min(daysDiff, 7); i++) {
|
||||
const date = new Date(startDate);
|
||||
date.setDate(date.getDate() + i);
|
||||
const monthDay = `${String(date.getMonth() + 1).padStart(2, "0")}-${String(date.getDate()).padStart(2, "0")}`;
|
||||
trends.push({ date: monthDay, cost: avgDailyCost });
|
||||
}
|
||||
}
|
||||
|
||||
const engineDistribution = byEngineData.engines?.map((e) => ({
|
||||
engine: e.type,
|
||||
label: ENGINE_LABELS[e.type] || e.type,
|
||||
cost: e.cost,
|
||||
})) || [];
|
||||
|
||||
const engineUsage: EngineUsageItem[] = Object.entries(summaryData.by_engine || {}).map(([engine, data]) => ({
|
||||
engine,
|
||||
label: ENGINE_LABELS[engine] || engine,
|
||||
queries: data.queries,
|
||||
inputTokens: data.input_tokens,
|
||||
outputTokens: data.output_tokens,
|
||||
cost: data.cost,
|
||||
}));
|
||||
|
||||
return {
|
||||
quota,
|
||||
trends,
|
||||
engineDistribution,
|
||||
engineUsage,
|
||||
};
|
||||
}, [quotaData, summaryData, byEngineData]);
|
||||
|
||||
const data = usageData;
|
||||
|
||||
const totalCost = useMemo(() => {
|
||||
if (!data) return 0;
|
||||
return data.engineUsage.reduce((sum, item) => sum + item.cost, 0);
|
||||
}, [data.engineUsage]);
|
||||
}, [data]);
|
||||
|
||||
const totalQueries = useMemo(() => {
|
||||
if (!data) return 0;
|
||||
return data.engineUsage.reduce((sum, item) => sum + item.queries, 0);
|
||||
}, [data.engineUsage]);
|
||||
}, [data]);
|
||||
|
||||
const hasData = data && data.engineUsage.length > 0;
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
|
|
@ -230,6 +303,22 @@ export default function UsagePage() {
|
|||
<Loader2 className="h-8 w-8 animate-spin text-muted-foreground" />
|
||||
<p className="mt-4 text-sm text-muted-foreground">正在加载用量数据...</p>
|
||||
</div>
|
||||
) : error ? (
|
||||
<div className="flex flex-col items-center justify-center py-12">
|
||||
<AlertCircle className="h-8 w-8 text-destructive" />
|
||||
<p className="mt-4 text-sm text-muted-foreground">加载失败: {error.message}</p>
|
||||
<Button variant="outline" size="sm" onClick={refresh} className="mt-4">
|
||||
<RefreshCw className="mr-1 h-3 w-3" />
|
||||
重试
|
||||
</Button>
|
||||
</div>
|
||||
) : !hasData ? (
|
||||
<div className="flex flex-col items-center justify-center py-12">
|
||||
<div className="rounded-lg bg-muted p-4">
|
||||
<DollarSign className="h-8 w-8 text-muted-foreground" />
|
||||
</div>
|
||||
<p className="mt-4 text-sm text-muted-foreground">暂无用量数据,开始查询后会自动记录</p>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
|
|
|
|||
Loading…
Reference in New Issue