175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
import asyncio
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import TYPE_CHECKING
|
|
|
|
import httpx
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
|
from app.services.usage_recorder import UsageRecorder
|
|
from app.services.usage_tracker import UsageTracker
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.api_key_manager import APIKeyManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {}
|
|
|
|
|
|
def register_adapter(cls: type[AIEngineAdapter]) -> None:
|
|
engine_type = None
|
|
try:
|
|
temp = cls()
|
|
engine_type = temp.get_engine_type()
|
|
_ADAPTER_CLASSES[engine_type] = cls
|
|
except Exception as e:
|
|
logger.warning(f"Failed to register adapter: {e}")
|
|
|
|
|
|
from .chatgpt import ChatGPTAdapter
|
|
from .perplexity import PerplexityAdapter
|
|
from .kimi import KimiAdapter
|
|
from .wenxin import WenxinAdapter
|
|
from .doubao import DoubaoAdapter
|
|
from .yuanbao import YuanbaoAdapter
|
|
from .deepseek import DeepSeekAdapter
|
|
from .qwen import QwenAdapter
|
|
from .gemini import GeminiAdapter
|
|
|
|
register_adapter(ChatGPTAdapter)
|
|
register_adapter(PerplexityAdapter)
|
|
register_adapter(KimiAdapter)
|
|
register_adapter(WenxinAdapter)
|
|
register_adapter(DoubaoAdapter)
|
|
register_adapter(YuanbaoAdapter)
|
|
register_adapter(DeepSeekAdapter)
|
|
register_adapter(QwenAdapter)
|
|
register_adapter(GeminiAdapter)
|
|
|
|
|
|
def get_batch_service(
|
|
key_manager: "APIKeyManager | None" = None,
|
|
user_id: str | None = None,
|
|
) -> "BatchQueryService":
|
|
if key_manager:
|
|
adapters = _build_adapters_with_key_manager(key_manager=key_manager, user_id=user_id)
|
|
else:
|
|
adapters = _build_adapters()
|
|
return BatchQueryService(adapters)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _build_adapters() -> dict[str, AIEngineAdapter]:
|
|
adapters: dict[str, AIEngineAdapter] = {}
|
|
for engine_type, cls in _ADAPTER_CLASSES.items():
|
|
try:
|
|
adapters[engine_type.value] = cls()
|
|
except httpx.HTTPError as e:
|
|
logger.warning(f"HTTP error from {engine_type.value}: {e}")
|
|
except asyncio.TimeoutError as e:
|
|
logger.warning(f"Timeout from {engine_type.value}: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
|
|
return adapters
|
|
|
|
|
|
def _build_adapters_with_key_manager(
|
|
key_manager: "APIKeyManager | None" = None,
|
|
user_id: str | None = None,
|
|
) -> dict[str, AIEngineAdapter]:
|
|
adapters: dict[str, AIEngineAdapter] = {}
|
|
for engine_type, cls in _ADAPTER_CLASSES.items():
|
|
try:
|
|
adapters[engine_type.value] = cls(
|
|
key_manager=key_manager,
|
|
user_id=user_id,
|
|
)
|
|
except httpx.HTTPError as e:
|
|
logger.warning(f"HTTP error from {engine_type.value}: {e}")
|
|
except asyncio.TimeoutError as e:
|
|
logger.warning(f"Timeout from {engine_type.value}: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error from {engine_type.value}: {e}", exc_info=True)
|
|
return adapters
|
|
|
|
|
|
class BatchQueryService:
|
|
def __init__(self, adapters: dict[str, AIEngineAdapter]):
|
|
self.adapters = adapters
|
|
self._tracker = UsageTracker()
|
|
self._recorder = UsageRecorder(self._tracker)
|
|
self._user_id: str | None = None
|
|
self._brand_id: str | None = None
|
|
|
|
def set_user_context(self, user_id: str, brand_id: str | None = None) -> None:
|
|
self._user_id = user_id
|
|
self._brand_id = brand_id
|
|
|
|
def get_usage_summary(self):
|
|
return self._tracker.get_summary(user_id=self._user_id)
|
|
|
|
async def query_single(
|
|
self,
|
|
engine_type: EngineType,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
adapter = self.adapters.get(engine_type.value)
|
|
if not adapter:
|
|
raise ValueError(f"Unknown engine type: {engine_type}")
|
|
|
|
result = await adapter.query(query, brand_name, competitor_names)
|
|
|
|
if self._user_id:
|
|
try:
|
|
self._recorder.record(
|
|
user_id=self._user_id,
|
|
brand_id=self._brand_id,
|
|
engine_type=engine_type.value,
|
|
query=query,
|
|
input_tokens=result.input_tokens,
|
|
output_tokens=result.output_tokens,
|
|
metadata={
|
|
"brand_name": brand_name,
|
|
"response_time_ms": result.response_time_ms,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record usage: {e}")
|
|
|
|
return result
|
|
|
|
async def query_batch(
|
|
self,
|
|
engines: list[EngineType],
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> list[AIQueryResult]:
|
|
tasks = [
|
|
self.query_single(engine, query, brand_name, competitor_names)
|
|
for engine in engines
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
successful: list[AIQueryResult] = []
|
|
for r in results:
|
|
if isinstance(r, AIQueryResult):
|
|
successful.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.warning(f"Engine query failed: {r}")
|
|
return successful
|
|
|
|
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
|
|
total = len(results)
|
|
brand_cited = sum(1 for r in results if r.has_brand_citation)
|
|
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
|
|
return {
|
|
"total_engines": total,
|
|
"brand_citation_count": brand_cited,
|
|
"brand_citation_rate": brand_cited / total if total > 0 else 0,
|
|
"competitor_citation_count": competitor_cited,
|
|
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
|
|
}
|