geo/backend/app/services/ai_engine/wenxin.py

164 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "completions_pro"
_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
_CHAT_URL_TEMPLATE = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}"
"?access_token={token}"
)
_cached_token: str | None = None
_token_expires_at: float = 0.0
class WenxinAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
secret_key: str | None = None,
rate_limiter=None,
proxy: str | None = None,
key_manager=None,
user_id: str | None = None,
):
super().__init__(
api_key=api_key,
rate_limiter=rate_limiter,
proxy=proxy,
key_manager=key_manager,
user_id=user_id,
)
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
self._model = _DEFAULT_MODEL
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
)
def get_engine_type(self) -> EngineType:
return EngineType.WENXIN
def _get_env_key(self) -> str | None:
return os.getenv("BAIDU_QIANFAN_API_KEY", "")
def _load_proxy(self) -> str | None:
return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
async def _get_access_token(self) -> str:
global _cached_token, _token_expires_at
now = time.monotonic()
if _cached_token and now < _token_expires_at:
return _cached_token
response = await self._client.post(
_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key,
},
)
if response.status_code != 200:
raise RuntimeError(
f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}"
)
data = response.json()
token = data.get("access_token")
if not token:
error_desc = data.get("error_description", "未知错误")
raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}")
expires_in = data.get("expires_in", 2592000)
_cached_token = token
_token_expires_at = now + expires_in - 300
logger.info("[wenxin] access_token 获取成功")
return token
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
access_token = await self._get_access_token()
chat_url = _CHAT_URL_TEMPLATE.format(
model=self._model,
token=access_token,
)
payload = {
"messages": [{"role": "user", "content": query}],
"system": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
"temperature": 0.7,
"max_output_tokens": 2000,
}
if self.rate_limiter:
await self.rate_limiter.acquire()
response = await self._client.post(chat_url, json=payload)
if response.status_code == 429:
raise RuntimeError("文心一言 API 限流")
if response.status_code != 200:
error_body = response.text[:500]
raise RuntimeError(
f"文心一言 API 返回错误 {response.status_code}: {error_body}"
)
data = response.json()
error_code = data.get("error_code")
if error_code:
error_msg = data.get("error_msg", "未知错误")
raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}")
content = data.get("result", "")
if not content:
raise RuntimeError("文心一言 API 返回空内容")
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
usage = data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
logger.info(
f"[wenxin] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": self._model, "usage": usage},
input_tokens=input_tokens,
output_tokens=output_tokens,
)