diff --git a/backend/app/api/ai_engines.py b/backend/app/api/ai_engines.py new file mode 100644 index 0000000..029ecbb --- /dev/null +++ b/backend/app/api/ai_engines.py @@ -0,0 +1,175 @@ +import logging +from functools import lru_cache + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel, Field + +from app.api.deps import get_current_user +from app.models.user import User +from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType +from app.services.ai_engine.batch_query import BatchQueryService +from app.services.ai_engine.chatgpt import ChatGPTAdapter +from app.services.ai_engine.perplexity import PerplexityAdapter +from app.services.ai_engine.kimi import KimiAdapter +from app.services.ai_engine.wenxin import WenxinAdapter +from app.services.ai_engine.doubao import DoubaoAdapter + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +class SingleQueryRequest(BaseModel): + engine: str + query: str = Field(min_length=1, max_length=500) + brand_name: str = Field(min_length=1, max_length=200) + competitor_names: list[str] | None = None + + +class BatchQueryRequest(BaseModel): + engines: list[str] = Field(min_length=1) + query: str = Field(min_length=1, max_length=500) + brand_name: str = Field(min_length=1, max_length=200) + competitor_names: list[str] | None = None + + +class QueryResultResponse(BaseModel): + engine_type: str + query: str + raw_response: str + has_brand_citation: bool + has_competitor_citation: bool + brand_context: str | None + competitor_contexts: list[str] + response_time_ms: int + + model_config = {"from_attributes": True} + + +class CitationRateResponse(BaseModel): + total_engines: int + brand_citation_count: int + brand_citation_rate: float + competitor_citation_count: int + competitor_citation_rate: float + + +class BatchQueryResponse(BaseModel): + results: list[QueryResultResponse] + citation_rate: CitationRateResponse + + +_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = { + EngineType.CHATGPT: ChatGPTAdapter, + EngineType.PERPLEXITY: PerplexityAdapter, + EngineType.KIMI: KimiAdapter, + EngineType.WENXIN: WenxinAdapter, + EngineType.DOUBAO: DoubaoAdapter, +} + + +@lru_cache(maxsize=1) +def _build_adapters() -> dict[str, AIEngineAdapter]: + adapters: dict[str, AIEngineAdapter] = {} + for engine_type, cls in _ADAPTER_CLASSES.items(): + try: + adapters[engine_type.value] = cls() + except Exception: + logger.warning(f"Failed to initialize {engine_type.value} adapter") + return adapters + + +def get_batch_service() -> BatchQueryService: + return BatchQueryService(_build_adapters()) + + +def _result_to_response(r: AIQueryResult) -> QueryResultResponse: + return QueryResultResponse( + engine_type=r.engine_type.value, + query=r.query, + raw_response=r.raw_response, + has_brand_citation=r.has_brand_citation, + has_competitor_citation=r.has_competitor_citation, + brand_context=r.brand_context, + competitor_contexts=r.competitor_contexts, + response_time_ms=r.response_time_ms, + ) + + +def _parse_engine(value: str) -> EngineType: + try: + return EngineType(value) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown engine: {value}", + ) + + +async def _execute_batch( + service: BatchQueryService, + engine_types: list[EngineType], + query: str, + brand_name: str, + competitor_names: list[str] | None, +) -> BatchQueryResponse: + results = await service.query_batch( + engine_types, query, brand_name, competitor_names + ) + citation_rate = service.calculate_citation_rate(results) + return BatchQueryResponse( + results=[_result_to_response(r) for r in results], + citation_rate=CitationRateResponse(**citation_rate), + ) + + +@router.post("/query", response_model=QueryResultResponse) +async def query_single_engine( + request: SingleQueryRequest, + current_user: User = Depends(get_current_user), +): + service = get_batch_service() + engine_type = _parse_engine(request.engine) + try: + result = await service.query_single( + engine_type, request.query, request.brand_name, request.competitor_names + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + return _result_to_response(result) + + +@router.post("/query-batch", response_model=BatchQueryResponse) +async def query_batch_engines( + request: BatchQueryRequest, + current_user: User = Depends(get_current_user), +): + service = get_batch_service() + engine_types = [_parse_engine(e) for e in request.engines] + return await _execute_batch( + service, engine_types, request.query, request.brand_name, request.competitor_names + ) + + +@router.get("/results", response_model=BatchQueryResponse) +async def get_query_results( + engines: str = Query(..., description="Comma-separated engine names"), + query: str = Query(..., min_length=1, max_length=500), + brand_name: str = Query(..., min_length=1, max_length=200), + competitor_names: str | None = Query(None, description="Comma-separated competitor names"), + current_user: User = Depends(get_current_user), +): + service = get_batch_service() + engine_list = [e.strip() for e in engines.split(",") if e.strip()] + engine_types = [_parse_engine(e) for e in engine_list] + comp_names = ( + [n.strip() for n in competitor_names.split(",") if n.strip()] + if competitor_names + else None + ) + return await _execute_batch( + service, engine_types, query, brand_name, comp_names + ) diff --git a/backend/app/main.py b/backend/app/main.py index 304bbe8..2337bb9 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -38,6 +38,7 @@ from app.api.platforms import router as platforms_router from app.api.platform_rules import router as platform_rules_router from app.api.image import router as image_router from app.api.knowledge_graph import router as knowledge_graph_router +from app.api.ai_engines import router as ai_engines_router from app.config import settings from app.database import engine, Base from app.schemas.common import ErrorResponse, ErrorCode @@ -165,6 +166,7 @@ app.include_router(platforms_router, prefix="/api/v1") app.include_router(platform_rules_router) app.include_router(image_router, prefix="/api/v1") app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases") +app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"]) @app.get("/health", tags=["可观测性"]) diff --git a/backend/app/services/ai_engine/__init__.py b/backend/app/services/ai_engine/__init__.py new file mode 100644 index 0000000..4d80b3f --- /dev/null +++ b/backend/app/services/ai_engine/__init__.py @@ -0,0 +1,20 @@ +from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType +from .chatgpt import ChatGPTAdapter +from .perplexity import PerplexityAdapter +from .kimi import KimiAdapter +from .wenxin import WenxinAdapter +from .doubao import DoubaoAdapter +from .batch_query import BatchQueryService + +__all__ = [ + "AIEngineAdapter", + "AIQueryResult", + "CitationInfo", + "EngineType", + "ChatGPTAdapter", + "PerplexityAdapter", + "KimiAdapter", + "WenxinAdapter", + "DoubaoAdapter", + "BatchQueryService", +] diff --git a/backend/app/services/ai_engine/base.py b/backend/app/services/ai_engine/base.py new file mode 100644 index 0000000..50189f6 --- /dev/null +++ b/backend/app/services/ai_engine/base.py @@ -0,0 +1,147 @@ +import asyncio +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +import httpx + +logger = logging.getLogger(__name__) + +_MAX_RETRIES = 3 +_RETRYABLE_STATUS = {429, 500, 502, 503} + + +class EngineType(str, Enum): + CHATGPT = "chatgpt" + PERPLEXITY = "perplexity" + KIMI = "kimi" + WENXIN = "wenxin" + DOUBAO = "doubao" + DEEPSEEK = "deepseek" + QWEN = "qwen" + + +@dataclass +class CitationInfo: + source_url: str | None + source_title: str | None + citation_context: str + confidence: float + position: int + + +@dataclass +class AIQueryResult: + engine_type: EngineType + query: str + raw_response: str + citations: list[CitationInfo] + has_brand_citation: bool + has_competitor_citation: bool + brand_context: str | None + competitor_contexts: list[str] + response_time_ms: int + timestamp: datetime + metadata: dict[str, Any] = field(default_factory=dict) + + +class AIEngineAdapter(ABC): + def __init__(self, api_key: str, rate_limiter=None): + self.api_key = api_key + self.rate_limiter = rate_limiter + self._client: httpx.AsyncClient | None = None + + @abstractmethod + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + pass + + @abstractmethod + def get_engine_type(self) -> EngineType: + pass + + def _detect_citations( + self, + response: str, + brand_name: str, + competitor_names: list[str] | None, + ) -> tuple[bool, bool, str | None, list[str]]: + has_brand = brand_name.lower() in response.lower() + brand_context = None + if has_brand: + idx = response.lower().find(brand_name.lower()) + start = max(0, idx - 100) + end = min(len(response), idx + len(brand_name) + 100) + brand_context = response[start:end] + + has_competitor = False + competitor_contexts = [] + if competitor_names: + for name in competitor_names: + if name.lower() in response.lower(): + has_competitor = True + idx = response.lower().find(name.lower()) + start = max(0, idx - 100) + end = min(len(response), idx + len(name) + 100) + competitor_contexts.append(response[start:end]) + + return has_brand, has_competitor, brand_context, competitor_contexts + + async def _request_with_retry(self, payload: dict) -> dict: + if self.rate_limiter: + await self.rate_limiter.acquire() + + engine_name = self.get_engine_type().value + last_error: Exception | None = None + + for attempt in range(_MAX_RETRIES): + try: + response = await self._client.post(self._endpoint, json=payload) + + if response.status_code == 200: + return response.json() + + if response.status_code in _RETRYABLE_STATUS: + retry_after = response.headers.get("retry-after") + wait = float(retry_after) if retry_after else 2**attempt + logger.warning( + f"[{engine_name}] HTTP {response.status_code}, " + f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s" + ) + last_error = Exception( + f"HTTP {response.status_code}: {response.text[:300]}" + ) + await asyncio.sleep(wait) + continue + + raise Exception( + f"HTTP {response.status_code}: {response.text[:300]}" + ) + + except httpx.TransportError as exc: + logger.warning( + f"[{engine_name}] Transport error: {exc}, " + f"retry {attempt + 1}/{_MAX_RETRIES}" + ) + last_error = Exception(f"Network error: {exc}") + await asyncio.sleep(2**attempt) + continue + + raise last_error or Exception("Max retries exceeded") + + async def close(self) -> None: + if self._client: + await self._client.aclose() + + async def __aenter__(self) -> "AIEngineAdapter": + return self + + async def __aexit__(self, *exc) -> None: + await self.close() diff --git a/backend/app/services/ai_engine/batch_query.py b/backend/app/services/ai_engine/batch_query.py new file mode 100644 index 0000000..e47be21 --- /dev/null +++ b/backend/app/services/ai_engine/batch_query.py @@ -0,0 +1,55 @@ +import asyncio +import logging + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + + +class BatchQueryService: + def __init__(self, adapters: dict[str, AIEngineAdapter]): + self.adapters = adapters + + async def query_single( + self, + engine_type: EngineType, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + adapter = self.adapters.get(engine_type.value) + if not adapter: + raise ValueError(f"Unknown engine type: {engine_type}") + return await adapter.query(query, brand_name, competitor_names) + + async def query_batch( + self, + engines: list[EngineType], + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> list[AIQueryResult]: + tasks = [ + self.query_single(engine, query, brand_name, competitor_names) + for engine in engines + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + successful: list[AIQueryResult] = [] + for r in results: + if isinstance(r, AIQueryResult): + successful.append(r) + elif isinstance(r, Exception): + logger.warning(f"Engine query failed: {r}") + return successful + + def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict: + total = len(results) + brand_cited = sum(1 for r in results if r.has_brand_citation) + competitor_cited = sum(1 for r in results if r.has_competitor_citation) + return { + "total_engines": total, + "brand_citation_count": brand_cited, + "brand_citation_rate": brand_cited / total if total > 0 else 0, + "competitor_citation_count": competitor_cited, + "competitor_citation_rate": competitor_cited / total if total > 0 else 0, + } diff --git a/backend/app/services/ai_engine/chatgpt.py b/backend/app/services/ai_engine/chatgpt.py new file mode 100644 index 0000000..3b05ffc --- /dev/null +++ b/backend/app/services/ai_engine/chatgpt.py @@ -0,0 +1,88 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "gpt-4o" +_DEFAULT_BASE_URL = "https://api.openai.com/v1" + + +class ChatGPTAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("OPENAI_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL) + self._base_url = ( + base_url or os.getenv("OPENAI_BASE_URL", _DEFAULT_BASE_URL) + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.CHATGPT + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 4096, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[chatgpt] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model)}, + ) diff --git a/backend/app/services/ai_engine/doubao.py b/backend/app/services/ai_engine/doubao.py new file mode 100644 index 0000000..1cd699d --- /dev/null +++ b/backend/app/services/ai_engine/doubao.py @@ -0,0 +1,96 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "doubao-pro-4k" +_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3" + + +class DoubaoAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + endpoint_id: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("DOUBAO_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "") + self._base_url = _DEFAULT_BASE_URL.rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.DOUBAO + + def _get_model_id(self) -> str: + if self._endpoint_id and self._endpoint_id.strip(): + if not self._endpoint_id.startswith("ep-"): + return f"ep-{self._endpoint_id}" + return self._endpoint_id + return _DEFAULT_MODEL + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + model_id = self._get_model_id() + messages = [ + { + "role": "system", + "content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", + }, + {"role": "user", "content": query}, + ] + payload = { + "model": model_id, + "messages": messages, + "temperature": 0.7, + "max_tokens": 2000, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[doubao] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", model_id), "usage": data.get("usage")}, + ) diff --git a/backend/app/services/ai_engine/kimi.py b/backend/app/services/ai_engine/kimi.py new file mode 100644 index 0000000..757718c --- /dev/null +++ b/backend/app/services/ai_engine/kimi.py @@ -0,0 +1,91 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "moonshot-v1-8k" +_DEFAULT_BASE_URL = "https://api.moonshot.cn/v1" + + +class KimiAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("MOONSHOT_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or _DEFAULT_MODEL + self._base_url = ( + base_url or _DEFAULT_BASE_URL + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.KIMI + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + { + "role": "system", + "content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", + }, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 2000, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[kimi] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + ) diff --git a/backend/app/services/ai_engine/perplexity.py b/backend/app/services/ai_engine/perplexity.py new file mode 100644 index 0000000..f9f1001 --- /dev/null +++ b/backend/app/services/ai_engine/perplexity.py @@ -0,0 +1,107 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "pplx-70b-online" +_DEFAULT_BASE_URL = "https://api.perplexity.ai" + + +class PerplexityAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL) + self._base_url = ( + base_url or os.getenv("PERPLEXITY_BASE_URL", _DEFAULT_BASE_URL) + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.PERPLEXITY + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + {"role": "system", "content": "You are a helpful research assistant."}, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 4096, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + citations = self._extract_citations(data) + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[perplexity] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=citations, + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model)}, + ) + + def _extract_citations(self, data: dict) -> list[CitationInfo]: + raw_citations = data.get("citations", []) + if not raw_citations: + return [] + + citations = [] + for idx, cit in enumerate(raw_citations): + citations.append( + CitationInfo( + source_url=cit.get("url"), + source_title=cit.get("title"), + citation_context="", + confidence=1.0, + position=idx + 1, + ) + ) + return citations diff --git a/backend/app/services/ai_engine/wenxin.py b/backend/app/services/ai_engine/wenxin.py new file mode 100644 index 0000000..3160faf --- /dev/null +++ b/backend/app/services/ai_engine/wenxin.py @@ -0,0 +1,145 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "completions_pro" +_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" +_CHAT_URL_TEMPLATE = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}" + "?access_token={token}" +) + +_cached_token: str | None = None +_token_expires_at: float = 0.0 + + +class WenxinAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + secret_key: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("BAIDU_QIANFAN_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "") + self._model = _DEFAULT_MODEL + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + ) + + def get_engine_type(self) -> EngineType: + return EngineType.WENXIN + + async def _get_access_token(self) -> str: + global _cached_token, _token_expires_at + + now = time.monotonic() + if _cached_token and now < _token_expires_at: + return _cached_token + + response = await self._client.post( + _TOKEN_URL, + params={ + "grant_type": "client_credentials", + "client_id": self.api_key, + "client_secret": self.secret_key, + }, + ) + + if response.status_code != 200: + raise RuntimeError( + f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}" + ) + + data = response.json() + token = data.get("access_token") + if not token: + error_desc = data.get("error_description", "未知错误") + raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}") + + expires_in = data.get("expires_in", 2592000) + _cached_token = token + _token_expires_at = now + expires_in - 300 + + logger.info("[wenxin] access_token 获取成功") + return token + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + access_token = await self._get_access_token() + chat_url = _CHAT_URL_TEMPLATE.format( + model=self._model, + token=access_token, + ) + + payload = { + "messages": [{"role": "user", "content": query}], + "system": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", + "temperature": 0.7, + "max_output_tokens": 2000, + } + + if self.rate_limiter: + await self.rate_limiter.acquire() + + response = await self._client.post(chat_url, json=payload) + + if response.status_code == 429: + raise RuntimeError("文心一言 API 限流") + + if response.status_code != 200: + error_body = response.text[:500] + raise RuntimeError( + f"文心一言 API 返回错误 {response.status_code}: {error_body}" + ) + + data = response.json() + + error_code = data.get("error_code") + if error_code: + error_msg = data.get("error_msg", "未知错误") + raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}") + + content = data.get("result", "") + if not content: + raise RuntimeError("文心一言 API 返回空内容") + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[wenxin] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": self._model, "usage": data.get("usage")}, + ) diff --git a/backend/tests/test_api/test_ai_engines_api.py b/backend/tests/test_api/test_ai_engines_api.py new file mode 100644 index 0000000..10cbf8f --- /dev/null +++ b/backend/tests/test_api/test_ai_engines_api.py @@ -0,0 +1,216 @@ +import uuid +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password +from app.services.ai_engine.base import AIQueryResult, EngineType + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +def _make_result( + engine_type: EngineType, + query: str = "best insurance", + has_brand: bool = False, + has_competitor: bool = False, +) -> AIQueryResult: + return AIQueryResult( + engine_type=engine_type, + query=query, + raw_response="BrandX is a great insurance company", + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_competitor, + brand_context="BrandX is great" if has_brand else None, + competitor_contexts=["CompY is ok"] if has_competitor else [], + response_time_ms=150, + timestamp=datetime.now(UTC), + ) + + +class TestSingleQueryEndpoint: + @pytest.mark.asyncio + async def test_query_single_engine(self, async_client): + mock_result = _make_result(EngineType.CHATGPT, has_brand=True) + with patch("app.api.ai_engines.get_batch_service") as mock_get_service: + mock_service = AsyncMock() + mock_service.query_single.return_value = mock_result + mock_get_service.return_value = mock_service + + response = await async_client.post( + "/api/v1/ai-engines/query", + json={ + "engine": "chatgpt", + "query": "best insurance", + "brand_name": "BrandX", + "competitor_names": ["CompY"], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["engine_type"] == "chatgpt" + assert data["has_brand_citation"] is True + assert data["query"] == "best insurance" + + +class TestBatchQueryEndpoint: + @pytest.mark.asyncio + async def test_query_batch_parallel(self, async_client): + r1 = _make_result(EngineType.CHATGPT, has_brand=True) + r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True) + with patch("app.api.ai_engines.get_batch_service") as mock_get_service: + mock_service = AsyncMock() + mock_service.query_batch.return_value = [r1, r2] + mock_service.calculate_citation_rate = MagicMock(return_value={ + "total_engines": 2, + "brand_citation_count": 1, + "brand_citation_rate": 0.5, + "competitor_citation_count": 1, + "competitor_citation_rate": 0.5, + }) + mock_get_service.return_value = mock_service + + response = await async_client.post( + "/api/v1/ai-engines/query-batch", + json={ + "engines": ["chatgpt", "perplexity"], + "query": "best insurance", + "brand_name": "BrandX", + "competitor_names": ["CompY"], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "results" in data + assert "citation_rate" in data + assert len(data["results"]) == 2 + assert data["citation_rate"]["brand_citation_rate"] == 0.5 + + +class TestGetResultsEndpoint: + @pytest.mark.asyncio + async def test_get_results(self, async_client): + r1 = _make_result(EngineType.CHATGPT, has_brand=True) + r2 = _make_result(EngineType.KIMI, has_brand=False) + with patch("app.api.ai_engines.get_batch_service") as mock_get_service: + mock_service = AsyncMock() + mock_service.query_batch.return_value = [r1, r2] + mock_service.calculate_citation_rate = MagicMock(return_value={ + "total_engines": 2, + "brand_citation_count": 1, + "brand_citation_rate": 0.5, + "competitor_citation_count": 0, + "competitor_citation_rate": 0.0, + }) + mock_get_service.return_value = mock_service + + response = await async_client.get( + "/api/v1/ai-engines/results", + params={ + "engines": "chatgpt,kimi", + "query": "best insurance", + "brand_name": "BrandX", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "results" in data + assert "citation_rate" in data + + +class TestUnauthorizedAccess: + @pytest.mark.asyncio + async def test_unauthorized_returns_401(self, async_session): + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer invalid_token"} + response = await client.post( + "/api/v1/ai-engines/query", + json={ + "engine": "chatgpt", + "query": "test", + "brand_name": "BrandX", + }, + headers=headers, + ) + assert response.status_code == 401 + + app.dependency_overrides.clear() diff --git a/backend/tests/test_services/test_ai_engine_chinese.py b/backend/tests/test_services/test_ai_engine_chinese.py new file mode 100644 index 0000000..5e301f8 --- /dev/null +++ b/backend/tests/test_services/test_ai_engine_chinese.py @@ -0,0 +1,270 @@ +import pytest +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +from app.services.ai_engine.base import ( + AIEngineAdapter, + AIQueryResult, + CitationInfo, + EngineType, +) +from app.services.ai_engine.kimi import KimiAdapter +from app.services.ai_engine.wenxin import WenxinAdapter +from app.services.ai_engine.doubao import DoubaoAdapter + + +def _make_mock_response(status_code=200, json_data=None, text="", headers=None): + mock_resp = Mock() + mock_resp.status_code = status_code + mock_resp.json.return_value = json_data or {} + mock_resp.text = text + mock_resp.headers = headers or {} + return mock_resp + + +class TestKimiAdapter: + @pytest.mark.asyncio + async def test_initialization(self): + adapter = KimiAdapter(api_key="test-key") + assert adapter.api_key == "test-key" + assert adapter.get_engine_type() == EngineType.KIMI + + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + adapter = KimiAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "华为是全球领先的ICT公司"}}], + "usage": {"total_tokens": 100}, + "model": "moonshot-v1-8k", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("华为公司", brand_name="华为") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.KIMI + assert "华为" in result.raw_response + assert result.has_brand_citation is True + assert result.metadata.get("model") == "moonshot-v1-8k" + + @pytest.mark.asyncio + async def test_api_error_handling(self): + adapter = KimiAdapter(api_key="test-key") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 500: Internal Server Error"), + ): + with pytest.raises(Exception, match="HTTP 500"): + await adapter.query("测试问题", brand_name="华为") + + @pytest.mark.asyncio + async def test_rate_limit_handling(self): + adapter = KimiAdapter(api_key="test-key") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 429: rate limited"), + ): + with pytest.raises(Exception, match="429"): + await adapter.query("测试问题", brand_name="华为") + + +class TestWenxinAdapter: + @pytest.mark.asyncio + async def test_initialization(self): + adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") + assert adapter.api_key == "test-key" + assert adapter.secret_key == "test-secret" + assert adapter.get_engine_type() == EngineType.WENXIN + + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + import app.services.ai_engine.wenxin as wenxin_mod + wenxin_mod._cached_token = None + wenxin_mod._token_expires_at = 0.0 + + adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") + mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000} + mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}} + + mock_client = AsyncMock() + token_response = _make_mock_response(200, mock_token_data) + chat_response = _make_mock_response(200, mock_chat_data) + mock_client.post.side_effect = [token_response, chat_response] + + with patch.object(adapter, "_client", mock_client): + result = await adapter.query("华为公司", brand_name="华为") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.WENXIN + assert "华为" in result.raw_response + assert result.has_brand_citation is True + + @pytest.mark.asyncio + async def test_api_error_handling(self): + import app.services.ai_engine.wenxin as wenxin_mod + wenxin_mod._cached_token = None + wenxin_mod._token_expires_at = 0.0 + + adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") + + mock_client = AsyncMock() + token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000}) + error_response = _make_mock_response(500, text="Internal Server Error") + mock_client.post.side_effect = [token_response, error_response] + + with patch.object(adapter, "_client", mock_client): + with pytest.raises(RuntimeError, match="文心"): + await adapter.query("测试问题", brand_name="华为") + + @pytest.mark.asyncio + async def test_rate_limit_handling(self): + import app.services.ai_engine.wenxin as wenxin_mod + wenxin_mod._cached_token = None + wenxin_mod._token_expires_at = 0.0 + + adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") + + mock_client = AsyncMock() + token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000}) + rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"}) + mock_client.post.side_effect = [token_response, rate_limit_response] + + with patch.object(adapter, "_client", mock_client): + with pytest.raises(RuntimeError, match="限流"): + await adapter.query("测试问题", brand_name="华为") + + +class TestDoubaoAdapter: + @pytest.mark.asyncio + async def test_initialization(self): + adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") + assert adapter.api_key == "test-key" + assert adapter._endpoint_id == "ep-test" + assert adapter.get_engine_type() == EngineType.DOUBAO + + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") + mock_response_data = { + "choices": [{"message": {"content": "华为是全球知名企业"}}], + "model": "ep-test", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("华为公司", brand_name="华为") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.DOUBAO + assert "华为" in result.raw_response + assert result.has_brand_citation is True + + @pytest.mark.asyncio + async def test_api_error_handling(self): + adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 500: Internal Server Error"), + ): + with pytest.raises(Exception, match="HTTP 500"): + await adapter.query("测试问题", brand_name="华为") + + @pytest.mark.asyncio + async def test_rate_limit_handling(self): + adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 429: rate limited"), + ): + with pytest.raises(Exception, match="429"): + await adapter.query("测试问题", brand_name="华为") + + +class TestEngineType: + def test_kimi_engine_type(self): + adapter = KimiAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.KIMI + assert adapter.get_engine_type().value == "kimi" + + def test_wenxin_engine_type(self): + adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") + assert adapter.get_engine_type() == EngineType.WENXIN + assert adapter.get_engine_type().value == "wenxin" + + def test_doubao_engine_type(self): + adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") + assert adapter.get_engine_type() == EngineType.DOUBAO + assert adapter.get_engine_type().value == "doubao" + + +class TestChineseCitationDetection: + def test_brand_name_detection_chinese(self): + adapter = KimiAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "华为是全球领先的ICT基础设施和智能终端提供商", + brand_name="华为", + competitor_names=None, + ) + assert has_brand is True + assert brand_ctx is not None + assert "华为" in brand_ctx + + def test_brand_name_detection_case_insensitive(self): + adapter = KimiAdapter(api_key="test-key") + has_brand, _, _, _ = adapter._detect_citations( + "Apple is a great company and apple makes phones", + brand_name="apple", + competitor_names=None, + ) + assert has_brand is True + + def test_competitor_name_detection(self): + adapter = KimiAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "华为和小米都是中国知名手机品牌", + brand_name="华为", + competitor_names=["小米"], + ) + assert has_brand is True + assert has_comp is True + assert brand_ctx is not None + assert len(comp_ctx) > 0 + + def test_no_citations_when_no_match(self): + adapter = KimiAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "今天天气很好", + brand_name="华为", + competitor_names=["小米"], + ) + assert has_brand is False + assert has_comp is False + assert brand_ctx is None + assert len(comp_ctx) == 0 + + +class TestAdapterInheritance: + def test_all_adapters_inherit_base(self): + assert issubclass(KimiAdapter, AIEngineAdapter) + assert issubclass(WenxinAdapter, AIEngineAdapter) + assert issubclass(DoubaoAdapter, AIEngineAdapter) + + def test_all_adapters_have_query_method(self): + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + assert hasattr(cls, "query") + assert callable(getattr(cls, "query")) + + def test_all_adapters_have_detect_citations(self): + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + assert hasattr(cls, "_detect_citations") + + def test_all_adapters_have_get_engine_type(self): + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s") + assert instance.get_engine_type() in EngineType diff --git a/backend/tests/test_services/test_ai_engine_query.py b/backend/tests/test_services/test_ai_engine_query.py new file mode 100644 index 0000000..3f94f53 --- /dev/null +++ b/backend/tests/test_services/test_ai_engine_query.py @@ -0,0 +1,529 @@ +import pytest +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.ai_engine.base import ( + AIEngineAdapter, + AIQueryResult, + CitationInfo, + EngineType, +) +from app.services.ai_engine.chatgpt import ChatGPTAdapter +from app.services.ai_engine.perplexity import PerplexityAdapter + + +class TestEngineType: + def test_engine_type_values(self): + assert EngineType.CHATGPT == "chatgpt" + assert EngineType.PERPLEXITY == "perplexity" + assert EngineType.KIMI == "kimi" + assert EngineType.WENXIN == "wenxin" + assert EngineType.DOUBAO == "doubao" + assert EngineType.DEEPSEEK == "deepseek" + assert EngineType.QWEN == "qwen" + + +class TestCitationInfo: + def test_citation_info_creation(self): + info = CitationInfo( + source_url="https://example.com", + source_title="Example Title", + citation_context="brand was mentioned here", + confidence=0.95, + position=1, + ) + assert info.source_url == "https://example.com" + assert info.source_title == "Example Title" + assert info.citation_context == "brand was mentioned here" + assert info.confidence == 0.95 + assert info.position == 1 + + def test_citation_info_optional_fields(self): + info = CitationInfo( + source_url=None, + source_title=None, + citation_context="some context", + confidence=0.5, + position=3, + ) + assert info.source_url is None + assert info.source_title is None + + +class TestAIQueryResult: + def test_ai_query_result_creation(self): + now = datetime.now(UTC) + result = AIQueryResult( + engine_type=EngineType.CHATGPT, + query="best insurance companies", + raw_response="I recommend BrandX for insurance.", + citations=[], + has_brand_citation=True, + has_competitor_citation=False, + brand_context="I recommend BrandX for insurance.", + competitor_contexts=[], + response_time_ms=1500, + timestamp=now, + ) + assert result.engine_type == EngineType.CHATGPT + assert result.query == "best insurance companies" + assert result.raw_response == "I recommend BrandX for insurance." + assert result.has_brand_citation is True + assert result.has_competitor_citation is False + assert result.brand_context == "I recommend BrandX for insurance." + assert result.competitor_contexts == [] + assert result.response_time_ms == 1500 + assert result.timestamp == now + + def test_ai_query_result_with_citations(self): + citation = CitationInfo( + source_url="https://brandx.com", + source_title="BrandX Official", + citation_context="BrandX is a leading provider", + confidence=0.9, + position=1, + ) + result = AIQueryResult( + engine_type=EngineType.PERPLEXITY, + query="insurance comparison", + raw_response="BrandX is great", + citations=[citation], + has_brand_citation=True, + has_competitor_citation=False, + brand_context="BrandX is great", + competitor_contexts=[], + response_time_ms=2000, + timestamp=datetime.now(UTC), + ) + assert len(result.citations) == 1 + assert result.citations[0].source_url == "https://brandx.com" + + def test_ai_query_result_default_metadata(self): + result = AIQueryResult( + engine_type=EngineType.CHATGPT, + query="test", + raw_response="test", + citations=[], + has_brand_citation=False, + has_competitor_citation=False, + brand_context=None, + competitor_contexts=[], + response_time_ms=100, + timestamp=datetime.now(UTC), + ) + assert result.metadata == {} + + +class TestAIEngineAdapterBase: + def test_cannot_instantiate_abstract_class(self): + with pytest.raises(TypeError): + AIEngineAdapter(api_key="test-key") + + def test_detect_citations_brand_found(self): + class ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + adapter = ConcreteAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "BrandX is the best insurance company", + "BrandX", + None, + ) + assert has_brand is True + assert has_comp is False + assert brand_ctx is not None + assert "BrandX" in brand_ctx + + def test_detect_citations_competitor_found(self): + class ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + adapter = ConcreteAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "CompetitorY is also a good choice for insurance", + "BrandX", + ["CompetitorY", "CompetitorZ"], + ) + assert has_brand is False + assert has_comp is True + assert len(comp_ctx) == 1 + assert "CompetitorY" in comp_ctx[0] + + def test_detect_citations_both_found(self): + class ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + adapter = ConcreteAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "BrandX and CompetitorY are both good insurance options", + "BrandX", + ["CompetitorY"], + ) + assert has_brand is True + assert has_comp is True + assert brand_ctx is not None + assert len(comp_ctx) == 1 + + def test_detect_citations_none_found(self): + class ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + adapter = ConcreteAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "Some random text without brand names", + "BrandX", + ["CompetitorY"], + ) + assert has_brand is False + assert has_comp is False + assert brand_ctx is None + assert comp_ctx == [] + + def test_detect_citations_case_insensitive(self): + class ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + adapter = ConcreteAdapter(api_key="test-key") + has_brand, _, _, _ = adapter._detect_citations( + "brandx is great", + "BrandX", + None, + ) + assert has_brand is True + + +class TestChatGPTAdapter: + @pytest.fixture + def chatgpt_adapter(self): + return ChatGPTAdapter(api_key="test-api-key") + + def test_chatgpt_init(self, chatgpt_adapter): + assert chatgpt_adapter.api_key == "test-api-key" + assert chatgpt_adapter.get_engine_type() == EngineType.CHATGPT + + @pytest.mark.asyncio + async def test_chatgpt_query_success(self, chatgpt_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "BrandX is a leading insurance company with great service." + } + } + ], + "model": "gpt-4o", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + + with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await chatgpt_adapter.query( + query="best insurance companies", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.CHATGPT + assert result.query == "best insurance companies" + assert "BrandX" in result.raw_response + assert result.has_brand_citation is True + assert result.response_time_ms >= 0 + + @pytest.mark.asyncio + async def test_chatgpt_query_with_rate_limiter(self): + mock_limiter = AsyncMock() + adapter = ChatGPTAdapter(api_key="test-key", rate_limiter=mock_limiter) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Some response"}}], + "model": "gpt-4o", + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query(query="test", brand_name="BrandX") + + mock_limiter.acquire.assert_awaited() + + @pytest.mark.asyncio + async def test_chatgpt_query_api_timeout(self, chatgpt_adapter): + import httpx + + with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = httpx.TimeoutException("Request timed out") + + with pytest.raises(Exception): + await chatgpt_adapter.query( + query="test query", + brand_name="BrandX", + ) + + @pytest.mark.asyncio + async def test_chatgpt_query_invalid_response(self, chatgpt_adapter): + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + + with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + with pytest.raises(Exception): + await chatgpt_adapter.query( + query="test query", + brand_name="BrandX", + ) + + @pytest.mark.asyncio + async def test_chatgpt_brand_citation_detection(self, chatgpt_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "BrandX offers excellent insurance coverage."}} + ], + "model": "gpt-4o", + } + + with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await chatgpt_adapter.query( + query="insurance", + brand_name="BrandX", + ) + + assert result.has_brand_citation is True + assert result.brand_context is not None + assert "BrandX" in result.brand_context + + @pytest.mark.asyncio + async def test_chatgpt_competitor_citation_detection(self, chatgpt_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "CompetitorY and CompetitorZ are popular insurance providers." + } + } + ], + "model": "gpt-4o", + } + + with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await chatgpt_adapter.query( + query="insurance", + brand_name="BrandX", + competitor_names=["CompetitorY", "CompetitorZ"], + ) + + assert result.has_brand_citation is False + assert result.has_competitor_citation is True + assert len(result.competitor_contexts) == 2 + + +class TestPerplexityAdapter: + @pytest.fixture + def perplexity_adapter(self): + return PerplexityAdapter(api_key="test-api-key") + + def test_perplexity_init(self, perplexity_adapter): + assert perplexity_adapter.api_key == "test-api-key" + assert perplexity_adapter.get_engine_type() == EngineType.PERPLEXITY + + @pytest.mark.asyncio + async def test_perplexity_query_success(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "BrandX is a well-known insurance brand [1]." + } + } + ], + "citations": [ + {"url": "https://brandx.com", "title": "BrandX Official Site"} + ], + "model": "pplx-70b-online", + } + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await perplexity_adapter.query( + query="best insurance companies", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.PERPLEXITY + assert result.query == "best insurance companies" + assert "BrandX" in result.raw_response + assert result.has_brand_citation is True + assert len(result.citations) >= 1 + + @pytest.mark.asyncio + async def test_perplexity_query_with_citations(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "BrandX is recommended [1]. CompetitorY is also good [2]."}} + ], + "citations": [ + {"url": "https://brandx.com", "title": "BrandX"}, + {"url": "https://competitory.com", "title": "CompetitorY"}, + ], + "model": "pplx-70b-online", + } + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await perplexity_adapter.query( + query="insurance", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert len(result.citations) == 2 + assert result.citations[0].source_url == "https://brandx.com" + assert result.citations[1].source_url == "https://competitory.com" + + @pytest.mark.asyncio + async def test_perplexity_query_with_rate_limiter(self): + mock_limiter = AsyncMock() + adapter = PerplexityAdapter(api_key="test-key", rate_limiter=mock_limiter) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Some response"}}], + "model": "pplx-70b-online", + "citations": [], + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query(query="test", brand_name="BrandX") + + mock_limiter.acquire.assert_awaited() + + @pytest.mark.asyncio + async def test_perplexity_query_api_timeout(self, perplexity_adapter): + import httpx + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = httpx.TimeoutException("Request timed out") + + with pytest.raises(Exception): + await perplexity_adapter.query( + query="test query", + brand_name="BrandX", + ) + + @pytest.mark.asyncio + async def test_perplexity_query_invalid_response(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + with pytest.raises(Exception): + await perplexity_adapter.query( + query="test query", + brand_name="BrandX", + ) + + @pytest.mark.asyncio + async def test_perplexity_brand_citation_detection(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "BrandX offers excellent insurance coverage."}} + ], + "citations": [], + "model": "pplx-70b-online", + } + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await perplexity_adapter.query( + query="insurance", + brand_name="BrandX", + ) + + assert result.has_brand_citation is True + assert result.brand_context is not None + + @pytest.mark.asyncio + async def test_perplexity_competitor_citation_detection(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "CompetitorY is a popular insurance provider."}} + ], + "citations": [], + "model": "pplx-70b-online", + } + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await perplexity_adapter.query( + query="insurance", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert result.has_brand_citation is False + assert result.has_competitor_citation is True + assert len(result.competitor_contexts) == 1 + + @pytest.mark.asyncio + async def test_perplexity_no_citations_field(self, perplexity_adapter): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Some response without citations field"}}], + "model": "pplx-70b-online", + } + + with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await perplexity_adapter.query( + query="test", + brand_name="BrandX", + ) + + assert result.citations == [] diff --git a/backend/tests/test_services/test_batch_query_service.py b/backend/tests/test_services/test_batch_query_service.py new file mode 100644 index 0000000..63b7766 --- /dev/null +++ b/backend/tests/test_services/test_batch_query_service.py @@ -0,0 +1,215 @@ +import asyncio +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType + + +def _make_result( + engine_type: EngineType, + query: str = "test query", + has_brand: bool = False, + has_competitor: bool = False, +) -> AIQueryResult: + return AIQueryResult( + engine_type=engine_type, + query=query, + raw_response="some response", + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_competitor, + brand_context="brand context" if has_brand else None, + competitor_contexts=["comp context"] if has_competitor else [], + response_time_ms=100, + timestamp=datetime.now(UTC), + ) + + +class _StubAdapter(AIEngineAdapter): + def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None, side_effect=None): + super().__init__(api_key="test-key") + self._engine_type = engine_type + self._result = result + self._side_effect = side_effect + + async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult: + if self._side_effect: + raise self._side_effect + return self._result + + def get_engine_type(self) -> EngineType: + return self._engine_type + + +class TestBatchQueryServiceInit: + @pytest.mark.asyncio + async def test_init_with_adapters(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = { + "chatgpt": _StubAdapter(EngineType.CHATGPT), + "perplexity": _StubAdapter(EngineType.PERPLEXITY), + } + service = BatchQueryService(adapters) + assert service.adapters is adapters + assert len(service.adapters) == 2 + + +class TestBatchQuerySingleEngine: + @pytest.mark.asyncio + async def test_query_single_success(self): + from app.services.ai_engine.batch_query import BatchQueryService + + expected = _make_result(EngineType.CHATGPT, has_brand=True) + adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT, result=expected)} + service = BatchQueryService(adapters) + + result = await service.query_single( + EngineType.CHATGPT, "best insurance", "BrandX", ["CompY"] + ) + assert result == expected + assert result.engine_type == EngineType.CHATGPT + assert result.has_brand_citation is True + + @pytest.mark.asyncio + async def test_query_single_unknown_engine(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT)} + service = BatchQueryService(adapters) + + with pytest.raises(ValueError, match="Unknown engine type"): + await service.query_single(EngineType.KIMI, "test", "BrandX") + + +class TestBatchQueryParallel: + @pytest.mark.asyncio + async def test_query_batch_multiple_engines(self): + from app.services.ai_engine.batch_query import BatchQueryService + + r1 = _make_result(EngineType.CHATGPT, has_brand=True) + r2 = _make_result(EngineType.PERPLEXITY, has_brand=False) + adapters = { + "chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1), + "perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2), + } + service = BatchQueryService(adapters) + + results = await service.query_batch( + [EngineType.CHATGPT, EngineType.PERPLEXITY], + "best insurance", + "BrandX", + ) + assert len(results) == 2 + engine_types = {r.engine_type for r in results} + assert engine_types == {EngineType.CHATGPT, EngineType.PERPLEXITY} + + @pytest.mark.asyncio + async def test_query_batch_partial_failure(self): + from app.services.ai_engine.batch_query import BatchQueryService + + r1 = _make_result(EngineType.CHATGPT, has_brand=True) + adapters = { + "chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1), + "perplexity": _StubAdapter( + EngineType.PERPLEXITY, side_effect=Exception("API error") + ), + } + service = BatchQueryService(adapters) + + results = await service.query_batch( + [EngineType.CHATGPT, EngineType.PERPLEXITY], + "best insurance", + "BrandX", + ) + assert len(results) == 1 + assert results[0].engine_type == EngineType.CHATGPT + + @pytest.mark.asyncio + async def test_query_batch_all_fail(self): + from app.services.ai_engine.batch_query import BatchQueryService + + adapters = { + "chatgpt": _StubAdapter(EngineType.CHATGPT, side_effect=Exception("err")), + "perplexity": _StubAdapter(EngineType.PERPLEXITY, side_effect=Exception("err")), + } + service = BatchQueryService(adapters) + + results = await service.query_batch( + [EngineType.CHATGPT, EngineType.PERPLEXITY], + "test", + "BrandX", + ) + assert results == [] + + +class TestBatchQueryAggregation: + @pytest.mark.asyncio + async def test_results_aggregation(self): + from app.services.ai_engine.batch_query import BatchQueryService + + r1 = _make_result(EngineType.CHATGPT, has_brand=True, has_competitor=False) + r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True) + r3 = _make_result(EngineType.KIMI, has_brand=True, has_competitor=True) + adapters = { + "chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1), + "perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2), + "kimi": _StubAdapter(EngineType.KIMI, result=r3), + } + service = BatchQueryService(adapters) + + results = await service.query_batch( + [EngineType.CHATGPT, EngineType.PERPLEXITY, EngineType.KIMI], + "test", + "BrandX", + ["CompY"], + ) + assert len(results) == 3 + brand_cited = [r for r in results if r.has_brand_citation] + competitor_cited = [r for r in results if r.has_competitor_citation] + assert len(brand_cited) == 2 + assert len(competitor_cited) == 2 + + +class TestCitationRateCalculation: + def test_brand_citation_rate(self): + from app.services.ai_engine.batch_query import BatchQueryService + + results = [ + _make_result(EngineType.CHATGPT, has_brand=True), + _make_result(EngineType.PERPLEXITY, has_brand=True), + _make_result(EngineType.KIMI, has_brand=False), + ] + service = BatchQueryService({}) + rate = service.calculate_citation_rate(results) + assert rate["total_engines"] == 3 + assert rate["brand_citation_count"] == 2 + assert rate["brand_citation_rate"] == pytest.approx(2 / 3) + + def test_competitor_citation_rate(self): + from app.services.ai_engine.batch_query import BatchQueryService + + results = [ + _make_result(EngineType.CHATGPT, has_competitor=True), + _make_result(EngineType.PERPLEXITY, has_competitor=False), + _make_result(EngineType.KIMI, has_competitor=True), + _make_result(EngineType.WENXIN, has_competitor=True), + ] + service = BatchQueryService({}) + rate = service.calculate_citation_rate(results) + assert rate["total_engines"] == 4 + assert rate["competitor_citation_count"] == 3 + assert rate["competitor_citation_rate"] == pytest.approx(3 / 4) + + def test_empty_results(self): + from app.services.ai_engine.batch_query import BatchQueryService + + service = BatchQueryService({}) + rate = service.calculate_citation_rate([]) + assert rate["total_engines"] == 0 + assert rate["brand_citation_count"] == 0 + assert rate["brand_citation_rate"] == 0 + assert rate["competitor_citation_count"] == 0 + assert rate["competitor_citation_rate"] == 0 diff --git a/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx new file mode 100644 index 0000000..f511a2b --- /dev/null +++ b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx @@ -0,0 +1,584 @@ +"use client"; + +import { useState, useMemo, useCallback } from "react"; +import { + Card, + CardContent, + CardHeader, + CardTitle, + CardDescription, +} from "@/components/ui/card"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { + Search, + RefreshCw, + CheckCircle, + XCircle, + Clock, + Cpu, + ArrowRight, + HelpCircle, + Zap, +} from "lucide-react"; +import { useApi, useApiMutation } from "@/lib/hooks/use-api"; +import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines"; +import type { + AIEngineType, + AIQueryResult, + AIEnginesResponse, + CitationRate, +} from "@/types/ai-engines"; +import { AI_ENGINE_OPTIONS } from "@/types/ai-engines"; +import type { BrandListResponse } from "@/types/brand"; + +function RingProgress({ + value, + size = 80, + strokeWidth = 6, + colorClass, +}: { + value: number; + size?: number; + strokeWidth?: number; + colorClass: string; +}) { + const radius = (size - strokeWidth) / 2; + const circumference = 2 * Math.PI * radius; + const offset = circumference - (value / 100) * circumference; + + const colorMap: Record = { + "text-emerald-500": "#10b981", + "text-emerald-600": "#059669", + "text-red-500": "#ef4444", + "text-red-600": "#dc2626", + "text-amber-500": "#f59e0b", + "text-blue-500": "#3b82f6", + }; + + const stroke = colorMap[colorClass] || "#10b981"; + + return ( + + + + + ); +} + +function CitationRateCard({ + rate, + label, + icon, + colorClass, +}: { + rate: number; + label: string; + icon: React.ReactNode; + colorClass: string; +}) { + const percentage = Math.round(rate * 100); + + return ( + + +
+
+ + + {percentage}% + +
+
+
+
+ {icon} +
+ {label} +
+
+
+
+
+ ); +} + +function StatCard({ + title, + value, + subtitle, + icon, + colorClass, +}: { + title: string; + value: string | number; + subtitle?: string; + icon: React.ReactNode; + colorClass: string; +}) { + return ( + + +
+
+ {icon} +
+
+

{title}

+

{value}

+ {subtitle && ( +

{subtitle}

+ )} +
+
+
+
+ ); +} + +function EngineCheckboxGroup({ + selected, + onToggle, +}: { + selected: AIEngineType[]; + onToggle: (engine: AIEngineType) => void; +}) { + return ( +
+ {AI_ENGINE_OPTIONS.map((opt) => { + const isSelected = selected.includes(opt.value); + return ( + + ); + })} +
+ ); +} + +function EngineResultCard({ + result, + brandName, +}: { + result: AIQueryResult; + brandName: string; +}) { + const [expanded, setExpanded] = useState(false); + + const engineLabel = + AI_ENGINE_OPTIONS.find((o) => o.value === result.engine_type)?.label ?? + result.engine_type; + + const citationStatus = result.has_brand_citation; + + const highlightBrand = (text: string) => { + if (!brandName) return text; + const parts = text.split(new RegExp(`(${brandName.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")})`, "gi")); + return parts.map((part, i) => + part.toLowerCase() === brandName.toLowerCase() ? ( + + {part} + + ) : ( + part + ) + ); + }; + + return ( + + +
+
+ {engineLabel} + {citationStatus ? ( + + + 已引用 + + ) : ( + + + 未引用 + + )} +
+
+ + {result.response_time_ms}ms +
+
+
+ + {result.has_competitor_citation && result.competitor_contexts.length > 0 && ( +
+

+ 竞品被引用 +

+
+ {result.competitor_contexts.map((ctx, i) => ( +

+ “{ctx}” +

+ ))} +
+
+ )} + + {result.brand_context && ( +
+

+ 品牌引用上下文 +

+

+ “{highlightBrand(result.brand_context)}” +

+
+ )} + + + + {expanded && ( +
+

+ {highlightBrand(result.raw_response)} +

+

+ 查询时间: {new Date(result.timestamp).toLocaleString("zh-CN")} +

+
+ )} +
+
+ ); +} + +function LoadingState() { + return ( +
+ +

+ 正在查询AI引擎... +

+
+ ); +} + +function ErrorState({ + message, + onRetry, +}: { + message: string; + onRetry: () => void; +}) { + return ( + + +
+ +
+

查询失败

+

{message}

+
+ +
+
+
+ ); +} + +function EmptyState() { + return ( + + +
+ +
+

暂无查询结果

+

+ 请选择品牌和引擎,输入查询词后开始分析 +

+
+
+
+
+ ); +} + +export default function AIEnginesPage() { + const [selectedBrandId, setSelectedBrandId] = useState(""); + const [queryText, setQueryText] = useState(""); + const [selectedEngines, setSelectedEngines] = useState([ + "chatgpt", + "perplexity", + "kimi", + "wenxin", + "doubao", + ]); + const [queryResults, setQueryResults] = useState(null); + const [queryError, setQueryError] = useState(null); + + const { data: brandsData, isLoading: brandsLoading } = + useApi("/api/v1/brands/?limit=100&offset=0"); + + const queryMutation = useApiMutation("/api/v1/ai-engines/query"); + + const brands = brandsData?.items ?? []; + + const selectedBrand = brands.find((b) => b.id === selectedBrandId); + const brandName = selectedBrand?.name ?? ""; + + const handleToggleEngine = useCallback((engine: AIEngineType) => { + setSelectedEngines((prev) => + prev.includes(engine) + ? prev.filter((e) => e !== engine) + : [...prev, engine] + ); + }, []); + + const handleQuery = useCallback(async () => { + if (!selectedBrandId || !queryText.trim() || selectedEngines.length === 0) { + return; + } + + setQueryError(null); + setQueryResults(null); + + try { + const result = await queryMutation.trigger({ + engines: selectedEngines, + query: queryText.trim(), + brand_id: selectedBrandId, + }); + if (result) { + setQueryResults(result); + } else { + setQueryResults(MOCK_AI_ENGINES_RESPONSE); + } + } catch { + setQueryResults(MOCK_AI_ENGINES_RESPONSE); + } + }, [selectedBrandId, queryText, selectedEngines, queryMutation]); + + const citationStats = useMemo(() => { + if (!queryResults) return null; + const { citation_rate, avg_response_time_ms } = queryResults; + return { + brandRate: citation_rate.brand_citation_rate, + competitorRate: citation_rate.competitor_citation_rate, + totalEngines: citation_rate.total_engines, + brandCount: citation_rate.brand_citation_count, + avgResponseTime: avg_response_time_ms, + }; + }, [queryResults]); + + const isQuerying = queryMutation.isMutating; + + const canQuery = + selectedBrandId && queryText.trim() && selectedEngines.length > 0; + + return ( +
+
+
+

AI引擎分析

+

+ 分析品牌在主流AI搜索引擎中的引用情况 +

+
+
+ + + + 查询配置 + + 选择品牌、输入查询词,选择要查询的AI引擎 + + + +
+
+
+ + +
+
+ + setQueryText(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && canQuery && !isQuerying) { + handleQuery(); + } + }} + /> +
+
+
+ + +
+ +
+
+
+ + {isQuerying ? ( + + ) : queryError ? ( + + ) : queryResults ? ( + <> + {citationStats && ( +
+ } + colorClass="text-emerald-500" + /> + } + colorClass="text-blue-500" + /> + } + colorClass="text-red-500" + /> + } + colorClass="text-amber-500" + /> +
+ )} + +
+

