geo/backend/app/services/ai_engine/base.py

206 lines
6.0 KiB
Python

import asyncio
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
import httpx
from app.services.api_key_manager import APIKeyManager
logger = logging.getLogger(__name__)
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class EngineType(str, Enum):
CHATGPT = "chatgpt"
PERPLEXITY = "perplexity"
KIMI = "kimi"
WENXIN = "wenxin"
DOUBAO = "doubao"
DEEPSEEK = "deepseek"
QWEN = "qwen"
GEMINI = "gemini"
YUANBAO = "yuanbao"
@dataclass
class CitationInfo:
source_url: str | None
source_title: str | None
citation_context: str
confidence: float
position: int
@dataclass
class AIQueryResult:
engine_type: EngineType
query: str
raw_response: str
citations: list[CitationInfo]
has_brand_citation: bool
has_competitor_citation: bool
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
timestamp: datetime
metadata: dict[str, Any] = field(default_factory=dict)
input_tokens: int = 0
output_tokens: int = 0
@property
def total_tokens(self) -> int:
return self.input_tokens + self.output_tokens
class AIEngineAdapter(ABC):
def __init__(
self,
api_key: str | None = None,
rate_limiter=None,
proxy: str | None = None,
key_manager: APIKeyManager | None = None,
user_id: str | None = None,
):
self._key_manager = key_manager
self._user_id = user_id
self.api_key = self._resolve_api_key(api_key, key_manager, user_id)
self.rate_limiter = rate_limiter
self.proxy = proxy or self._load_proxy()
self._client: httpx.AsyncClient | None = None
def _load_proxy(self) -> str | None:
return os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
def _resolve_api_key(
self,
direct_key: str | None,
key_manager: APIKeyManager | None,
user_id: str | None,
) -> str:
if direct_key and direct_key.strip():
return direct_key
if key_manager:
key = key_manager.get_key(self.get_engine_type().value, user_id=user_id)
if key:
return key
return self._get_env_key() or ""
@abstractmethod
def _get_env_key(self) -> str | None:
pass
@abstractmethod
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
pass
@abstractmethod
def get_engine_type(self) -> EngineType:
pass
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(**self._client_kwargs())
return self._client
def _client_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
}
if self.proxy:
kwargs["proxy"] = self.proxy
return kwargs
def _detect_citations(
self,
response: str,
brand_name: str,
competitor_names: list[str] | None,
) -> tuple[bool, bool, str | None, list[str]]:
has_brand = brand_name.lower() in response.lower()
brand_context = None
if has_brand:
idx = response.lower().find(brand_name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(brand_name) + 100)
brand_context = response[start:end]
has_competitor = False
competitor_contexts = []
if competitor_names:
for name in competitor_names:
if name.lower() in response.lower():
has_competitor = True
idx = response.lower().find(name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(name) + 100)
competitor_contexts.append(response[start:end])
return has_brand, has_competitor, brand_context, competitor_contexts
async def _request_with_retry(self, payload: dict) -> dict:
if self.rate_limiter:
await self.rate_limiter.acquire()
client = await self._get_client()
engine_name = self.get_engine_type().value
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
logger.warning(
f"[{engine_name}] HTTP {response.status_code}, "
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
)
last_error = Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
await asyncio.sleep(wait)
continue
raise Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
except httpx.TransportError as exc:
logger.warning(
f"[{engine_name}] Transport error: {exc}, "
f"retry {attempt + 1}/{_MAX_RETRIES}"
)
last_error = Exception(f"Network error: {exc}")
await asyncio.sleep(2**attempt)
continue
raise last_error or Exception("Max retries exceeded")
async def close(self) -> None:
if self._client:
await self._client.aclose()
async def __aenter__(self) -> "AIEngineAdapter":
return self
async def __aexit__(self, *exc) -> None:
await self.close()