56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchQueryService:
|
|
def __init__(self, adapters: dict[str, AIEngineAdapter]):
|
|
self.adapters = adapters
|
|
|
|
async def query_single(
|
|
self,
|
|
engine_type: EngineType,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
adapter = self.adapters.get(engine_type.value)
|
|
if not adapter:
|
|
raise ValueError(f"Unknown engine type: {engine_type}")
|
|
return await adapter.query(query, brand_name, competitor_names)
|
|
|
|
async def query_batch(
|
|
self,
|
|
engines: list[EngineType],
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> list[AIQueryResult]:
|
|
tasks = [
|
|
self.query_single(engine, query, brand_name, competitor_names)
|
|
for engine in engines
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
successful: list[AIQueryResult] = []
|
|
for r in results:
|
|
if isinstance(r, AIQueryResult):
|
|
successful.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.warning(f"Engine query failed: {r}")
|
|
return successful
|
|
|
|
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
|
|
total = len(results)
|
|
brand_cited = sum(1 for r in results if r.has_brand_citation)
|
|
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
|
|
return {
|
|
"total_engines": total,
|
|
"brand_citation_count": brand_cited,
|
|
"brand_citation_rate": brand_cited / total if total > 0 else 0,
|
|
"competitor_citation_count": competitor_cited,
|
|
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
|
|
}
|