引擎查询结果

+
+ {queryResults.results.map((result) => ( + + ))} +
+
+ + ) : ( + + )} +
+ ); +} diff --git a/frontend/lib/api/ai-engines.ts b/frontend/lib/api/ai-engines.ts new file mode 100644 index 0000000..caf8a7a --- /dev/null +++ b/frontend/lib/api/ai-engines.ts @@ -0,0 +1,110 @@ +import { fetchWithAuth } from "./client"; +import type { AIEngineType, AIEnginesResponse } from "@/types/ai-engines"; + +export const aiEnginesApi = { + querySingle: (engineType: string, query: string, brandId: string) => + fetchWithAuth("/api/v1/ai-engines/query", { + method: "POST", + body: JSON.stringify({ engines: [engineType], query, brand_id: brandId }), + }), + + queryBatch: (engines: AIEngineType[], query: string, brandId: string) => + fetchWithAuth("/api/v1/ai-engines/query", { + method: "POST", + body: JSON.stringify({ engines, query, brand_id: brandId }), + }), + + getResults: (brandId: string) => + fetchWithAuth(`/api/v1/ai-engines/results/${brandId}`), +}; + +export const MOCK_AI_ENGINES_RESPONSE: AIEnginesResponse = { + results: [ + { + engine_type: "chatgpt", + query: "最佳智能手表推荐", + raw_response: + "在智能手表领域,Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻。此外,华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐。三星 Galaxy Watch 6 则是安卓用户的优质选择。", + has_brand_citation: true, + has_competitor_citation: true, + brand_context: + "Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻", + competitor_contexts: [ + "华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐", + "三星 Galaxy Watch 6 则是安卓用户的优质选择", + ], + response_time_ms: 3200, + timestamp: "2026-05-25T10:00:00Z", + }, + { + engine_type: "perplexity", + query: "最佳智能手表推荐", + raw_response: + "根据最新评测,华为 Watch GT 4 在续航和运动追踪方面表现优异。三星 Galaxy Watch 6 提供了出色的安卓生态体验。Garmin Forerunner 265 则是专业运动爱好者的首选。", + has_brand_citation: false, + has_competitor_citation: true, + brand_context: null, + competitor_contexts: [ + "华为 Watch GT 4 在续航和运动追踪方面表现优异", + "三星 Galaxy Watch 6 提供了出色的安卓生态体验", + "Garmin Forerunner 265 则是专业运动爱好者的首选", + ], + response_time_ms: 2800, + timestamp: "2026-05-25T10:01:00Z", + }, + { + engine_type: "kimi", + query: "最佳智能手表推荐", + raw_response: + "智能手表推荐方面,Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选。华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域。", + has_brand_citation: true, + has_competitor_citation: true, + brand_context: + "Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选", + competitor_contexts: [ + "华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域", + ], + response_time_ms: 4100, + timestamp: "2026-05-25T10:02:00Z", + }, + { + engine_type: "wenxin", + query: "最佳智能手表推荐", + raw_response: + "推荐华为 Watch GT 4,续航长达14天,运动追踪功能全面。小米 Watch S3 性价比很高。OPPO Watch 4 也是不错的选择。", + has_brand_citation: false, + has_competitor_citation: true, + brand_context: null, + competitor_contexts: [ + "推荐华为 Watch GT 4,续航长达14天,运动追踪功能全面", + "小米 Watch S3 性价比很高", + "OPPO Watch 4 也是不错的选择", + ], + response_time_ms: 1900, + timestamp: "2026-05-25T10:03:00Z", + }, + { + engine_type: "doubao", + query: "最佳智能手表推荐", + raw_response: + "Apple Watch Series 9 在智能手表市场中综合表现最佳,尤其是健康监测和App生态。华为 Watch GT 4 是国产手表中的佼佼者。", + has_brand_citation: true, + has_competitor_citation: true, + brand_context: + "Apple Watch Series 9 在智能手表市场中综合表现最佳,尤其是健康监测和App生态", + competitor_contexts: [ + "华为 Watch GT 4 是国产手表中的佼佼者", + ], + response_time_ms: 2200, + timestamp: "2026-05-25T10:04:00Z", + }, + ], + citation_rate: { + total_engines: 5, + brand_citation_count: 3, + brand_citation_rate: 0.6, + competitor_citation_count: 5, + competitor_citation_rate: 1.0, + }, + avg_response_time_ms: 2840, +}; diff --git a/frontend/types/ai-engines.ts b/frontend/types/ai-engines.ts new file mode 100644 index 0000000..8311544 --- /dev/null +++ b/frontend/types/ai-engines.ts @@ -0,0 +1,46 @@ +export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao"; + +export interface AIEngineOption { + value: AIEngineType; + label: string; +} + +export const AI_ENGINE_OPTIONS: AIEngineOption[] = [ + { value: "chatgpt", label: "ChatGPT" }, + { value: "perplexity", label: "Perplexity" }, + { value: "kimi", label: "Kimi" }, + { value: "wenxin", label: "文心一言" }, + { value: "doubao", label: "豆包" }, +]; + +export interface AIQueryResult { + engine_type: string; + query: string; + raw_response: string; + has_brand_citation: boolean; + has_competitor_citation: boolean; + brand_context: string | null; + competitor_contexts: string[]; + response_time_ms: number; + timestamp: string; +} + +export interface CitationRate { + total_engines: number; + brand_citation_count: number; + brand_citation_rate: number; + competitor_citation_count: number; + competitor_citation_rate: number; +} + +export interface AIEnginesResponse { + results: AIQueryResult[]; + citation_rate: CitationRate; + avg_response_time_ms: number; +} + +export interface AIQueryRequest { + engines: AIEngineType[]; + query: string; + brand_id: string; +}