geo/backend/app/services/ai_engine/gemini.py

155 lines
5.2 KiB
Python

import asyncio
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "gemini-pro"
_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class GeminiAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
rate_limiter=None,
proxy: str | None = None,
key_manager=None,
user_id: str | None = None,
):
super().__init__(
api_key=api_key,
rate_limiter=rate_limiter,
proxy=proxy,
key_manager=key_manager,
user_id=user_id,
)
self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL)
self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/")
self._proxy = proxy or os.getenv("GEMINI_PROXY")
self._endpoint = (
f"{self._base_url}/models/{self._model}:generateContent"
f"?key={self.api_key}"
)
client_kwargs = {
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
"headers": {"Content-Type": "application/json"},
}
if self._proxy:
client_kwargs["proxy"] = self._proxy
self._client = httpx.AsyncClient(**client_kwargs)
def get_engine_type(self) -> EngineType:
return EngineType.GEMINI
def _get_env_key(self) -> str | None:
return os.getenv("GOOGLE_API_KEY", "")
def _load_proxy(self) -> str | None:
return os.getenv("GOOGLE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
async def _request_with_retry(self, payload: dict) -> dict:
if self.rate_limiter:
await self.rate_limiter.acquire()
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await self._client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
logger.warning(
f"[gemini] HTTP {response.status_code}, "
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
)
last_error = Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
await asyncio.sleep(wait)
continue
raise RuntimeError(
f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}"
)
except httpx.TransportError as exc:
logger.warning(
f"[gemini] Transport error: {exc}, "
f"retry {attempt + 1}/{_MAX_RETRIES}"
)
last_error = RuntimeError(f"Gemini 网络错误: {exc}")
await asyncio.sleep(2**attempt)
continue
raise last_error or RuntimeError("Gemini API 超过最大重试次数")
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
payload = {
"contents": [{"parts": [{"text": query}]}],
"generationConfig": {"temperature": 0.7},
}
data = await self._request_with_retry(payload)
candidates = data.get("candidates", [])
if not candidates:
raise RuntimeError("Gemini API 返回空候选内容")
parts = candidates[0].get("content", {}).get("parts", [])
if not parts:
raise RuntimeError("Gemini API 返回空内容")
content = parts[0].get("text", "")
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
usage_metadata = data.get("usageMetadata", {})
input_tokens = usage_metadata.get("promptTokenCount", 0)
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
logger.info(
f"[gemini] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": self._model, "usage": usage_metadata},
input_tokens=input_tokens,
output_tokens=output_tokens,
)