126 lines
3.9 KiB
Python
126 lines
3.9 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
from datetime import UTC, datetime
|
|
|
|
import httpx
|
|
|
|
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_MODEL = "pplx-70b-online"
|
|
_DEFAULT_BASE_URL = "https://api.perplexity.ai"
|
|
|
|
|
|
class PerplexityAdapter(AIEngineAdapter):
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str | None = None,
|
|
base_url: str | None = None,
|
|
rate_limiter=None,
|
|
proxy: str | None = None,
|
|
key_manager=None,
|
|
user_id: str | None = None,
|
|
):
|
|
super().__init__(
|
|
api_key=api_key,
|
|
rate_limiter=rate_limiter,
|
|
proxy=proxy,
|
|
key_manager=key_manager,
|
|
user_id=user_id,
|
|
)
|
|
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
|
|
self._base_url = (
|
|
base_url or os.getenv("PERPLEXITY_BASE_URL", _DEFAULT_BASE_URL)
|
|
).rstrip("/")
|
|
self._endpoint = f"{self._base_url}/chat/completions"
|
|
self._client = httpx.AsyncClient(
|
|
**self._client_kwargs(),
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
|
|
def get_engine_type(self) -> EngineType:
|
|
return EngineType.PERPLEXITY
|
|
|
|
def _get_env_key(self) -> str | None:
|
|
return os.getenv("PERPLEXITY_API_KEY", "")
|
|
|
|
def _load_proxy(self) -> str | None:
|
|
return os.getenv("PERPLEXITY_PROXY") or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
|
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
brand_name: str,
|
|
competitor_names: list[str] | None = None,
|
|
) -> AIQueryResult:
|
|
start_time = time.perf_counter()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful research assistant."},
|
|
{"role": "user", "content": query},
|
|
]
|
|
payload = {
|
|
"model": self._model,
|
|
"messages": messages,
|
|
"temperature": 0.7,
|
|
"max_tokens": 4096,
|
|
}
|
|
|
|
data = await self._request_with_retry(payload)
|
|
content = data["choices"][0]["message"]["content"]
|
|
citations = self._extract_citations(data)
|
|
|
|
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
|
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
|
content, brand_name, competitor_names
|
|
)
|
|
|
|
usage = data.get("usage", {})
|
|
input_tokens = usage.get("prompt_tokens", 0)
|
|
output_tokens = usage.get("completion_tokens", 0)
|
|
|
|
logger.info(
|
|
f"[perplexity] query='{query[:50]}...' brand={has_brand} "
|
|
f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms"
|
|
)
|
|
|
|
return AIQueryResult(
|
|
engine_type=self.get_engine_type(),
|
|
query=query,
|
|
raw_response=content,
|
|
citations=citations,
|
|
has_brand_citation=has_brand,
|
|
has_competitor_citation=has_comp,
|
|
brand_context=brand_ctx,
|
|
competitor_contexts=comp_ctx,
|
|
response_time_ms=elapsed_ms,
|
|
timestamp=datetime.now(UTC),
|
|
metadata={"model": data.get("model", self._model)},
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
)
|
|
|
|
def _extract_citations(self, data: dict) -> list[CitationInfo]:
|
|
raw_citations = data.get("citations", [])
|
|
if not raw_citations:
|
|
return []
|
|
|
|
citations = []
|
|
for idx, cit in enumerate(raw_citations):
|
|
citations.append(
|
|
CitationInfo(
|
|
source_url=cit.get("url"),
|
|
source_title=cit.get("title"),
|
|
citation_context="",
|
|
confidence=1.0,
|
|
position=idx + 1,
|
|
)
|
|
)
|
|
return citations
|