178 lines
5.5 KiB
Python
178 lines
5.5 KiB
Python
import logging
|
|
from functools import lru_cache
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.models.user import User
|
|
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
|
from app.services.ai_engine.batch_query import BatchQueryService
|
|
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
|
from app.services.ai_engine.doubao import DoubaoAdapter
|
|
from app.services.ai_engine.kimi import KimiAdapter
|
|
from app.services.ai_engine.perplexity import PerplexityAdapter
|
|
from app.services.ai_engine.wenxin import WenxinAdapter
|
|
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class SingleQueryRequest(BaseModel):
|
|
engine: str
|
|
query: str = Field(min_length=1, max_length=500)
|
|
brand_name: str = Field(min_length=1, max_length=200)
|
|
competitor_names: list[str] | None = None
|
|
|
|
|
|
class BatchQueryRequest(BaseModel):
|
|
engines: list[str] = Field(min_length=1)
|
|
query: str = Field(min_length=1, max_length=500)
|
|
brand_name: str = Field(min_length=1, max_length=200)
|
|
competitor_names: list[str] | None = None
|
|
|
|
|
|
class QueryResultResponse(BaseModel):
|
|
engine_type: str
|
|
query: str
|
|
raw_response: str
|
|
has_brand_citation: bool
|
|
has_competitor_citation: bool
|
|
brand_context: str | None
|
|
competitor_contexts: list[str]
|
|
response_time_ms: int
|
|
|
|
model_config = {"from_attributes": True}
|
|
|
|
|
|
class CitationRateResponse(BaseModel):
|
|
total_engines: int
|
|
brand_citation_count: int
|
|
brand_citation_rate: float
|
|
competitor_citation_count: int
|
|
competitor_citation_rate: float
|
|
|
|
|
|
class BatchQueryResponse(BaseModel):
|
|
results: list[QueryResultResponse]
|
|
citation_rate: CitationRateResponse
|
|
|
|
|
|
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
|
|
EngineType.CHATGPT: ChatGPTAdapter,
|
|
EngineType.PERPLEXITY: PerplexityAdapter,
|
|
EngineType.KIMI: KimiAdapter,
|
|
EngineType.WENXIN: WenxinAdapter,
|
|
EngineType.DOUBAO: DoubaoAdapter,
|
|
EngineType.YUANBAO: YuanbaoAdapter,
|
|
}
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _build_adapters() -> dict[str, AIEngineAdapter]:
|
|
adapters: dict[str, AIEngineAdapter] = {}
|
|
for engine_type, cls in _ADAPTER_CLASSES.items():
|
|
try:
|
|
adapters[engine_type.value] = cls()
|
|
except Exception:
|
|
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
|
return adapters
|
|
|
|
|
|
def get_batch_service() -> BatchQueryService:
|
|
return BatchQueryService(_build_adapters())
|
|
|
|
|
|
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
|
return QueryResultResponse(
|
|
engine_type=r.engine_type.value,
|
|
query=r.query,
|
|
raw_response=r.raw_response,
|
|
has_brand_citation=r.has_brand_citation,
|
|
has_competitor_citation=r.has_competitor_citation,
|
|
brand_context=r.brand_context,
|
|
competitor_contexts=r.competitor_contexts,
|
|
response_time_ms=r.response_time_ms,
|
|
)
|
|
|
|
|
|
def _parse_engine(value: str) -> EngineType:
|
|
try:
|
|
return EngineType(value)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unknown engine: {value}",
|
|
)
|
|
|
|
|
|
async def _execute_batch(
|
|
service: BatchQueryService,
|
|
engine_types: list[EngineType],
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None,
|
|
) -> BatchQueryResponse:
|
|
results = await service.query_batch(
|
|
engine_types, query, brand_name, competitor_names
|
|
)
|
|
citation_rate = service.calculate_citation_rate(results)
|
|
return BatchQueryResponse(
|
|
results=[_result_to_response(r) for r in results],
|
|
citation_rate=CitationRateResponse(**citation_rate),
|
|
)
|
|
|
|
|
|
@router.post("/query", response_model=QueryResultResponse)
|
|
async def query_single_engine(
|
|
request: SingleQueryRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
service = get_batch_service()
|
|
engine_type = _parse_engine(request.engine)
|
|
try:
|
|
result = await service.query_single(
|
|
engine_type, request.query, request.brand_name, request.competitor_names
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
)
|
|
return _result_to_response(result)
|
|
|
|
|
|
@router.post("/query-batch", response_model=BatchQueryResponse)
|
|
async def query_batch_engines(
|
|
request: BatchQueryRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
service = get_batch_service()
|
|
engine_types = [_parse_engine(e) for e in request.engines]
|
|
return await _execute_batch(
|
|
service, engine_types, request.query, request.brand_name, request.competitor_names
|
|
)
|
|
|
|
|
|
@router.get("/results", response_model=BatchQueryResponse)
|
|
async def get_query_results(
|
|
engines: str = Query(..., description="Comma-separated engine names"),
|
|
query: str = Query(..., min_length=1, max_length=500),
|
|
brand_name: str = Query(..., min_length=1, max_length=200),
|
|
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
service = get_batch_service()
|
|
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
|
|
engine_types = [_parse_engine(e) for e in engine_list]
|
|
comp_names = (
|
|
[n.strip() for n in competitor_names.split(",") if n.strip()]
|
|
if competitor_names
|
|
else None
|
|
)
|
|
return await _execute_batch(
|
|
service, engine_types, query, brand_name, comp_names
|
|
)
|