geo/backend/app/services/ai_engine/doubao.py

115 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "doubao-pro-4k"
_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
class DoubaoAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
endpoint_id: str | None = None,
rate_limiter=None,
proxy: str | None = None,
key_manager=None,
user_id: str | None = None,
):
super().__init__(
api_key=api_key,
rate_limiter=rate_limiter,
proxy=proxy,
key_manager=key_manager,
user_id=user_id,
)
self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "")
self._base_url = _DEFAULT_BASE_URL.rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.DOUBAO
def _get_env_key(self) -> str | None:
return os.getenv("DOUBAO_API_KEY", "")
def _load_proxy(self) -> str | None:
return os.getenv("DOUBAO_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
def _get_model_id(self) -> str:
if self._endpoint_id and self._endpoint_id.strip():
if not self._endpoint_id.startswith("ep-"):
return f"ep-{self._endpoint_id}"
return self._endpoint_id
return _DEFAULT_MODEL
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
model_id = self._get_model_id()
messages = [
{
"role": "system",
"content": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
},
{"role": "user", "content": query},
]
payload = {
"model": model_id,
"messages": messages,
"temperature": 0.7,
"max_tokens": 2000,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
usage = data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
logger.info(
f"[doubao] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", model_id), "usage": usage},
input_tokens=input_tokens,
output_tokens=output_tokens,
)