geo/backend/app/services/ai_engine/base.py

148 lines
4.4 KiB
Python

import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
import httpx
logger = logging.getLogger(__name__)
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class EngineType(str, Enum):
CHATGPT = "chatgpt"
PERPLEXITY = "perplexity"
KIMI = "kimi"
WENXIN = "wenxin"
DOUBAO = "doubao"
DEEPSEEK = "deepseek"
QWEN = "qwen"
@dataclass
class CitationInfo:
source_url: str | None
source_title: str | None
citation_context: str
confidence: float
position: int
@dataclass
class AIQueryResult:
engine_type: EngineType
query: str
raw_response: str
citations: list[CitationInfo]
has_brand_citation: bool
has_competitor_citation: bool
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
timestamp: datetime
metadata: dict[str, Any] = field(default_factory=dict)
class AIEngineAdapter(ABC):
def __init__(self, api_key: str, rate_limiter=None):
self.api_key = api_key
self.rate_limiter = rate_limiter
self._client: httpx.AsyncClient | None = None
@abstractmethod
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
pass
@abstractmethod
def get_engine_type(self) -> EngineType:
pass
def _detect_citations(
self,
response: str,
brand_name: str,
competitor_names: list[str] | None,
) -> tuple[bool, bool, str | None, list[str]]:
has_brand = brand_name.lower() in response.lower()
brand_context = None
if has_brand:
idx = response.lower().find(brand_name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(brand_name) + 100)
brand_context = response[start:end]
has_competitor = False
competitor_contexts = []
if competitor_names:
for name in competitor_names:
if name.lower() in response.lower():
has_competitor = True
idx = response.lower().find(name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(name) + 100)
competitor_contexts.append(response[start:end])
return has_brand, has_competitor, brand_context, competitor_contexts
async def _request_with_retry(self, payload: dict) -> dict:
if self.rate_limiter:
await self.rate_limiter.acquire()
engine_name = self.get_engine_type().value
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await self._client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
logger.warning(
f"[{engine_name}] HTTP {response.status_code}, "
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
)
last_error = Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
await asyncio.sleep(wait)
continue
raise Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
except httpx.TransportError as exc:
logger.warning(
f"[{engine_name}] Transport error: {exc}, "
f"retry {attempt + 1}/{_MAX_RETRIES}"
)
last_error = Exception(f"Network error: {exc}")
await asyncio.sleep(2**attempt)
continue
raise last_error or Exception("Max retries exceeded")
async def close(self) -> None:
if self._client:
await self._client.aclose()
async def __aenter__(self) -> "AIEngineAdapter":
return self
async def __aexit__(self, *exc) -> None:
await self.close()