import asyncio import logging import os import time from datetime import UTC, datetime import httpx from .base import AIEngineAdapter, AIQueryResult, EngineType logger = logging.getLogger(__name__) _DEFAULT_MODEL = "gemini-pro" _DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" _MAX_RETRIES = 3 _RETRYABLE_STATUS = {429, 500, 502, 503} class GeminiAdapter(AIEngineAdapter): def __init__( self, api_key: str | None = None, model: str | None = None, rate_limiter=None, proxy: str | None = None, key_manager=None, user_id: str | None = None, ): super().__init__( api_key=api_key, rate_limiter=rate_limiter, proxy=proxy, key_manager=key_manager, user_id=user_id, ) self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL) self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/") self._proxy = proxy or os.getenv("GEMINI_PROXY") self._endpoint = ( f"{self._base_url}/models/{self._model}:generateContent" f"?key={self.api_key}" ) client_kwargs = { "timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), "headers": {"Content-Type": "application/json"}, } if self._proxy: client_kwargs["proxy"] = self._proxy self._client = httpx.AsyncClient(**client_kwargs) def get_engine_type(self) -> EngineType: return EngineType.GEMINI def _get_env_key(self) -> str | None: return os.getenv("GOOGLE_API_KEY", "") def _load_proxy(self) -> str | None: return os.getenv("GOOGLE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") async def _request_with_retry(self, payload: dict) -> dict: if self.rate_limiter: await self.rate_limiter.acquire() last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: response = await self._client.post(self._endpoint, json=payload) if response.status_code == 200: return response.json() if response.status_code in _RETRYABLE_STATUS: retry_after = response.headers.get("retry-after") wait = float(retry_after) if retry_after else 2**attempt logger.warning( f"[gemini] HTTP {response.status_code}, " f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s" ) last_error = Exception( f"HTTP {response.status_code}: {response.text[:300]}" ) await asyncio.sleep(wait) continue raise RuntimeError( f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}" ) except httpx.TransportError as exc: logger.warning( f"[gemini] Transport error: {exc}, " f"retry {attempt + 1}/{_MAX_RETRIES}" ) last_error = RuntimeError(f"Gemini 网络错误: {exc}") await asyncio.sleep(2**attempt) continue raise last_error or RuntimeError("Gemini API 超过最大重试次数") async def query( self, query: str, brand_name: str, competitor_names: list[str] | None = None, ) -> AIQueryResult: start_time = time.perf_counter() payload = { "contents": [{"parts": [{"text": query}]}], "generationConfig": {"temperature": 0.7}, } data = await self._request_with_retry(payload) candidates = data.get("candidates", []) if not candidates: raise RuntimeError("Gemini API 返回空候选内容") parts = candidates[0].get("content", {}).get("parts", []) if not parts: raise RuntimeError("Gemini API 返回空内容") content = parts[0].get("text", "") elapsed_ms = int((time.perf_counter() - start_time) * 1000) has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( content, brand_name, competitor_names ) usage_metadata = data.get("usageMetadata", {}) input_tokens = usage_metadata.get("promptTokenCount", 0) output_tokens = usage_metadata.get("candidatesTokenCount", 0) logger.info( f"[gemini] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" ) return AIQueryResult( engine_type=self.get_engine_type(), query=query, raw_response=content, citations=[], has_brand_citation=has_brand, has_competitor_citation=has_comp, brand_context=brand_ctx, competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), metadata={"model": self._model, "usage": usage_metadata}, input_tokens=input_tokens, output_tokens=output_tokens, )