155 lines
5.2 KiB
Python
155 lines
5.2 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from datetime import UTC, datetime
|
|
|
|
import httpx
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_MODEL = "gemini-pro"
|
|
_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
|
_MAX_RETRIES = 3
|
|
_RETRYABLE_STATUS = {429, 500, 502, 503}
|
|
|
|
|
|
class GeminiAdapter(AIEngineAdapter):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str | None = None,
|
|
rate_limiter=None,
|
|
proxy: str | None = None,
|
|
key_manager=None,
|
|
user_id: str | None = None,
|
|
):
|
|
super().__init__(
|
|
api_key=api_key,
|
|
rate_limiter=rate_limiter,
|
|
proxy=proxy,
|
|
key_manager=key_manager,
|
|
user_id=user_id,
|
|
)
|
|
self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL)
|
|
self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/")
|
|
self._proxy = proxy or os.getenv("GEMINI_PROXY")
|
|
self._endpoint = (
|
|
f"{self._base_url}/models/{self._model}:generateContent"
|
|
f"?key={self.api_key}"
|
|
)
|
|
client_kwargs = {
|
|
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
|
"headers": {"Content-Type": "application/json"},
|
|
}
|
|
if self._proxy:
|
|
client_kwargs["proxy"] = self._proxy
|
|
self._client = httpx.AsyncClient(**client_kwargs)
|
|
|
|
def get_engine_type(self) -> EngineType:
|
|
return EngineType.GEMINI
|
|
|
|
def _get_env_key(self) -> str | None:
|
|
return os.getenv("GOOGLE_API_KEY", "")
|
|
|
|
def _load_proxy(self) -> str | None:
|
|
return os.getenv("GOOGLE_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
|
|
|
async def _request_with_retry(self, payload: dict) -> dict:
|
|
if self.rate_limiter:
|
|
await self.rate_limiter.acquire()
|
|
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(_MAX_RETRIES):
|
|
try:
|
|
response = await self._client.post(self._endpoint, json=payload)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
|
|
if response.status_code in _RETRYABLE_STATUS:
|
|
retry_after = response.headers.get("retry-after")
|
|
wait = float(retry_after) if retry_after else 2**attempt
|
|
logger.warning(
|
|
f"[gemini] HTTP {response.status_code}, "
|
|
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
|
|
)
|
|
last_error = Exception(
|
|
f"HTTP {response.status_code}: {response.text[:300]}"
|
|
)
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
|
|
raise RuntimeError(
|
|
f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}"
|
|
)
|
|
|
|
except httpx.TransportError as exc:
|
|
logger.warning(
|
|
f"[gemini] Transport error: {exc}, "
|
|
f"retry {attempt + 1}/{_MAX_RETRIES}"
|
|
)
|
|
last_error = RuntimeError(f"Gemini 网络错误: {exc}")
|
|
await asyncio.sleep(2**attempt)
|
|
continue
|
|
|
|
raise last_error or RuntimeError("Gemini API 超过最大重试次数")
|
|
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
start_time = time.perf_counter()
|
|
|
|
payload = {
|
|
"contents": [{"parts": [{"text": query}]}],
|
|
"generationConfig": {"temperature": 0.7},
|
|
}
|
|
|
|
data = await self._request_with_retry(payload)
|
|
|
|
candidates = data.get("candidates", [])
|
|
if not candidates:
|
|
raise RuntimeError("Gemini API 返回空候选内容")
|
|
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
if not parts:
|
|
raise RuntimeError("Gemini API 返回空内容")
|
|
|
|
content = parts[0].get("text", "")
|
|
|
|
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
|
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
|
content, brand_name, competitor_names
|
|
)
|
|
|
|
usage_metadata = data.get("usageMetadata", {})
|
|
input_tokens = usage_metadata.get("promptTokenCount", 0)
|
|
output_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
|
|
|
logger.info(
|
|
f"[gemini] query='{query[:50]}...' brand={has_brand} "
|
|
f"competitor={has_comp} time={elapsed_ms}ms"
|
|
)
|
|
|
|
return AIQueryResult(
|
|
engine_type=self.get_engine_type(),
|
|
query=query,
|
|
raw_response=content,
|
|
citations=[],
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_comp,
|
|
brand_context=brand_ctx,
|
|
competitor_contexts=comp_ctx,
|
|
response_time_ms=elapsed_ms,
|
|
timestamp=datetime.now(UTC),
|
|
metadata={"model": self._model, "usage": usage_metadata},
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
)
|