91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
from datetime import UTC, datetime
|
|
|
|
import httpx
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_MODEL = "deepseek-chat"
|
|
_DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
|
|
|
|
|
class DeepSeekAdapter(AIEngineAdapter):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str | None = None,
|
|
base_url: str | None = None,
|
|
rate_limiter=None,
|
|
proxy: str | None = None,
|
|
):
|
|
super().__init__(
|
|
api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""),
|
|
rate_limiter=rate_limiter,
|
|
proxy=proxy or os.getenv("DEEPSEEK_PROXY"),
|
|
)
|
|
self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL)
|
|
self._base_url = (
|
|
base_url or os.getenv("DEEPSEEK_BASE_URL", _DEFAULT_BASE_URL)
|
|
).rstrip("/")
|
|
self._endpoint = f"{self._base_url}/chat/completions"
|
|
self._client = httpx.AsyncClient(
|
|
**self._client_kwargs(),
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
|
|
def get_engine_type(self) -> EngineType:
|
|
return EngineType.DEEPSEEK
|
|
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
start_time = time.perf_counter()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": query},
|
|
]
|
|
payload = {
|
|
"model": self._model,
|
|
"messages": messages,
|
|
"temperature": 0.7,
|
|
"max_tokens": 4096,
|
|
}
|
|
|
|
data = await self._request_with_retry(payload)
|
|
content = data["choices"][0]["message"]["content"]
|
|
|
|
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
|
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
|
content, brand_name, competitor_names
|
|
)
|
|
|
|
logger.info(
|
|
f"[deepseek] query='{query[:50]}...' brand={has_brand} "
|
|
f"competitor={has_comp} time={elapsed_ms}ms"
|
|
)
|
|
|
|
return AIQueryResult(
|
|
engine_type=self.get_engine_type(),
|
|
query=query,
|
|
raw_response=content,
|
|
citations=[],
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_comp,
|
|
brand_context=brand_ctx,
|
|
competitor_contexts=comp_ctx,
|
|
response_time_ms=elapsed_ms,
|
|
timestamp=datetime.now(UTC),
|
|
metadata={"model": data.get("model", self._model)},
|
|
)
|