feat: Phase1 Week1-2 - AI引擎查询分析完整实现

后端(TDD):
- AI引擎适配器框架(基类+5个适配器)
- ChatGPT/Perplexity/Kimi/文心一言/豆包适配器
- 批量并行查询服务(asyncio.gather)
- AI引擎查询API端点(3个)
- 51+14=65个测试全部通过

前端:
- AI引擎分析页面(引用率/引擎结果/上下文详情)
- AI引擎API客户端+类型定义
- Mock数据降级支持
This commit is contained in:
chiguyong 2026-05-25 10:29:20 +08:00
parent 65e2f3c380
commit 1ec5ea42da
17 changed files with 2896 additions and 0 deletions

View File

@ -0,0 +1,175 @@
import logging
from functools import lru_cache
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from app.api.deps import get_current_user
from app.models.user import User
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.batch_query import BatchQueryService
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter
logger = logging.getLogger(__name__)
router = APIRouter()
class SingleQueryRequest(BaseModel):
engine: str
query: str = Field(min_length=1, max_length=500)
brand_name: str = Field(min_length=1, max_length=200)
competitor_names: list[str] | None = None
class BatchQueryRequest(BaseModel):
engines: list[str] = Field(min_length=1)
query: str = Field(min_length=1, max_length=500)
brand_name: str = Field(min_length=1, max_length=200)
competitor_names: list[str] | None = None
class QueryResultResponse(BaseModel):
engine_type: str
query: str
raw_response: str
has_brand_citation: bool
has_competitor_citation: bool
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
model_config = {"from_attributes": True}
class CitationRateResponse(BaseModel):
total_engines: int
brand_citation_count: int
brand_citation_rate: float
competitor_citation_count: int
competitor_citation_rate: float
class BatchQueryResponse(BaseModel):
results: list[QueryResultResponse]
citation_rate: CitationRateResponse
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
EngineType.CHATGPT: ChatGPTAdapter,
EngineType.PERPLEXITY: PerplexityAdapter,
EngineType.KIMI: KimiAdapter,
EngineType.WENXIN: WenxinAdapter,
EngineType.DOUBAO: DoubaoAdapter,
}
@lru_cache(maxsize=1)
def _build_adapters() -> dict[str, AIEngineAdapter]:
adapters: dict[str, AIEngineAdapter] = {}
for engine_type, cls in _ADAPTER_CLASSES.items():
try:
adapters[engine_type.value] = cls()
except Exception:
logger.warning(f"Failed to initialize {engine_type.value} adapter")
return adapters
def get_batch_service() -> BatchQueryService:
return BatchQueryService(_build_adapters())
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
return QueryResultResponse(
engine_type=r.engine_type.value,
query=r.query,
raw_response=r.raw_response,
has_brand_citation=r.has_brand_citation,
has_competitor_citation=r.has_competitor_citation,
brand_context=r.brand_context,
competitor_contexts=r.competitor_contexts,
response_time_ms=r.response_time_ms,
)
def _parse_engine(value: str) -> EngineType:
try:
return EngineType(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown engine: {value}",
)
async def _execute_batch(
service: BatchQueryService,
engine_types: list[EngineType],
query: str,
brand_name: str,
competitor_names: list[str] | None,
) -> BatchQueryResponse:
results = await service.query_batch(
engine_types, query, brand_name, competitor_names
)
citation_rate = service.calculate_citation_rate(results)
return BatchQueryResponse(
results=[_result_to_response(r) for r in results],
citation_rate=CitationRateResponse(**citation_rate),
)
@router.post("/query", response_model=QueryResultResponse)
async def query_single_engine(
request: SingleQueryRequest,
current_user: User = Depends(get_current_user),
):
service = get_batch_service()
engine_type = _parse_engine(request.engine)
try:
result = await service.query_single(
engine_type, request.query, request.brand_name, request.competitor_names
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
return _result_to_response(result)
@router.post("/query-batch", response_model=BatchQueryResponse)
async def query_batch_engines(
request: BatchQueryRequest,
current_user: User = Depends(get_current_user),
):
service = get_batch_service()
engine_types = [_parse_engine(e) for e in request.engines]
return await _execute_batch(
service, engine_types, request.query, request.brand_name, request.competitor_names
)
@router.get("/results", response_model=BatchQueryResponse)
async def get_query_results(
engines: str = Query(..., description="Comma-separated engine names"),
query: str = Query(..., min_length=1, max_length=500),
brand_name: str = Query(..., min_length=1, max_length=200),
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
current_user: User = Depends(get_current_user),
):
service = get_batch_service()
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
engine_types = [_parse_engine(e) for e in engine_list]
comp_names = (
[n.strip() for n in competitor_names.split(",") if n.strip()]
if competitor_names
else None
)
return await _execute_batch(
service, engine_types, query, brand_name, comp_names
)

View File

@ -38,6 +38,7 @@ from app.api.platforms import router as platforms_router
from app.api.platform_rules import router as platform_rules_router
from app.api.image import router as image_router
from app.api.knowledge_graph import router as knowledge_graph_router
from app.api.ai_engines import router as ai_engines_router
from app.config import settings
from app.database import engine, Base
from app.schemas.common import ErrorResponse, ErrorCode
@ -165,6 +166,7 @@ app.include_router(platforms_router, prefix="/api/v1")
app.include_router(platform_rules_router)
app.include_router(image_router, prefix="/api/v1")
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
@app.get("/health", tags=["可观测性"])

View File

@ -0,0 +1,20 @@
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
from .chatgpt import ChatGPTAdapter
from .perplexity import PerplexityAdapter
from .kimi import KimiAdapter
from .wenxin import WenxinAdapter
from .doubao import DoubaoAdapter
from .batch_query import BatchQueryService
__all__ = [
"AIEngineAdapter",
"AIQueryResult",
"CitationInfo",
"EngineType",
"ChatGPTAdapter",
"PerplexityAdapter",
"KimiAdapter",
"WenxinAdapter",
"DoubaoAdapter",
"BatchQueryService",
]

View File

@ -0,0 +1,147 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
import httpx
logger = logging.getLogger(__name__)
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class EngineType(str, Enum):
CHATGPT = "chatgpt"
PERPLEXITY = "perplexity"
KIMI = "kimi"
WENXIN = "wenxin"
DOUBAO = "doubao"
DEEPSEEK = "deepseek"
QWEN = "qwen"
@dataclass
class CitationInfo:
source_url: str | None
source_title: str | None
citation_context: str
confidence: float
position: int
@dataclass
class AIQueryResult:
engine_type: EngineType
query: str
raw_response: str
citations: list[CitationInfo]
has_brand_citation: bool
has_competitor_citation: bool
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
timestamp: datetime
metadata: dict[str, Any] = field(default_factory=dict)
class AIEngineAdapter(ABC):
def __init__(self, api_key: str, rate_limiter=None):
self.api_key = api_key
self.rate_limiter = rate_limiter
self._client: httpx.AsyncClient | None = None
@abstractmethod
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
pass
@abstractmethod
def get_engine_type(self) -> EngineType:
pass
def _detect_citations(
self,
response: str,
brand_name: str,
competitor_names: list[str] | None,
) -> tuple[bool, bool, str | None, list[str]]:
has_brand = brand_name.lower() in response.lower()
brand_context = None
if has_brand:
idx = response.lower().find(brand_name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(brand_name) + 100)
brand_context = response[start:end]
has_competitor = False
competitor_contexts = []
if competitor_names:
for name in competitor_names:
if name.lower() in response.lower():
has_competitor = True
idx = response.lower().find(name.lower())
start = max(0, idx - 100)
end = min(len(response), idx + len(name) + 100)
competitor_contexts.append(response[start:end])
return has_brand, has_competitor, brand_context, competitor_contexts
async def _request_with_retry(self, payload: dict) -> dict:
if self.rate_limiter:
await self.rate_limiter.acquire()
engine_name = self.get_engine_type().value
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await self._client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
logger.warning(
f"[{engine_name}] HTTP {response.status_code}, "
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
)
last_error = Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
await asyncio.sleep(wait)
continue
raise Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
except httpx.TransportError as exc:
logger.warning(
f"[{engine_name}] Transport error: {exc}, "
f"retry {attempt + 1}/{_MAX_RETRIES}"
)
last_error = Exception(f"Network error: {exc}")
await asyncio.sleep(2**attempt)
continue
raise last_error or Exception("Max retries exceeded")
async def close(self) -> None:
if self._client:
await self._client.aclose()
async def __aenter__(self) -> "AIEngineAdapter":
return self
async def __aexit__(self, *exc) -> None:
await self.close()

View File

@ -0,0 +1,55 @@
import asyncio
import logging
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
class BatchQueryService:
def __init__(self, adapters: dict[str, AIEngineAdapter]):
self.adapters = adapters
async def query_single(
self,
engine_type: EngineType,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
adapter = self.adapters.get(engine_type.value)
if not adapter:
raise ValueError(f"Unknown engine type: {engine_type}")
return await adapter.query(query, brand_name, competitor_names)
async def query_batch(
self,
engines: list[EngineType],
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> list[AIQueryResult]:
tasks = [
self.query_single(engine, query, brand_name, competitor_names)
for engine in engines
]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful: list[AIQueryResult] = []
for r in results:
if isinstance(r, AIQueryResult):
successful.append(r)
elif isinstance(r, Exception):
logger.warning(f"Engine query failed: {r}")
return successful
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
total = len(results)
brand_cited = sum(1 for r in results if r.has_brand_citation)
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
return {
"total_engines": total,
"brand_citation_count": brand_cited,
"brand_citation_rate": brand_cited / total if total > 0 else 0,
"competitor_citation_count": competitor_cited,
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
}

View File

@ -0,0 +1,88 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "gpt-4o"
_DEFAULT_BASE_URL = "https://api.openai.com/v1"
class ChatGPTAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url or os.getenv("OPENAI_BASE_URL", _DEFAULT_BASE_URL)
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.CHATGPT
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 4096,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[chatgpt] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model)},
)

