import logging import os import time from datetime import UTC, datetime import httpx from .base import AIEngineAdapter, AIQueryResult, EngineType logger = logging.getLogger(__name__) _DEFAULT_MODEL = "completions_pro" _TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" _CHAT_URL_TEMPLATE = ( "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}" "?access_token={token}" ) _cached_token: str | None = None _token_expires_at: float = 0.0 class WenxinAdapter(AIEngineAdapter): def __init__( self, api_key: str | None = None, secret_key: str | None = None, rate_limiter=None, proxy: str | None = None, key_manager=None, user_id: str | None = None, ): super().__init__( api_key=api_key, rate_limiter=rate_limiter, proxy=proxy, key_manager=key_manager, user_id=user_id, ) self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "") self._model = _DEFAULT_MODEL self._client = httpx.AsyncClient( timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), ) def get_engine_type(self) -> EngineType: return EngineType.WENXIN def _get_env_key(self) -> str | None: return os.getenv("BAIDU_QIANFAN_API_KEY", "") def _load_proxy(self) -> str | None: return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") async def _get_access_token(self) -> str: global _cached_token, _token_expires_at now = time.monotonic() if _cached_token and now < _token_expires_at: return _cached_token response = await self._client.post( _TOKEN_URL, params={ "grant_type": "client_credentials", "client_id": self.api_key, "client_secret": self.secret_key, }, ) if response.status_code != 200: raise RuntimeError( f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}" ) data = response.json() token = data.get("access_token") if not token: error_desc = data.get("error_description", "未知错误") raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}") expires_in = data.get("expires_in", 2592000) _cached_token = token _token_expires_at = now + expires_in - 300 logger.info("[wenxin] access_token 获取成功") return token async def query( self, query: str, brand_name: str, competitor_names: list[str] | None = None, ) -> AIQueryResult: start_time = time.perf_counter() access_token = await self._get_access_token() chat_url = _CHAT_URL_TEMPLATE.format( model=self._model, token=access_token, ) payload = { "messages": [{"role": "user", "content": query}], "system": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", "temperature": 0.7, "max_output_tokens": 2000, } if self.rate_limiter: await self.rate_limiter.acquire() response = await self._client.post(chat_url, json=payload) if response.status_code == 429: raise RuntimeError("文心一言 API 限流") if response.status_code != 200: error_body = response.text[:500] raise RuntimeError( f"文心一言 API 返回错误 {response.status_code}: {error_body}" ) data = response.json() error_code = data.get("error_code") if error_code: error_msg = data.get("error_msg", "未知错误") raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}") content = data.get("result", "") if not content: raise RuntimeError("文心一言 API 返回空内容") elapsed_ms = int((time.perf_counter() - start_time) * 1000) has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( content, brand_name, competitor_names ) usage = data.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) logger.info( f"[wenxin] query='{query[:50]}...' brand={has_brand} " f"competitor={has_comp} time={elapsed_ms}ms" ) return AIQueryResult( engine_type=self.get_engine_type(), query=query, raw_response=content, citations=[], has_brand_citation=has_brand, has_competitor_citation=has_comp, brand_context=brand_ctx, competitor_contexts=comp_ctx, response_time_ms=elapsed_ms, timestamp=datetime.now(UTC), metadata={"model": self._model, "usage": usage}, input_tokens=input_tokens, output_tokens=output_tokens, )