geo/backend/app/api/ai_engines.py

162 lines
5.4 KiB
Python

import logging
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from app.api.deps import get_current_user, get_key_manager
from app.models.user import User
from app.services.ai_engine.base import AIQueryResult, EngineType
from app.services.ai_engine.batch_query import BatchQueryService, get_batch_service as _get_batch_service
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.doubao import DoubaoAdapter
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.yuanbao import YuanbaoAdapter
if TYPE_CHECKING:
from app.services.api_key_manager import APIKeyManager
logger = logging.getLogger(__name__)
router = APIRouter()
class SingleQueryRequest(BaseModel):
engine: str
query: str = Field(min_length=1, max_length=500)
brand_name: str = Field(min_length=1, max_length=200)
competitor_names: list[str] | None = None
class BatchQueryRequest(BaseModel):
engines: list[str] = Field(min_length=1)
query: str = Field(min_length=1, max_length=500)
brand_name: str = Field(min_length=1, max_length=200)
competitor_names: list[str] | None = None
class QueryResultResponse(BaseModel):
engine_type: str
query: str
raw_response: str
has_brand_citation: bool
has_competitor_citation: bool
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
model_config = {"from_attributes": True}
class CitationRateResponse(BaseModel):
total_engines: int
brand_citation_count: int
brand_citation_rate: float
competitor_citation_count: int
competitor_citation_rate: float
class BatchQueryResponse(BaseModel):
results: list[QueryResultResponse]
citation_rate: CitationRateResponse
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
return QueryResultResponse(
engine_type=r.engine_type.value,
query=r.query,
raw_response=r.raw_response,
has_brand_citation=r.has_brand_citation,
has_competitor_citation=r.has_competitor_citation,
brand_context=r.brand_context,
competitor_contexts=r.competitor_contexts,
response_time_ms=r.response_time_ms,
)
def _parse_engine(value: str) -> EngineType:
try:
return EngineType(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown engine: {value}",
)
async def _execute_batch(
service: BatchQueryService,
engine_types: list[EngineType],
query: str,
brand_name: str,
competitor_names: list[str] | None,
) -> BatchQueryResponse:
results = await service.query_batch(
engine_types, query, brand_name, competitor_names
)
citation_rate = service.calculate_citation_rate(results)
return BatchQueryResponse(
results=[_result_to_response(r) for r in results],
citation_rate=CitationRateResponse(**citation_rate),
)
@router.post("/query", response_model=QueryResultResponse)
async def query_single_engine(
request: SingleQueryRequest,
current_user: User = Depends(get_current_user),
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
):
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
service.set_user_context(str(current_user.id))
engine_type = _parse_engine(request.engine)
try:
result = await service.query_single(
engine_type, request.query, request.brand_name, request.competitor_names
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
return _result_to_response(result)
@router.post("/query-batch", response_model=BatchQueryResponse)
async def query_batch_engines(
request: BatchQueryRequest,
current_user: User = Depends(get_current_user),
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
):
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
service.set_user_context(str(current_user.id))
engine_types = [_parse_engine(e) for e in request.engines]
return await _execute_batch(
service, engine_types, request.query, request.brand_name, request.competitor_names
)
@router.get("/results", response_model=BatchQueryResponse)
async def get_query_results(
engines: str = Query(..., description="Comma-separated engine names"),
query: str = Query(..., min_length=1, max_length=500),
brand_name: str = Query(..., min_length=1, max_length=200),
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
current_user: User = Depends(get_current_user),
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
):
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
service.set_user_context(str(current_user.id))
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
engine_types = [_parse_engine(e) for e in engine_list]
comp_names = (
[n.strip() for n in competitor_names.split(",") if n.strip()]
if competitor_names
else None
)
return await _execute_batch(
service, engine_types, query, brand_name, comp_names
)