162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.api.deps import get_current_user, get_key_manager
|
|
from app.models.user import User
|
|
from app.services.ai_engine.base import AIQueryResult, EngineType
|
|
from app.services.ai_engine.batch_query import BatchQueryService, get_batch_service as _get_batch_service
|
|
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
|
from app.services.ai_engine.doubao import DoubaoAdapter
|
|
from app.services.ai_engine.kimi import KimiAdapter
|
|
from app.services.ai_engine.perplexity import PerplexityAdapter
|
|
from app.services.ai_engine.wenxin import WenxinAdapter
|
|
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.api_key_manager import APIKeyManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class SingleQueryRequest(BaseModel):
|
|
engine: str
|
|
query: str = Field(min_length=1, max_length=500)
|
|
brand_name: str = Field(min_length=1, max_length=200)
|
|
competitor_names: list[str] | None = None
|
|
|
|
|
|
class BatchQueryRequest(BaseModel):
|
|
engines: list[str] = Field(min_length=1)
|
|
query: str = Field(min_length=1, max_length=500)
|
|
brand_name: str = Field(min_length=1, max_length=200)
|
|
competitor_names: list[str] | None = None
|
|
|
|
|
|
class QueryResultResponse(BaseModel):
|
|
engine_type: str
|
|
query: str
|
|
raw_response: str
|
|
has_brand_citation: bool
|
|
has_competitor_citation: bool
|
|
brand_context: str | None
|
|
competitor_contexts: list[str]
|
|
response_time_ms: int
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CitationRateResponse(BaseModel):
|
|
total_engines: int
|
|
brand_citation_count: int
|
|
brand_citation_rate: float
|
|
competitor_citation_count: int
|
|
competitor_citation_rate: float
|
|
|
|
|
|
class BatchQueryResponse(BaseModel):
|
|
results: list[QueryResultResponse]
|
|
citation_rate: CitationRateResponse
|
|
|
|
|
|
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
|
return QueryResultResponse(
|
|
engine_type=r.engine_type.value,
|
|
query=r.query,
|
|
raw_response=r.raw_response,
|
|
has_brand_citation=r.has_brand_citation,
|
|
has_competitor_citation=r.has_competitor_citation,
|
|
brand_context=r.brand_context,
|
|
competitor_contexts=r.competitor_contexts,
|
|
response_time_ms=r.response_time_ms,
|
|
)
|
|
|
|
|
|
def _parse_engine(value: str) -> EngineType:
|
|
try:
|
|
return EngineType(value)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unknown engine: {value}",
|
|
)
|
|
|
|
|
|
async def _execute_batch(
|
|
service: BatchQueryService,
|
|
engine_types: list[EngineType],
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None,
|
|
) -> BatchQueryResponse:
|
|
results = await service.query_batch(
|
|
engine_types, query, brand_name, competitor_names
|
|
)
|
|
citation_rate = service.calculate_citation_rate(results)
|
|
return BatchQueryResponse(
|
|
results=[_result_to_response(r) for r in results],
|
|
citation_rate=CitationRateResponse(**citation_rate),
|
|
)
|
|
|
|
|
|
@router.post("/query", response_model=QueryResultResponse)
|
|
async def query_single_engine(
|
|
request: SingleQueryRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
|
):
|
|
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
|
service.set_user_context(str(current_user.id))
|
|
engine_type = _parse_engine(request.engine)
|
|
try:
|
|
result = await service.query_single(
|
|
engine_type, request.query, request.brand_name, request.competitor_names
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
)
|
|
return _result_to_response(result)
|
|
|
|
|
|
@router.post("/query-batch", response_model=BatchQueryResponse)
|
|
async def query_batch_engines(
|
|
request: BatchQueryRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
|
):
|
|
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
|
service.set_user_context(str(current_user.id))
|
|
engine_types = [_parse_engine(e) for e in request.engines]
|
|
return await _execute_batch(
|
|
service, engine_types, request.query, request.brand_name, request.competitor_names
|
|
)
|
|
|
|
|
|
@router.get("/results", response_model=BatchQueryResponse)
|
|
async def get_query_results(
|
|
engines: str = Query(..., description="Comma-separated engine names"),
|
|
query: str = Query(..., min_length=1, max_length=500),
|
|
brand_name: str = Query(..., min_length=1, max_length=200),
|
|
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
|
|
current_user: User = Depends(get_current_user),
|
|
key_manager: "APIKeyManager | None" = Depends(get_key_manager),
|
|
):
|
|
service = _get_batch_service(key_manager=key_manager, user_id=str(current_user.id))
|
|
service.set_user_context(str(current_user.id))
|
|
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
|
|
engine_types = [_parse_engine(e) for e in engine_list]
|
|
comp_names = (
|
|
[n.strip() for n in competitor_names.split(",") if n.strip()]
|
|
if competitor_names
|
|
else None
|
|
)
|
|
return await _execute_batch(
|
|
service, engine_types, query, brand_name, comp_names
|
|
)
|