185 lines
6.0 KiB
Python
185 lines
6.0 KiB
Python
import logging
|
||
import os
|
||
import time
|
||
from datetime import UTC, datetime
|
||
|
||
import httpx
|
||
|
||
from app.services.api_key_manager import APIKeyManager, KeyCredentials
|
||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_DEFAULT_MODEL = "completions_pro"
|
||
_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||
_CHAT_URL_TEMPLATE = (
|
||
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}"
|
||
"?access_token={token}"
|
||
)
|
||
|
||
_cached_token: str | None = None
|
||
_token_expires_at: float = 0.0
|
||
|
||
|
||
class WenxinAdapter(AIEngineAdapter):
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
secret_key: str | None = None,
|
||
rate_limiter=None,
|
||
proxy: str | None = None,
|
||
key_manager: APIKeyManager | None = None,
|
||
user_id: str | None = None,
|
||
):
|
||
self._key_manager = key_manager
|
||
self._user_id = user_id
|
||
self.rate_limiter = rate_limiter
|
||
self.proxy = proxy or self._load_proxy()
|
||
self.api_key = api_key or ""
|
||
self.secret_key = secret_key or ""
|
||
self._resolve_keys_from_manager(api_key, secret_key, key_manager, user_id)
|
||
self._model = _DEFAULT_MODEL
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||
)
|
||
|
||
def _resolve_keys_from_manager(
|
||
self,
|
||
direct_api_key: str | None,
|
||
direct_secret_key: str | None,
|
||
key_manager: APIKeyManager | None,
|
||
user_id: str | None,
|
||
) -> None:
|
||
if direct_api_key and direct_api_key.strip():
|
||
return
|
||
if not key_manager:
|
||
return
|
||
creds = key_manager.get_credentials("wenxin", user_id=user_id)
|
||
if creds:
|
||
if not self.api_key:
|
||
self.api_key = creds.api_key
|
||
if not self.secret_key:
|
||
self.secret_key = creds.secret_key or ""
|
||
|
||
def get_engine_type(self) -> EngineType:
|
||
return EngineType.WENXIN
|
||
|
||
def _get_env_key(self) -> str | None:
|
||
return os.getenv("BAIDU_QIANFAN_API_KEY", "")
|
||
|
||
def _get_env_secret_key(self) -> str | None:
|
||
return os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||
|
||
def _load_proxy(self) -> str | None:
|
||
return os.getenv("BAIDU_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||
|
||
async def _get_access_token(self) -> str:
|
||
global _cached_token, _token_expires_at
|
||
|
||
now = time.monotonic()
|
||
if _cached_token and now < _token_expires_at:
|
||
return _cached_token
|
||
|
||
response = await self._client.post(
|
||
_TOKEN_URL,
|
||
params={
|
||
"grant_type": "client_credentials",
|
||
"client_id": self.api_key,
|
||
"client_secret": self.secret_key,
|
||
},
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise RuntimeError(
|
||
f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}"
|
||
)
|
||
|
||
data = response.json()
|
||
token = data.get("access_token")
|
||
if not token:
|
||
error_desc = data.get("error_description", "未知错误")
|
||
raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}")
|
||
|
||
expires_in = data.get("expires_in", 2592000)
|
||
_cached_token = token
|
||
_token_expires_at = now + expires_in - 300
|
||
|
||
logger.info("[wenxin] access_token 获取成功")
|
||
return token
|
||
|
||
async def query(
|
||
self,
|
||
query: str,
|
||
brand_name: str,
|
||
competitor_names: list[str] | None = None,
|
||
) -> AIQueryResult:
|
||
start_time = time.perf_counter()
|
||
|
||
access_token = await self._get_access_token()
|
||
chat_url = _CHAT_URL_TEMPLATE.format(
|
||
model=self._model,
|
||
token=access_token,
|
||
)
|
||
|
||
payload = {
|
||
"messages": [{"role": "user", "content": query}],
|
||
"system": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||
"temperature": 0.7,
|
||
"max_output_tokens": 2000,
|
||
}
|
||
|
||
if self.rate_limiter:
|
||
await self.rate_limiter.acquire()
|
||
|
||
response = await self._client.post(chat_url, json=payload)
|
||
|
||
if response.status_code == 429:
|
||
raise RuntimeError("文心一言 API 限流")
|
||
|
||
if response.status_code != 200:
|
||
error_body = response.text[:500]
|
||
raise RuntimeError(
|
||
f"文心一言 API 返回错误 {response.status_code}: {error_body}"
|
||
)
|
||
|
||
data = response.json()
|
||
|
||
error_code = data.get("error_code")
|
||
if error_code:
|
||
error_msg = data.get("error_msg", "未知错误")
|
||
raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}")
|
||
|
||
content = data.get("result", "")
|
||
if not content:
|
||
raise RuntimeError("文心一言 API 返回空内容")
|
||
|
||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||
content, brand_name, competitor_names
|
||
)
|
||
|
||
usage = data.get("usage", {})
|
||
input_tokens = usage.get("prompt_tokens", 0)
|
||
output_tokens = usage.get("completion_tokens", 0)
|
||
|
||
logger.info(
|
||
f"[wenxin] query='{query[:50]}...' brand={has_brand} "
|
||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||
)
|
||
|
||
return AIQueryResult(
|
||
engine_type=self.get_engine_type(),
|
||
query=query,
|
||
raw_response=content,
|
||
citations=[],
|
||
has_brand_citation=has_brand,
|
||
has_competitor_citation=has_comp,
|
||
brand_context=brand_ctx,
|
||
competitor_contexts=comp_ctx,
|
||
response_time_ms=elapsed_ms,
|
||
timestamp=datetime.now(UTC),
|
||
metadata={"model": self._model, "usage": usage},
|
||
input_tokens=input_tokens,
|
||
output_tokens=output_tokens,
|
||
)
|