import logging from typing import TYPE_CHECKING from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from app.api.deps import get_current_user, get_key_manager from app.models.user import User from app.services.ai_engine.base import AIQueryResult, EngineType from app.services.ai_engine.batch_query import BatchQueryService, get_batch_service as _get_batch_service from app.services.ai_engine.chatgpt import ChatGPTAdapter from app.services.ai_engine.doubao import DoubaoAdapter from app.services.ai_engine.kimi import KimiAdapter from app.services.ai_engine.perplexity import PerplexityAdapter from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.yuanbao import YuanbaoAdapter if TYPE_CHECKING: from app.services.api_key_manager import APIKeyManager logger = logging.getLogger(__name__) router = APIRouter() class SingleQueryRequest(BaseModel): engine: str query: str = Field(min_length=1, max_length=500) brand_name: str = Field(min_length=1, max_length=200) competitor_names: list[str] | None = None class BatchQueryRequest(BaseModel): engines: list[str] = Field(min_length=1) query: str = Field(min_length=1, max_length=500) brand_name: str = Field(min_length=1, max_length=200) competitor_names: list[str] | None = None class QueryResultResponse(BaseModel): engine_type: str query: str raw_response: str has_brand_citation: bool has_competitor_citation: bool brand_context: str | None competitor_contexts: list[str] response_time_ms: int timestamp: str model_config = {"from_attributes": True} class CitationRateResponse(BaseModel): total_engines: int brand_citation_count: int brand_citation_rate: float competitor_citation_count: int competitor_citation_rate: float class BatchQueryResponse(BaseModel): results: list[QueryResultResponse] citation_rate: CitationRateResponse avg_response_time_ms: float = 0.0 def _result_to_response(r: AIQueryResult) -> QueryResultResponse: return QueryResultResponse( engine_type=r.engine_type.value, query=r.query, raw_response=r.raw_response, has_brand_citation=r.has_brand_citation, has_competitor_citation=r.has_competitor_citation, brand_context=r.brand_context, competitor_contexts=r.competitor_contexts, response_time_ms=r.response_time_ms, timestamp=r.timestamp.isoformat(), ) def _parse_engine(value: str) -> EngineType: try: return EngineType(value) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unknown engine: {value}", ) async def _execute_batch( service: BatchQueryService, engine_types: list[EngineType], query: str, brand_name: str, competitor_names: list[str] | None, ) -> BatchQueryResponse: results = await service.query_batch( engine_types, query, brand_name, competitor_names ) citation_rate = service.calculate_citation_rate(results) response_results = [_result_to_response(r) for r in results] avg_response_time = ( sum(r.response_time_ms for r in results) / len(results) if results else 0.0 ) return BatchQueryResponse( results=response_results, citation_rate=CitationRateResponse(**citation_rate), avg_response_time_ms=avg_response_time, ) @router.post("/query", response_model=QueryResultResponse) async def query_single_engine( request: SingleQueryRequest, current_user: User = Depends(get_current_user), key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) service.set_user_context(str(current_user.id)) engine_type = _parse_engine(request.engine) try: result = await service.query_single( engine_type, request.query, request.brand_name, request.competitor_names ) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) return _result_to_response(result) @router.post("/query-batch", response_model=BatchQueryResponse) async def query_batch_engines( request: BatchQueryRequest, current_user: User = Depends(get_current_user), key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) service.set_user_context(str(current_user.id)) engine_types = [_parse_engine(e) for e in request.engines] return await _execute_batch( service, engine_types, request.query, request.brand_name, request.competitor_names ) @router.get("/results", response_model=BatchQueryResponse) async def get_query_results( engines: str = Query(..., description="Comma-separated engine names"), query: str = Query(..., min_length=1, max_length=500), brand_name: str = Query(..., min_length=1, max_length=200), competitor_names: str | None = Query(None, description="Comma-separated competitor names"), current_user: User = Depends(get_current_user), key_manager: "APIKeyManager | None" = Depends(get_key_manager), ): service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id)) service.set_user_context(str(current_user.id)) engine_list = [e.strip() for e in engines.split(",") if e.strip()] engine_types = [_parse_engine(e) for e in engine_list] comp_names = ( [n.strip() for n in competitor_names.split(",") if n.strip()] if competitor_names else None ) return await _execute_batch( service, engine_types, query, brand_name, comp_names )