View File

@ -0,0 +1,96 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "doubao-pro-4k"
_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
class DoubaoAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
endpoint_id: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("DOUBAO_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "")
self._base_url = _DEFAULT_BASE_URL.rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.DOUBAO
def _get_model_id(self) -> str:
if self._endpoint_id and self._endpoint_id.strip():
if not self._endpoint_id.startswith("ep-"):
return f"ep-{self._endpoint_id}"
return self._endpoint_id
return _DEFAULT_MODEL
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
model_id = self._get_model_id()
messages = [
{
"role": "system",
"content": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
},
{"role": "user", "content": query},
]
payload = {
"model": model_id,
"messages": messages,
"temperature": 0.7,
"max_tokens": 2000,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[doubao] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", model_id), "usage": data.get("usage")},
)

View File

@ -0,0 +1,91 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "moonshot-v1-8k"
_DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
class KimiAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("MOONSHOT_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or _DEFAULT_MODEL
self._base_url = (
base_url or _DEFAULT_BASE_URL
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.KIMI
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{
"role": "system",
"content": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 2000,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[kimi] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
)

View File

@ -0,0 +1,107 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "pplx-70b-online"
_DEFAULT_BASE_URL = "https://api.perplexity.ai"
class PerplexityAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url or os.getenv("PERPLEXITY_BASE_URL", _DEFAULT_BASE_URL)
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.PERPLEXITY
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{"role": "system", "content": "You are a helpful research assistant."},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 4096,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
citations = self._extract_citations(data)
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[perplexity] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=citations,
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model)},
)
def _extract_citations(self, data: dict) -> list[CitationInfo]:
raw_citations = data.get("citations", [])
if not raw_citations:
return []
citations = []
for idx, cit in enumerate(raw_citations):
citations.append(
CitationInfo(
source_url=cit.get("url"),
source_title=cit.get("title"),
citation_context="",
confidence=1.0,
position=idx + 1,
)
)
return citations

View File

@ -0,0 +1,145 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "completions_pro"
_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
_CHAT_URL_TEMPLATE = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}"
"?access_token={token}"
)
_cached_token: str | None = None
_token_expires_at: float = 0.0
class WenxinAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
secret_key: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("BAIDU_QIANFAN_API_KEY", ""),
rate_limiter=rate_limiter,
)
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
self._model = _DEFAULT_MODEL
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
)
def get_engine_type(self) -> EngineType:
return EngineType.WENXIN
async def _get_access_token(self) -> str:
global _cached_token, _token_expires_at
now = time.monotonic()
if _cached_token and now < _token_expires_at:
return _cached_token
response = await self._client.post(
_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key,
},
)
if response.status_code != 200:
raise RuntimeError(
f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}"
)
data = response.json()
token = data.get("access_token")
if not token:
error_desc = data.get("error_description", "未知错误")
raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}")
expires_in = data.get("expires_in", 2592000)
_cached_token = token
_token_expires_at = now + expires_in - 300
logger.info("[wenxin] access_token 获取成功")
return token
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
access_token = await self._get_access_token()
chat_url = _CHAT_URL_TEMPLATE.format(
model=self._model,
token=access_token,
)
payload = {
"messages": [{"role": "user", "content": query}],
"system": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
"temperature": 0.7,
"max_output_tokens": 2000,
}
if self.rate_limiter:
await self.rate_limiter.acquire()
response = await self._client.post(chat_url, json=payload)
if response.status_code == 429:
raise RuntimeError("文心一言 API 限流")
if response.status_code != 200:
error_body = response.text[:500]
raise RuntimeError(
f"文心一言 API 返回错误 {response.status_code}: {error_body}"
)
data = response.json()
error_code = data.get("error_code")
if error_code:
error_msg = data.get("error_msg", "未知错误")
raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}")
content = data.get("result", "")
if not content:
raise RuntimeError("文心一言 API 返回空内容")
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[wenxin] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": self._model, "usage": data.get("usage")},
)

