geo/backend/app/services/ai_engine/batch_query.py

162 lines
5.2 KiB
Python

import asyncio
import logging
from functools import lru_cache
from typing import TYPE_CHECKING
from .base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.usage_recorder import UsageRecorder
from app.services.usage_tracker import UsageTracker
if TYPE_CHECKING:
from app.services.api_key_manager import APIKeyManager
logger = logging.getLogger(__name__)
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {}
def register_adapter(cls: type[AIEngineAdapter]) -> None:
engine_type = None
try:
temp = cls()
engine_type = temp.get_engine_type()
_ADAPTER_CLASSES[engine_type] = cls
except Exception as e:
logger.warning(f"Failed to register adapter: {e}")
from .chatgpt import ChatGPTAdapter
from .perplexity import PerplexityAdapter
from .kimi import KimiAdapter
from .wenxin import WenxinAdapter
from .doubao import DoubaoAdapter
from .yuanbao import YuanbaoAdapter
from .deepseek import DeepSeekAdapter
from .qwen import QwenAdapter
from .gemini import GeminiAdapter
register_adapter(ChatGPTAdapter)
register_adapter(PerplexityAdapter)
register_adapter(KimiAdapter)
register_adapter(WenxinAdapter)
register_adapter(DoubaoAdapter)
register_adapter(YuanbaoAdapter)
register_adapter(DeepSeekAdapter)
register_adapter(QwenAdapter)
register_adapter(GeminiAdapter)
def get_batch_service(
key_manager: "APIKeyManager | None" = None,
user_id: str | None = None,
) -> "BatchQueryService":
if key_manager:
adapters = _build_adapters_with_key_manager(key_manager=key_manager, user_id=user_id)
else:
adapters = _build_adapters()
return BatchQueryService(adapters)
@lru_cache(maxsize=1)
def _build_adapters() -> dict[str, AIEngineAdapter]:
adapters: dict[str, AIEngineAdapter] = {}
for engine_type, cls in _ADAPTER_CLASSES.items():
try:
adapters[engine_type.value] = cls()
except Exception:
logger.warning(f"Failed to initialize {engine_type.value} adapter")
return adapters
def _build_adapters_with_key_manager(
key_manager: "APIKeyManager | None" = None,
user_id: str | None = None,
) -> dict[str, AIEngineAdapter]:
adapters: dict[str, AIEngineAdapter] = {}
for engine_type, cls in _ADAPTER_CLASSES.items():
try:
adapters[engine_type.value] = cls(
key_manager=key_manager,
user_id=user_id,
)
except Exception:
logger.warning(f"Failed to initialize {engine_type.value} adapter")
return adapters
class BatchQueryService:
def __init__(self, adapters: dict[str, AIEngineAdapter]):
self.adapters = adapters
self._tracker = UsageTracker()
self._recorder = UsageRecorder(self._tracker)
self._user_id: str | None = None
self._brand_id: str | None = None
def set_user_context(self, user_id: str, brand_id: str | None = None) -> None:
self._user_id = user_id
self._brand_id = brand_id
def get_usage_summary(self):
return self._tracker.get_summary(user_id=self._user_id)
async def query_single(
self,
engine_type: EngineType,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
adapter = self.adapters.get(engine_type.value)
if not adapter:
raise ValueError(f"Unknown engine type: {engine_type}")
result = await adapter.query(query, brand_name, competitor_names)
if self._user_id:
self._recorder.record(
user_id=self._user_id,
brand_id=self._brand_id,
engine_type=engine_type.value,
query=query,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens,
metadata={
"brand_name": brand_name,
"response_time_ms": result.response_time_ms,
},
)
return result
async def query_batch(
self,
engines: list[EngineType],
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> list[AIQueryResult]:
tasks = [
self.query_single(engine, query, brand_name, competitor_names)
for engine in engines
]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful: list[AIQueryResult] = []
for r in results:
if isinstance(r, AIQueryResult):
successful.append(r)
elif isinstance(r, Exception):
logger.warning(f"Engine query failed: {r}")
return successful
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
total = len(results)
brand_cited = sum(1 for r in results if r.has_brand_citation)
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
return {
"total_engines": total,
"brand_citation_count": brand_cited,
"brand_citation_rate": brand_cited / total if total > 0 else 0,
"competitor_citation_count": competitor_cited,
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
}