110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
import logging
|
||
import os
|
||
import time
|
||
from datetime import UTC, datetime
|
||
|
||
import httpx
|
||
|
||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_DEFAULT_MODEL = "hunyuan-lite"
|
||
_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
|
||
|
||
|
||
class YuanbaoAdapter(AIEngineAdapter):
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
model: str | None = None,
|
||
base_url: str | None = None,
|
||
rate_limiter=None,
|
||
proxy: str | None = None,
|
||
key_manager=None,
|
||
user_id: str | None = None,
|
||
):
|
||
super().__init__(
|
||
api_key=api_key,
|
||
rate_limiter=rate_limiter,
|
||
proxy=proxy,
|
||
key_manager=key_manager,
|
||
user_id=user_id,
|
||
)
|
||
self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL)
|
||
self._base_url = (
|
||
base_url or os.getenv("HUNYUAN_BASE_URL", _DEFAULT_BASE_URL)
|
||
).rstrip("/")
|
||
self._endpoint = f"{self._base_url}/chat/completions"
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||
headers={
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
|
||
def get_engine_type(self) -> EngineType:
|
||
return EngineType.YUANBAO
|
||
|
||
def _get_env_key(self) -> str | None:
|
||
return os.getenv("HUNYUAN_API_KEY", "")
|
||
|
||
def _load_proxy(self) -> str | None:
|
||
return os.getenv("HUNYUAN_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||
|
||
async def query(
|
||
self,
|
||
query: str,
|
||
brand_name: str,
|
||
competitor_names: list[str] | None = None,
|
||
) -> AIQueryResult:
|
||
start_time = time.perf_counter()
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||
},
|
||
{"role": "user", "content": query},
|
||
]
|
||
payload = {
|
||
"model": self._model,
|
||
"messages": messages,
|
||
"temperature": 0.7,
|
||
"max_tokens": 2000,
|
||
}
|
||
|
||
data = await self._request_with_retry(payload)
|
||
content = data["choices"][0]["message"]["content"]
|
||
|
||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||
content, brand_name, competitor_names
|
||
)
|
||
|
||
usage = data.get("usage", {})
|
||
input_tokens = usage.get("prompt_tokens", 0)
|
||
output_tokens = usage.get("completion_tokens", 0)
|
||
|
||
logger.info(
|
||
f"[yuanbao] query='{query[:50]}...' brand={has_brand} "
|
||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||
)
|
||
|
||
return AIQueryResult(
|
||
engine_type=self.get_engine_type(),
|
||
query=query,
|
||
raw_response=content,
|
||
citations=[],
|
||
has_brand_citation=has_brand,
|
||
has_competitor_citation=has_comp,
|
||
brand_context=brand_ctx,
|
||
competitor_contexts=comp_ctx,
|
||
response_time_ms=elapsed_ms,
|
||
timestamp=datetime.now(UTC),
|
||
metadata={"model": data.get("model", self._model), "usage": usage},
|
||
input_tokens=input_tokens,
|
||
output_tokens=output_tokens,
|
||
)
|