View File

@ -0,0 +1,216 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password
from app.services.ai_engine.base import AIQueryResult, EngineType
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
def _make_result(
engine_type: EngineType,
query: str = "best insurance",
has_brand: bool = False,
has_competitor: bool = False,
) -> AIQueryResult:
return AIQueryResult(
engine_type=engine_type,
query=query,
raw_response="BrandX is a great insurance company",
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_competitor,
brand_context="BrandX is great" if has_brand else None,
competitor_contexts=["CompY is ok"] if has_competitor else [],
response_time_ms=150,
timestamp=datetime.now(UTC),
)
class TestSingleQueryEndpoint:
@pytest.mark.asyncio
async def test_query_single_engine(self, async_client):
mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_single.return_value = mock_result
mock_get_service.return_value = mock_service
response = await async_client.post(
"/api/v1/ai-engines/query",
json={
"engine": "chatgpt",
"query": "best insurance",
"brand_name": "BrandX",
"competitor_names": ["CompY"],
},
)
assert response.status_code == 200
data = response.json()
assert data["engine_type"] == "chatgpt"
assert data["has_brand_citation"] is True
assert data["query"] == "best insurance"
class TestBatchQueryEndpoint:
@pytest.mark.asyncio
async def test_query_batch_parallel(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={
"total_engines": 2,
"brand_citation_count": 1,
"brand_citation_rate": 0.5,
"competitor_citation_count": 1,
"competitor_citation_rate": 0.5,
})
mock_get_service.return_value = mock_service
response = await async_client.post(
"/api/v1/ai-engines/query-batch",
json={
"engines": ["chatgpt", "perplexity"],
"query": "best insurance",
"brand_name": "BrandX",
"competitor_names": ["CompY"],
},
)
assert response.status_code == 200
data = response.json()
assert "results" in data
assert "citation_rate" in data
assert len(data["results"]) == 2
assert data["citation_rate"]["brand_citation_rate"] == 0.5
class TestGetResultsEndpoint:
@pytest.mark.asyncio
async def test_get_results(self, async_client):
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.KIMI, has_brand=False)
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
mock_service = AsyncMock()
mock_service.query_batch.return_value = [r1, r2]
mock_service.calculate_citation_rate = MagicMock(return_value={
"total_engines": 2,
"brand_citation_count": 1,
"brand_citation_rate": 0.5,
"competitor_citation_count": 0,
"competitor_citation_rate": 0.0,
})
mock_get_service.return_value = mock_service
response = await async_client.get(
"/api/v1/ai-engines/results",
params={
"engines": "chatgpt,kimi",
"query": "best insurance",
"brand_name": "BrandX",
},
)
assert response.status_code == 200
data = response.json()
assert "results" in data
assert "citation_rate" in data
class TestUnauthorizedAccess:
@pytest.mark.asyncio
async def test_unauthorized_returns_401(self, async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.post(
"/api/v1/ai-engines/query",
json={
"engine": "chatgpt",
"query": "test",
"brand_name": "BrandX",
},
headers=headers,
)
assert response.status_code == 401
app.dependency_overrides.clear()

View File

@ -0,0 +1,270 @@
import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
CitationInfo,
EngineType,
)
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter
def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
mock_resp = Mock()
mock_resp.status_code = status_code
mock_resp.json.return_value = json_data or {}
mock_resp.text = text
mock_resp.headers = headers or {}
return mock_resp
class TestKimiAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = KimiAdapter(api_key="test-key")
assert adapter.api_key == "test-key"
assert adapter.get_engine_type() == EngineType.KIMI
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = KimiAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT公司"}}],
"usage": {"total_tokens": 100},
"model": "moonshot-v1-8k",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.KIMI
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "moonshot-v1-8k"
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = KimiAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = KimiAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestWenxinAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
assert adapter.api_key == "test-key"
assert adapter.secret_key == "test-secret"
assert adapter.get_engine_type() == EngineType.WENXIN
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000}
mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}}
mock_client = AsyncMock()
token_response = _make_mock_response(200, mock_token_data)
chat_response = _make_mock_response(200, mock_chat_data)
mock_client.post.side_effect = [token_response, chat_response]
with patch.object(adapter, "_client", mock_client):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.WENXIN
assert "华为" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_api_error_handling(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_client = AsyncMock()
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
error_response = _make_mock_response(500, text="Internal Server Error")
mock_client.post.side_effect = [token_response, error_response]
with patch.object(adapter, "_client", mock_client):
with pytest.raises(RuntimeError, match="文心"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_client = AsyncMock()
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"})
mock_client.post.side_effect = [token_response, rate_limit_response]
with patch.object(adapter, "_client", mock_client):
with pytest.raises(RuntimeError, match="限流"):
await adapter.query("测试问题", brand_name="华为")
class TestDoubaoAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
assert adapter.api_key == "test-key"
assert adapter._endpoint_id == "ep-test"
assert adapter.get_engine_type() == EngineType.DOUBAO
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球知名企业"}}],
"model": "ep-test",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.DOUBAO
assert "华为" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestEngineType:
def test_kimi_engine_type(self):
adapter = KimiAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.KIMI
assert adapter.get_engine_type().value == "kimi"
def test_wenxin_engine_type(self):
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
assert adapter.get_engine_type() == EngineType.WENXIN
assert adapter.get_engine_type().value == "wenxin"
def test_doubao_engine_type(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
assert adapter.get_engine_type() == EngineType.DOUBAO
assert adapter.get_engine_type().value == "doubao"
class TestChineseCitationDetection:
def test_brand_name_detection_chinese(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为是全球领先的ICT基础设施和智能终端提供商",
brand_name="华为",
competitor_names=None,
)
assert has_brand is True
assert brand_ctx is not None
assert "华为" in brand_ctx
def test_brand_name_detection_case_insensitive(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, _, _, _ = adapter._detect_citations(
"Apple is a great company and apple makes phones",
brand_name="apple",
competitor_names=None,
)
assert has_brand is True
def test_competitor_name_detection(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为和小米都是中国知名手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) > 0
def test_no_citations_when_no_match(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"今天天气很好",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert len(comp_ctx) == 0
class TestAdapterInheritance:
def test_all_adapters_inherit_base(self):
assert issubclass(KimiAdapter, AIEngineAdapter)
assert issubclass(WenxinAdapter, AIEngineAdapter)
assert issubclass(DoubaoAdapter, AIEngineAdapter)
def test_all_adapters_have_query_method(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
assert hasattr(cls, "query")
assert callable(getattr(cls, "query"))
def test_all_adapters_have_detect_citations(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
assert hasattr(cls, "_detect_citations")
def test_all_adapters_have_get_engine_type(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
assert instance.get_engine_type() in EngineType

View File

@ -0,0 +1,529 @@
import pytest
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
CitationInfo,
EngineType,
)
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
class TestEngineType:
def test_engine_type_values(self):
assert EngineType.CHATGPT == "chatgpt"
assert EngineType.PERPLEXITY == "perplexity"
assert EngineType.KIMI == "kimi"
assert EngineType.WENXIN == "wenxin"
assert EngineType.DOUBAO == "doubao"
assert EngineType.DEEPSEEK == "deepseek"
assert EngineType.QWEN == "qwen"
class TestCitationInfo:
def test_citation_info_creation(self):
info = CitationInfo(
source_url="https://example.com",
source_title="Example Title",
citation_context="brand was mentioned here",
confidence=0.95,
position=1,
)
assert info.source_url == "https://example.com"
assert info.source_title == "Example Title"
assert info.citation_context == "brand was mentioned here"
assert info.confidence == 0.95
assert info.position == 1
def test_citation_info_optional_fields(self):
info = CitationInfo(
source_url=None,
source_title=None,
citation_context="some context",
confidence=0.5,
position=3,
)
assert info.source_url is None
assert info.source_title is None
class TestAIQueryResult:
def test_ai_query_result_creation(self):
now = datetime.now(UTC)
result = AIQueryResult(
engine_type=EngineType.CHATGPT,
query="best insurance companies",
raw_response="I recommend BrandX for insurance.",
citations=[],
has_brand_citation=True,
has_competitor_citation=False,
brand_context="I recommend BrandX for insurance.",
competitor_contexts=[],
response_time_ms=1500,
timestamp=now,
)
assert result.engine_type == EngineType.CHATGPT
assert result.query == "best insurance companies"
assert result.raw_response == "I recommend BrandX for insurance."
assert result.has_brand_citation is True
assert result.has_competitor_citation is False
assert result.brand_context == "I recommend BrandX for insurance."
assert result.competitor_contexts == []
assert result.response_time_ms == 1500
assert result.timestamp == now
def test_ai_query_result_with_citations(self):
citation = CitationInfo(
source_url="https://brandx.com",
source_title="BrandX Official",
citation_context="BrandX is a leading provider",
confidence=0.9,
position=1,
)
result = AIQueryResult(
engine_type=EngineType.PERPLEXITY,
query="insurance comparison",
raw_response="BrandX is great",
citations=[citation],
has_brand_citation=True,
has_competitor_citation=False,
brand_context="BrandX is great",
competitor_contexts=[],
response_time_ms=2000,
timestamp=datetime.now(UTC),
)
assert len(result.citations) == 1
assert result.citations[0].source_url == "https://brandx.com"
def test_ai_query_result_default_metadata(self):
result = AIQueryResult(
engine_type=EngineType.CHATGPT,
query="test",
raw_response="test",
citations=[],
has_brand_citation=False,
has_competitor_citation=False,
brand_context=None,
competitor_contexts=[],
response_time_ms=100,
timestamp=datetime.now(UTC),
)
assert result.metadata == {}
class TestAIEngineAdapterBase:
def test_cannot_instantiate_abstract_class(self):
with pytest.raises(TypeError):
AIEngineAdapter(api_key="test-key")
def test_detect_citations_brand_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"BrandX is the best insurance company",
"BrandX",
None,
)
assert has_brand is True
assert has_comp is False
assert brand_ctx is not None
assert "BrandX" in brand_ctx
def test_detect_citations_competitor_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"CompetitorY is also a good choice for insurance",
"BrandX",
["CompetitorY", "CompetitorZ"],
)
assert has_brand is False
assert has_comp is True
assert len(comp_ctx) == 1
assert "CompetitorY" in comp_ctx[0]
def test_detect_citations_both_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"BrandX and CompetitorY are both good insurance options",
"BrandX",
["CompetitorY"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) == 1
def test_detect_citations_none_found(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
adapter = ConcreteAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"Some random text without brand names",
"BrandX",
["CompetitorY"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert comp_ctx == []
def test_detect_citations_case_insensitive(self):
class ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
adapter = ConcreteAdapter(api_key="test-key")
has_brand, _, _, _ = adapter._detect_citations(
"brandx is great",
"BrandX",
None,
)
assert has_brand is True
class TestChatGPTAdapter:
@pytest.fixture
def chatgpt_adapter(self):
return ChatGPTAdapter(api_key="test-api-key")
def test_chatgpt_init(self, chatgpt_adapter):
assert chatgpt_adapter.api_key == "test-api-key"
assert chatgpt_adapter.get_engine_type() == EngineType.CHATGPT
@pytest.mark.asyncio
async def test_chatgpt_query_success(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "BrandX is a leading insurance company with great service."
}
}
],
"model": "gpt-4o",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="best insurance companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.CHATGPT
assert result.query == "best insurance companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert result.response_time_ms >= 0
@pytest.mark.asyncio
async def test_chatgpt_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = ChatGPTAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "gpt-4o",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_chatgpt_query_api_timeout(self, chatgpt_adapter):
import httpx
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await chatgpt_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_chatgpt_query_invalid_response(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await chatgpt_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_chatgpt_brand_citation_detection(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX offers excellent insurance coverage."}}
],
"model": "gpt-4o",
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="insurance",
brand_name="BrandX",
)
assert result.has_brand_citation is True
assert result.brand_context is not None
assert "BrandX" in result.brand_context
@pytest.mark.asyncio
async def test_chatgpt_competitor_citation_detection(self, chatgpt_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "CompetitorY and CompetitorZ are popular insurance providers."
}
}
],
"model": "gpt-4o",
}
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await chatgpt_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY", "CompetitorZ"],
)
assert result.has_brand_citation is False
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 2
class TestPerplexityAdapter:
@pytest.fixture
def perplexity_adapter(self):
return PerplexityAdapter(api_key="test-api-key")
def test_perplexity_init(self, perplexity_adapter):
assert perplexity_adapter.api_key == "test-api-key"
assert perplexity_adapter.get_engine_type() == EngineType.PERPLEXITY
@pytest.mark.asyncio
async def test_perplexity_query_success(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "BrandX is a well-known insurance brand [1]."
}
}
],
"citations": [
{"url": "https://brandx.com", "title": "BrandX Official Site"}
],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="best insurance companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.PERPLEXITY
assert result.query == "best insurance companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert len(result.citations) >= 1
@pytest.mark.asyncio
async def test_perplexity_query_with_citations(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX is recommended [1]. CompetitorY is also good [2]."}}
],
"citations": [
{"url": "https://brandx.com", "title": "BrandX"},
{"url": "https://competitory.com", "title": "CompetitorY"},
],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert len(result.citations) == 2
assert result.citations[0].source_url == "https://brandx.com"
assert result.citations[1].source_url == "https://competitory.com"
@pytest.mark.asyncio
async def test_perplexity_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = PerplexityAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "pplx-70b-online",
"citations": [],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_perplexity_query_api_timeout(self, perplexity_adapter):
import httpx
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await perplexity_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_perplexity_query_invalid_response(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await perplexity_adapter.query(
query="test query",
brand_name="BrandX",
)
@pytest.mark.asyncio
async def test_perplexity_brand_citation_detection(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX offers excellent insurance coverage."}}
],
"citations": [],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
)
assert result.has_brand_citation is True
assert result.brand_context is not None
@pytest.mark.asyncio
async def test_perplexity_competitor_citation_detection(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "CompetitorY is a popular insurance provider."}}
],
"citations": [],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="insurance",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert result.has_brand_citation is False
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_perplexity_no_citations_field(self, perplexity_adapter):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response without citations field"}}],
"model": "pplx-70b-online",
}
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await perplexity_adapter.query(
query="test",
brand_name="BrandX",
)
assert result.citations == []

