206 lines
6.0 KiB
Python
206 lines
6.0 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from app.services.api_key_manager import APIKeyManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MAX_RETRIES = 3
|
|
_RETRYABLE_STATUS = {429, 500, 502, 503}
|
|
|
|
|
|
class EngineType(str, Enum):
|
|
CHATGPT = "chatgpt"
|
|
PERPLEXITY = "perplexity"
|
|
KIMI = "kimi"
|
|
WENXIN = "wenxin"
|
|
DOUBAO = "doubao"
|
|
DEEPSEEK = "deepseek"
|
|
QWEN = "qwen"
|
|
GEMINI = "gemini"
|
|
YUANBAO = "yuanbao"
|
|
|
|
|
|
@dataclass
|
|
class CitationInfo:
|
|
source_url: str | None
|
|
source_title: str | None
|
|
citation_context: str
|
|
confidence: float
|
|
position: int
|
|
|
|
|
|
@dataclass
|
|
class AIQueryResult:
|
|
engine_type: EngineType
|
|
query: str
|
|
raw_response: str
|
|
citations: list[CitationInfo]
|
|
has_brand_citation: bool
|
|
has_competitor_citation: bool
|
|
brand_context: str | None
|
|
competitor_contexts: list[str]
|
|
response_time_ms: int
|
|
timestamp: datetime
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
input_tokens: int = 0
|
|
output_tokens: int = 0
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
return self.input_tokens + self.output_tokens
|
|
|
|
|
|
class AIEngineAdapter(ABC):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
rate_limiter=None,
|
|
proxy: str | None = None,
|
|
key_manager: APIKeyManager | None = None,
|
|
user_id: str | None = None,
|
|
):
|
|
self._key_manager = key_manager
|
|
self._user_id = user_id
|
|
self.api_key = self._resolve_api_key(api_key, key_manager, user_id)
|
|
self.rate_limiter = rate_limiter
|
|
self.proxy = proxy or self._load_proxy()
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
def _load_proxy(self) -> str | None:
|
|
return os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
|
|
|
def _resolve_api_key(
|
|
self,
|
|
direct_key: str | None,
|
|
key_manager: APIKeyManager | None,
|
|
user_id: str | None,
|
|
) -> str:
|
|
if direct_key and direct_key.strip():
|
|
return direct_key
|
|
|
|
if key_manager:
|
|
key = key_manager.get_key(self.get_engine_type().value, user_id=user_id)
|
|
if key:
|
|
return key
|
|
|
|
return self._get_env_key() or ""
|
|
|
|
@abstractmethod
|
|
def _get_env_key(self) -> str | None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_engine_type(self) -> EngineType:
|
|
pass
|
|
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(**self._client_kwargs())
|
|
return self._client
|
|
|
|
def _client_kwargs(self) -> dict[str, Any]:
|
|
kwargs: dict[str, Any] = {
|
|
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
|
}
|
|
if self.proxy:
|
|
kwargs["proxy"] = self.proxy
|
|
return kwargs
|
|
|
|
def _detect_citations(
|
|
self,
|
|
response: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None,
|
|
) -> tuple[bool, bool, str | None, list[str]]:
|
|
has_brand = brand_name.lower() in response.lower()
|
|
brand_context = None
|
|
if has_brand:
|
|
idx = response.lower().find(brand_name.lower())
|
|
start = max(0, idx - 100)
|
|
end = min(len(response), idx + len(brand_name) + 100)
|
|
brand_context = response[start:end]
|
|
|
|
has_competitor = False
|
|
competitor_contexts = []
|
|
if competitor_names:
|
|
for name in competitor_names:
|
|
if name.lower() in response.lower():
|
|
has_competitor = True
|
|
idx = response.lower().find(name.lower())
|
|
start = max(0, idx - 100)
|
|
end = min(len(response), idx + len(name) + 100)
|
|
competitor_contexts.append(response[start:end])
|
|
|
|
return has_brand, has_competitor, brand_context, competitor_contexts
|
|
|
|
async def _request_with_retry(self, payload: dict) -> dict:
|
|
if self.rate_limiter:
|
|
await self.rate_limiter.acquire()
|
|
|
|
client = await self._get_client()
|
|
engine_name = self.get_engine_type().value
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(_MAX_RETRIES):
|
|
try:
|
|
response = await client.post(self._endpoint, json=payload)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
|
|
if response.status_code in _RETRYABLE_STATUS:
|
|
retry_after = response.headers.get("retry-after")
|
|
wait = float(retry_after) if retry_after else 2**attempt
|
|
logger.warning(
|
|
f"[{engine_name}] HTTP {response.status_code}, "
|
|
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
|
|
)
|
|
last_error = Exception(
|
|
f"HTTP {response.status_code}: {response.text[:300]}"
|
|
)
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
|
|
raise Exception(
|
|
f"HTTP {response.status_code}: {response.text[:300]}"
|
|
)
|
|
|
|
except httpx.TransportError as exc:
|
|
logger.warning(
|
|
f"[{engine_name}] Transport error: {exc}, "
|
|
f"retry {attempt + 1}/{_MAX_RETRIES}"
|
|
)
|
|
last_error = Exception(f"Network error: {exc}")
|
|
await asyncio.sleep(2**attempt)
|
|
continue
|
|
|
|
raise last_error or Exception("Max retries exceeded")
|
|
|
|
async def close(self) -> None:
|
|
if self._client:
|
|
await self._client.aclose()
|
|
|
|
async def __aenter__(self) -> "AIEngineAdapter":
|
|
return self
|
|
|
|
async def __aexit__(self, *exc) -> None:
|
|
await self.close()
|