import asyncio import logging from functools import lru_cache from typing import TYPE_CHECKING import httpx from .base import AIEngineAdapter, AIQueryResult, EngineType from app.services.usage_recorder import UsageRecorder from app.services.usage_tracker import UsageTracker if TYPE_CHECKING: from app.services.api_key_manager import APIKeyManager logger = logging.getLogger(__name__) _ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {} def register_adapter(cls: type[AIEngineAdapter]) -> None: engine_type = None try: temp = cls() engine_type = temp.get_engine_type() _ADAPTER_CLASSES[engine_type] = cls except Exception as e: logger.warning(f"Failed to register adapter: {e}") from .chatgpt import ChatGPTAdapter from .perplexity import PerplexityAdapter from .kimi import KimiAdapter from .wenxin import WenxinAdapter from .doubao import DoubaoAdapter from .yuanbao import YuanbaoAdapter from .deepseek import DeepSeekAdapter from .qwen import QwenAdapter from .gemini import GeminiAdapter register_adapter(ChatGPTAdapter) register_adapter(PerplexityAdapter) register_adapter(KimiAdapter) register_adapter(WenxinAdapter) register_adapter(DoubaoAdapter) register_adapter(YuanbaoAdapter) register_adapter(DeepSeekAdapter) register_adapter(QwenAdapter) register_adapter(GeminiAdapter) def get_batch_service( key_manager: "APIKeyManager | None" = None, user_id: str | None = None, ) -> "BatchQueryService": if key_manager: adapters = _build_adapters_with_key_manager(key_manager=key_manager, user_id=user_id) else: adapters = _build_adapters() return BatchQueryService(adapters) @lru_cache(maxsize=1) def _build_adapters() -> dict[str, AIEngineAdapter]: adapters: dict[str, AIEngineAdapter] = {} for engine_type, cls in _ADAPTER_CLASSES.items(): try: adapters[engine_type.value] = cls() except httpx.HTTPError as e: logger.warning(f"HTTP error from {engine_type.value}: {e}") except asyncio.TimeoutError as e: logger.warning(f"Timeout from {engine_type.value}: {e}") except Exception as e: logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True) return adapters def _build_adapters_with_key_manager( key_manager: "APIKeyManager | None" = None, user_id: str | None = None, ) -> dict[str, AIEngineAdapter]: adapters: dict[str, AIEngineAdapter] = {} for engine_type, cls in _ADAPTER_CLASSES.items(): try: adapters[engine_type.value] = cls( key_manager=key_manager, user_id=user_id, ) except httpx.HTTPError as e: logger.warning(f"HTTP error from {engine_type.value}: {e}") except asyncio.TimeoutError as e: logger.warning(f"Timeout from {engine_type.value}: {e}") except Exception as e: logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True) return adapters class BatchQueryService: def __init__(self, adapters: dict[str, AIEngineAdapter]): self.adapters = adapters self._tracker = UsageTracker() self._recorder = UsageRecorder(self._tracker) self._user_id: str | None = None self._brand_id: str | None = None def set_user_context(self, user_id: str, brand_id: str | None = None) -> None: self._user_id = user_id self._brand_id = brand_id def get_usage_summary(self): return self._tracker.get_summary(user_id=self._user_id) async def query_single( self, engine_type: EngineType, query: str, brand_name: str, competitor_names: list[str] | None = None, ) -> AIQueryResult: adapter = self.adapters.get(engine_type.value) if not adapter: raise ValueError(f"Unknown engine type: {engine_type}") result = await adapter.query(query, brand_name, competitor_names) if self._user_id: try: self._recorder.record( user_id=self._user_id, brand_id=self._brand_id, engine_type=engine_type.value, query=query, input_tokens=result.input_tokens, output_tokens=result.output_tokens, metadata={ "brand_name": brand_name, "response_time_ms": result.response_time_ms, }, ) except Exception as e: logger.warning(f"Failed to record usage: {e}") return result async def query_batch( self, engines: list[EngineType], query: str, brand_name: str, competitor_names: list[str] | None = None, ) -> list[AIQueryResult]: tasks = [ self.query_single(engine, query, brand_name, competitor_names) for engine in engines ] results = await asyncio.gather(*tasks, return_exceptions=True) successful: list[AIQueryResult] = [] for r in results: if isinstance(r, AIQueryResult): successful.append(r) elif isinstance(r, Exception): logger.warning(f"Engine query failed: {r}") return successful def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict: total = len(results) brand_cited = sum(1 for r in results if r.has_brand_citation) competitor_cited = sum(1 for r in results if r.has_competitor_citation) return { "total_engines": total, "brand_citation_count": brand_cited, "brand_citation_rate": brand_cited / total if total > 0 else 0, "competitor_citation_count": competitor_cited, "competitor_citation_rate": competitor_cited / total if total > 0 else 0, }