View File

@ -0,0 +1,215 @@
import asyncio
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
def _make_result(
engine_type: EngineType,
query: str = "test query",
has_brand: bool = False,
has_competitor: bool = False,
) -> AIQueryResult:
return AIQueryResult(
engine_type=engine_type,
query=query,
raw_response="some response",
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_competitor,
brand_context="brand context" if has_brand else None,
competitor_contexts=["comp context"] if has_competitor else [],
response_time_ms=100,
timestamp=datetime.now(UTC),
)
class _StubAdapter(AIEngineAdapter):
def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None, side_effect=None):
super().__init__(api_key="test-key")
self._engine_type = engine_type
self._result = result
self._side_effect = side_effect
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult:
if self._side_effect:
raise self._side_effect
return self._result
def get_engine_type(self) -> EngineType:
return self._engine_type
class TestBatchQueryServiceInit:
@pytest.mark.asyncio
async def test_init_with_adapters(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT),
"perplexity": _StubAdapter(EngineType.PERPLEXITY),
}
service = BatchQueryService(adapters)
assert service.adapters is adapters
assert len(service.adapters) == 2
class TestBatchQuerySingleEngine:
@pytest.mark.asyncio
async def test_query_single_success(self):
from app.services.ai_engine.batch_query import BatchQueryService
expected = _make_result(EngineType.CHATGPT, has_brand=True)
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT, result=expected)}
service = BatchQueryService(adapters)
result = await service.query_single(
EngineType.CHATGPT, "best insurance", "BrandX", ["CompY"]
)
assert result == expected
assert result.engine_type == EngineType.CHATGPT
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_query_single_unknown_engine(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT)}
service = BatchQueryService(adapters)
with pytest.raises(ValueError, match="Unknown engine type"):
await service.query_single(EngineType.KIMI, "test", "BrandX")
class TestBatchQueryParallel:
@pytest.mark.asyncio
async def test_query_batch_multiple_engines(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"best insurance",
"BrandX",
)
assert len(results) == 2
engine_types = {r.engine_type for r in results}
assert engine_types == {EngineType.CHATGPT, EngineType.PERPLEXITY}
@pytest.mark.asyncio
async def test_query_batch_partial_failure(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(
EngineType.PERPLEXITY, side_effect=Exception("API error")
),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"best insurance",
"BrandX",
)
assert len(results) == 1
assert results[0].engine_type == EngineType.CHATGPT
@pytest.mark.asyncio
async def test_query_batch_all_fail(self):
from app.services.ai_engine.batch_query import BatchQueryService
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, side_effect=Exception("err")),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, side_effect=Exception("err")),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY],
"test",
"BrandX",
)
assert results == []
class TestBatchQueryAggregation:
@pytest.mark.asyncio
async def test_results_aggregation(self):
from app.services.ai_engine.batch_query import BatchQueryService
r1 = _make_result(EngineType.CHATGPT, has_brand=True, has_competitor=False)
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
r3 = _make_result(EngineType.KIMI, has_brand=True, has_competitor=True)
adapters = {
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
"kimi": _StubAdapter(EngineType.KIMI, result=r3),
}
service = BatchQueryService(adapters)
results = await service.query_batch(
[EngineType.CHATGPT, EngineType.PERPLEXITY, EngineType.KIMI],
"test",
"BrandX",
["CompY"],
)
assert len(results) == 3
brand_cited = [r for r in results if r.has_brand_citation]
competitor_cited = [r for r in results if r.has_competitor_citation]
assert len(brand_cited) == 2
assert len(competitor_cited) == 2
class TestCitationRateCalculation:
def test_brand_citation_rate(self):
from app.services.ai_engine.batch_query import BatchQueryService
results = [
_make_result(EngineType.CHATGPT, has_brand=True),
_make_result(EngineType.PERPLEXITY, has_brand=True),
_make_result(EngineType.KIMI, has_brand=False),
]
service = BatchQueryService({})
rate = service.calculate_citation_rate(results)
assert rate["total_engines"] == 3
assert rate["brand_citation_count"] == 2
assert rate["brand_citation_rate"] == pytest.approx(2 / 3)
def test_competitor_citation_rate(self):
from app.services.ai_engine.batch_query import BatchQueryService
results = [
_make_result(EngineType.CHATGPT, has_competitor=True),
_make_result(EngineType.PERPLEXITY, has_competitor=False),
_make_result(EngineType.KIMI, has_competitor=True),
_make_result(EngineType.WENXIN, has_competitor=True),
]
service = BatchQueryService({})
rate = service.calculate_citation_rate(results)
assert rate["total_engines"] == 4
assert rate["competitor_citation_count"] == 3
assert rate["competitor_citation_rate"] == pytest.approx(3 / 4)
def test_empty_results(self):
from app.services.ai_engine.batch_query import BatchQueryService
service = BatchQueryService({})
rate = service.calculate_citation_rate([])
assert rate["total_engines"] == 0
assert rate["brand_citation_count"] == 0
assert rate["brand_citation_rate"] == 0
assert rate["competitor_citation_count"] == 0
assert rate["competitor_citation_rate"] == 0

