geo/backend/app/services/ai_engine/batch_query.py

56 lines
1.9 KiB
Python

import asyncio
import logging
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
class BatchQueryService:
def __init__(self, adapters: dict[str, AIEngineAdapter]):
self.adapters = adapters
async def query_single(
self,
engine_type: EngineType,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
adapter = self.adapters.get(engine_type.value)
if not adapter:
raise ValueError(f"Unknown engine type: {engine_type}")
return await adapter.query(query, brand_name, competitor_names)
async def query_batch(
self,
engines: list[EngineType],
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> list[AIQueryResult]:
tasks = [
self.query_single(engine, query, brand_name, competitor_names)
for engine in engines
]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful: list[AIQueryResult] = []
for r in results:
if isinstance(r, AIQueryResult):
successful.append(r)
elif isinstance(r, Exception):
logger.warning(f"Engine query failed: {r}")
return successful
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
total = len(results)
brand_cited = sum(1 for r in results if r.has_brand_citation)
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
return {
"total_engines": total,
"brand_citation_count": brand_cited,
"brand_citation_rate": brand_cited / total if total > 0 else 0,
"competitor_citation_count": competitor_cited,
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
}