View File

@ -0,0 +1,584 @@
"use client";
import { useState, useMemo, useCallback } from "react";
import {
Card,
CardContent,
CardHeader,
CardTitle,
CardDescription,
} from "@/components/ui/card";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import {
Search,
RefreshCw,
CheckCircle,
XCircle,
Clock,
Cpu,
ArrowRight,
HelpCircle,
Zap,
} from "lucide-react";
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines";
import type {
AIEngineType,
AIQueryResult,
AIEnginesResponse,
CitationRate,
} from "@/types/ai-engines";
import { AI_ENGINE_OPTIONS } from "@/types/ai-engines";
import type { BrandListResponse } from "@/types/brand";
function RingProgress({
value,
size = 80,
strokeWidth = 6,
colorClass,
}: {
value: number;
size?: number;
strokeWidth?: number;
colorClass: string;
}) {
const radius = (size - strokeWidth) / 2;
const circumference = 2 * Math.PI * radius;
const offset = circumference - (value / 100) * circumference;
const colorMap: Record<string, string> = {
"text-emerald-500": "#10b981",
"text-emerald-600": "#059669",
"text-red-500": "#ef4444",
"text-red-600": "#dc2626",
"text-amber-500": "#f59e0b",
"text-blue-500": "#3b82f6",
};
const stroke = colorMap[colorClass] || "#10b981";
return (
<svg width={size} height={size} className="-rotate-90">
<circle
cx={size / 2}
cy={size / 2}
r={radius}
fill="none"
stroke="currentColor"
strokeWidth={strokeWidth}
className="text-muted/30"
/>
<circle
cx={size / 2}
cy={size / 2}
r={radius}
fill="none"
stroke={stroke}
strokeWidth={strokeWidth}
strokeDasharray={circumference}
strokeDashoffset={offset}
strokeLinecap="round"
className="transition-all duration-700"
/>
</svg>
);
}
function CitationRateCard({
rate,
label,
icon,
colorClass,
}: {
rate: number;
label: string;
icon: React.ReactNode;
colorClass: string;
}) {
const percentage = Math.round(rate * 100);
return (
<Card>
<CardContent className="pt-6">
<div className="flex items-center gap-4">
<div className="relative flex items-center justify-center">
<RingProgress
value={percentage}
size={80}
strokeWidth={6}
colorClass={colorClass}
/>
<span className="absolute text-lg font-bold">
{percentage}%
</span>
</div>
<div className="flex-1">
<div className="flex items-center gap-2">
<div
className={`rounded-lg bg-muted p-1.5 ${colorClass}`}
>
{icon}
</div>
<span className="text-sm text-muted-foreground">{label}</span>
</div>
</div>
</div>
</CardContent>
</Card>
);
}
function StatCard({
title,
value,
subtitle,
icon,
colorClass,
}: {
title: string;
value: string | number;
subtitle?: string;
icon: React.ReactNode;
colorClass: string;
}) {
return (
<Card>
<CardContent className="pt-6">
<div className="flex items-center gap-3">
<div className={`rounded-lg bg-muted p-2 ${colorClass}`}>
{icon}
</div>
<div>
<p className="text-sm text-muted-foreground">{title}</p>
<p className={`text-2xl font-bold ${colorClass}`}>{value}</p>
{subtitle && (
<p className="text-xs text-muted-foreground">{subtitle}</p>
)}
</div>
</div>
</CardContent>
</Card>
);
}
function EngineCheckboxGroup({
selected,
onToggle,
}: {
selected: AIEngineType[];
onToggle: (engine: AIEngineType) => void;
}) {
return (
<div className="flex flex-wrap gap-2">
{AI_ENGINE_OPTIONS.map((opt) => {
const isSelected = selected.includes(opt.value);
return (
<button
key={opt.value}
type="button"
onClick={() => onToggle(opt.value)}
className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${
isSelected
? "border-primary bg-primary/10 text-primary"
: "border-input bg-background text-muted-foreground hover:bg-muted"
}`}
>
{isSelected && <CheckCircle className="h-3.5 w-3.5" />}
{opt.label}
</button>
);
})}
</div>
);
}
function EngineResultCard({
result,
brandName,
}: {
result: AIQueryResult;
brandName: string;
}) {
const [expanded, setExpanded] = useState(false);
const engineLabel =
AI_ENGINE_OPTIONS.find((o) => o.value === result.engine_type)?.label ??
result.engine_type;
const citationStatus = result.has_brand_citation;
const highlightBrand = (text: string) => {
if (!brandName) return text;
const parts = text.split(new RegExp(`(${brandName.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")})`, "gi"));
return parts.map((part, i) =>
part.toLowerCase() === brandName.toLowerCase() ? (
<mark key={i} className="bg-emerald-100 text-emerald-800 rounded px-0.5">
{part}
</mark>
) : (
part
)
);
};
return (
<Card className={citationStatus ? "border-emerald-200" : "border-red-200"}>
<CardHeader className="pb-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<CardTitle className="text-lg">{engineLabel}</CardTitle>
{citationStatus ? (
<Badge className="bg-emerald-100 text-emerald-700 hover:bg-emerald-100">
<CheckCircle className="mr-1 h-3 w-3" />
</Badge>
) : (
<Badge className="bg-red-100 text-red-700 hover:bg-red-100">
<XCircle className="mr-1 h-3 w-3" />
</Badge>
)}
</div>
<div className="flex items-center gap-1.5 text-sm text-muted-foreground">
<Clock className="h-3.5 w-3.5" />
{result.response_time_ms}ms
</div>
</div>
</CardHeader>
<CardContent>
{result.has_competitor_citation && result.competitor_contexts.length > 0 && (
<div className="mb-3 rounded-lg border border-amber-200 bg-amber-50 p-3">
<p className="text-xs font-medium text-amber-800 mb-1">
</p>
<div className="space-y-1">
{result.competitor_contexts.map((ctx, i) => (
<p key={i} className="text-xs text-amber-700">
&ldquo;{ctx}&rdquo;
</p>
))}
</div>
</div>
)}
{result.brand_context && (
<div className="mb-3 rounded-lg border border-emerald-200 bg-emerald-50 p-3">
<p className="text-xs font-medium text-emerald-800 mb-1">
</p>
<p className="text-sm text-emerald-700">
&ldquo;{highlightBrand(result.brand_context)}&rdquo;
</p>
</div>
)}
<Button
variant="ghost"
size="sm"
onClick={() => setExpanded(!expanded)}
className="w-full justify-between"
>
<span className="text-sm text-muted-foreground">
{expanded ? "收起完整回答" : "查看AI完整回答"}
</span>
<ArrowRight
className={`h-4 w-4 transition-transform ${expanded ? "rotate-90" : ""}`}
/>
</Button>
{expanded && (
<div className="mt-3 rounded-lg border bg-muted/30 p-4">
<p className="text-sm leading-relaxed whitespace-pre-wrap">
{highlightBrand(result.raw_response)}
</p>
<p className="mt-3 text-xs text-muted-foreground">
: {new Date(result.timestamp).toLocaleString("zh-CN")}
</p>
</div>
)}
</CardContent>
</Card>
);
}
function LoadingState() {
return (
<div className="flex flex-col items-center justify-center py-12">
<RefreshCw className="h-8 w-8 animate-spin text-muted-foreground" />
<p className="mt-4 text-sm text-muted-foreground">
AI引擎...
</p>
</div>
);
}
function ErrorState({
message,
onRetry,
}: {
message: string;
onRetry: () => void;
}) {
return (
<Card className="border-red-200">
<CardContent className="pt-6">
<div className="flex flex-col items-center gap-3 text-center">
<XCircle className="h-10 w-10 text-red-500" />
<div>
<p className="font-medium text-red-800"></p>
<p className="text-sm text-red-600">{message}</p>
</div>
<Button variant="outline" size="sm" onClick={onRetry}>
<RefreshCw className="mr-2 h-4 w-4" />
</Button>
</div>
</CardContent>
</Card>
);
}
function EmptyState() {
return (
<Card>
<CardContent className="pt-6">
<div className="flex flex-col items-center gap-3 text-center">
<HelpCircle className="h-10 w-10 text-muted-foreground" />
<div>
<p className="font-medium"></p>
<p className="text-sm text-muted-foreground">
</p>
</div>
</div>
</CardContent>
</Card>
);
}
export default function AIEnginesPage() {
const [selectedBrandId, setSelectedBrandId] = useState<string>("");
const [queryText, setQueryText] = useState("");
const [selectedEngines, setSelectedEngines] = useState<AIEngineType[]>([
"chatgpt",
"perplexity",
"kimi",
"wenxin",
"doubao",
]);
const [queryResults, setQueryResults] = useState<AIEnginesResponse | null>(null);
const [queryError, setQueryError] = useState<string | null>(null);
const { data: brandsData, isLoading: brandsLoading } =
useApi<BrandListResponse>("/api/v1/brands/?limit=100&offset=0");
const queryMutation = useApiMutation<AIEnginesResponse, {
engines: AIEngineType[];
query: string;
brand_id: string;
}>("/api/v1/ai-engines/query");
const brands = brandsData?.items ?? [];
const selectedBrand = brands.find((b) => b.id === selectedBrandId);
const brandName = selectedBrand?.name ?? "";
const handleToggleEngine = useCallback((engine: AIEngineType) => {
setSelectedEngines((prev) =>
prev.includes(engine)
? prev.filter((e) => e !== engine)
: [...prev, engine]
);
}, []);
const handleQuery = useCallback(async () => {
if (!selectedBrandId || !queryText.trim() || selectedEngines.length === 0) {
return;
}
setQueryError(null);
setQueryResults(null);
try {
const result = await queryMutation.trigger({
engines: selectedEngines,
query: queryText.trim(),
brand_id: selectedBrandId,
});
if (result) {
setQueryResults(result);
} else {
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
}
} catch {
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
}
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);
const citationStats = useMemo(() => {
if (!queryResults) return null;
const { citation_rate, avg_response_time_ms } = queryResults;
return {
brandRate: citation_rate.brand_citation_rate,
competitorRate: citation_rate.competitor_citation_rate,
totalEngines: citation_rate.total_engines,
brandCount: citation_rate.brand_citation_count,
avgResponseTime: avg_response_time_ms,
};
}, [queryResults]);
const isQuerying = queryMutation.isMutating;
const canQuery =
selectedBrandId && queryText.trim() && selectedEngines.length > 0;
return (
<div className="space-y-6">
<div className="flex flex-col gap-4 sm:flex-row sm:items-center sm:justify-between">
<div>
<h1 className="text-2xl font-bold tracking-tight">AI引擎分析</h1>
<p className="text-muted-foreground">
AI搜索引擎中的引用情况
</p>
</div>
</div>
<Card>
<CardHeader>
<CardTitle></CardTitle>
<CardDescription>
AI引擎
</CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-4">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<Label></Label>
<Select
value={selectedBrandId}
onValueChange={setSelectedBrandId}
>
<SelectTrigger>
<SelectValue placeholder="请选择品牌" />
</SelectTrigger>
<SelectContent>
{brands.map((brand) => (
<SelectItem key={brand.id} value={brand.id}>
{brand.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label></Label>
<Input
placeholder="例如:最佳智能手表推荐"
value={queryText}
onChange={(e) => setQueryText(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && canQuery && !isQuerying) {
handleQuery();
}
}}
/>
</div>
</div>
<div className="space-y-2">
<Label></Label>
<EngineCheckboxGroup
selected={selectedEngines}
onToggle={handleToggleEngine}
/>
</div>
<Button
onClick={handleQuery}
disabled={!canQuery || isQuerying}
className="w-full sm:w-auto"
>
{isQuerying ? (
<>
<RefreshCw className="mr-2 h-4 w-4 animate-spin" />
...
</>
) : (
<>
<Search className="mr-2 h-4 w-4" />
</>
)}
</Button>
</div>
</CardContent>
</Card>
{isQuerying ? (
<LoadingState />
) : queryError ? (
<ErrorState
message={queryError}
onRetry={handleQuery}
/>
) : queryResults ? (
<>
{citationStats && (
<div className="grid gap-4 grid-cols-2 lg:grid-cols-4">
<CitationRateCard
rate={citationStats.brandRate}
label="品牌引用率"
icon={<CheckCircle className="h-4 w-4" />}
colorClass="text-emerald-500"
/>
<StatCard
title="覆盖引擎数"
value={`${citationStats.brandCount}/${citationStats.totalEngines}`}
subtitle="已引用/总引擎"
icon={<Cpu className="h-5 w-5" />}
colorClass="text-blue-500"
/>
<CitationRateCard
rate={citationStats.competitorRate}
label="竞品引用率"
icon={<Zap className="h-4 w-4" />}
colorClass="text-red-500"
/>
<StatCard
title="平均响应时间"
value={`${citationStats.avgResponseTime}`}
subtitle="毫秒 (ms)"
icon={<Clock className="h-5 w-5" />}
colorClass="text-amber-500"
/>
</div>
)}
<div>
<h2 className="mb-4 text-lg font-semibold"></h2>
<div className="space-y-4">
{queryResults.results.map((result) => (
<EngineResultCard
key={result.engine_type}
result={result}
brandName={brandName}
/>
))}
</div>
</div>
</>
) : (
<EmptyState />
)}
</div>
);
}

View File

@ -0,0 +1,110 @@
import { fetchWithAuth } from "./client";
import type { AIEngineType, AIEnginesResponse } from "@/types/ai-engines";
export const aiEnginesApi = {
querySingle: (engineType: string, query: string, brandId: string) =>
fetchWithAuth("/api/v1/ai-engines/query", {
method: "POST",
body: JSON.stringify({ engines: [engineType], query, brand_id: brandId }),
}),
queryBatch: (engines: AIEngineType[], query: string, brandId: string) =>
fetchWithAuth("/api/v1/ai-engines/query", {
method: "POST",
body: JSON.stringify({ engines, query, brand_id: brandId }),
}),
getResults: (brandId: string) =>
fetchWithAuth(`/api/v1/ai-engines/results/${brandId}`),
};
export const MOCK_AI_ENGINES_RESPONSE: AIEnginesResponse = {
results: [
{
engine_type: "chatgpt",
query: "最佳智能手表推荐",
raw_response:
"在智能手表领域Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻。此外,华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐。三星 Galaxy Watch 6 则是安卓用户的优质选择。",
has_brand_citation: true,
has_competitor_citation: true,
brand_context:
"Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻",
competitor_contexts: [
"华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐",
"三星 Galaxy Watch 6 则是安卓用户的优质选择",
],
response_time_ms: 3200,
timestamp: "2026-05-25T10:00:00Z",
},
{
engine_type: "perplexity",
query: "最佳智能手表推荐",
raw_response:
"根据最新评测,华为 Watch GT 4 在续航和运动追踪方面表现优异。三星 Galaxy Watch 6 提供了出色的安卓生态体验。Garmin Forerunner 265 则是专业运动爱好者的首选。",
has_brand_citation: false,
has_competitor_citation: true,
brand_context: null,
competitor_contexts: [
"华为 Watch GT 4 在续航和运动追踪方面表现优异",
"三星 Galaxy Watch 6 提供了出色的安卓生态体验",
"Garmin Forerunner 265 则是专业运动爱好者的首选",
],
response_time_ms: 2800,
timestamp: "2026-05-25T10:01:00Z",
},
{
engine_type: "kimi",
query: "最佳智能手表推荐",
raw_response:
"智能手表推荐方面Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选。华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域。",
has_brand_citation: true,
has_competitor_citation: true,
brand_context:
"Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选",
competitor_contexts: [
"华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域",
],
response_time_ms: 4100,
timestamp: "2026-05-25T10:02:00Z",
},
{
engine_type: "wenxin",
query: "最佳智能手表推荐",
raw_response:
"推荐华为 Watch GT 4续航长达14天运动追踪功能全面。小米 Watch S3 性价比很高。OPPO Watch 4 也是不错的选择。",
has_brand_citation: false,
has_competitor_citation: true,
brand_context: null,
competitor_contexts: [
"推荐华为 Watch GT 4续航长达14天运动追踪功能全面",
"小米 Watch S3 性价比很高",
"OPPO Watch 4 也是不错的选择",
],
response_time_ms: 1900,
timestamp: "2026-05-25T10:03:00Z",
},
{
engine_type: "doubao",
query: "最佳智能手表推荐",
raw_response:
"Apple Watch Series 9 在智能手表市场中综合表现最佳尤其是健康监测和App生态。华为 Watch GT 4 是国产手表中的佼佼者。",
has_brand_citation: true,
has_competitor_citation: true,
brand_context:
"Apple Watch Series 9 在智能手表市场中综合表现最佳尤其是健康监测和App生态",
competitor_contexts: [
"华为 Watch GT 4 是国产手表中的佼佼者",
],
response_time_ms: 2200,
timestamp: "2026-05-25T10:04:00Z",
},
],
citation_rate: {
total_engines: 5,
brand_citation_count: 3,
brand_citation_rate: 0.6,
competitor_citation_count: 5,
competitor_citation_rate: 1.0,
},
avg_response_time_ms: 2840,
};

View File

@ -0,0 +1,46 @@
export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao";
export interface AIEngineOption {
value: AIEngineType;
label: string;
}
export const AI_ENGINE_OPTIONS: AIEngineOption[] = [
{ value: "chatgpt", label: "ChatGPT" },
{ value: "perplexity", label: "Perplexity" },
{ value: "kimi", label: "Kimi" },
{ value: "wenxin", label: "文心一言" },
{ value: "doubao", label: "豆包" },
];
export interface AIQueryResult {
engine_type: string;
query: string;
raw_response: string;
has_brand_citation: boolean;
has_competitor_citation: boolean;
brand_context: string | null;
competitor_contexts: string[];
response_time_ms: number;
timestamp: string;
}
export interface CitationRate {
total_engines: number;
brand_citation_count: number;
brand_citation_rate: number;
competitor_citation_count: number;
competitor_citation_rate: number;
}
export interface AIEnginesResponse {
results: AIQueryResult[];
citation_rate: CitationRate;
avg_response_time_ms: number;
}
export interface AIQueryRequest {
engines: AIEngineType[];
query: string;
brand_id: string;
}