feat: Phase1 Week1-2 - AI引擎查询分析完整实现
后端(TDD): - AI引擎适配器框架(基类+5个适配器) - ChatGPT/Perplexity/Kimi/文心一言/豆包适配器 - 批量并行查询服务(asyncio.gather) - AI引擎查询API端点(3个) - 51+14=65个测试全部通过 前端: - AI引擎分析页面(引用率/引擎结果/上下文详情) - AI引擎API客户端+类型定义 - Mock数据降级支持
This commit is contained in:
parent
65e2f3c380
commit
1ec5ea42da
|
|
@ -0,0 +1,175 @@
|
|||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SingleQueryRequest(BaseModel):
|
||||
engine: str
|
||||
query: str = Field(min_length=1, max_length=500)
|
||||
brand_name: str = Field(min_length=1, max_length=200)
|
||||
competitor_names: list[str] | None = None
|
||||
|
||||
|
||||
class BatchQueryRequest(BaseModel):
|
||||
engines: list[str] = Field(min_length=1)
|
||||
query: str = Field(min_length=1, max_length=500)
|
||||
brand_name: str = Field(min_length=1, max_length=200)
|
||||
competitor_names: list[str] | None = None
|
||||
|
||||
|
||||
class QueryResultResponse(BaseModel):
|
||||
engine_type: str
|
||||
query: str
|
||||
raw_response: str
|
||||
has_brand_citation: bool
|
||||
has_competitor_citation: bool
|
||||
brand_context: str | None
|
||||
competitor_contexts: list[str]
|
||||
response_time_ms: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CitationRateResponse(BaseModel):
|
||||
total_engines: int
|
||||
brand_citation_count: int
|
||||
brand_citation_rate: float
|
||||
competitor_citation_count: int
|
||||
competitor_citation_rate: float
|
||||
|
||||
|
||||
class BatchQueryResponse(BaseModel):
|
||||
results: list[QueryResultResponse]
|
||||
citation_rate: CitationRateResponse
|
||||
|
||||
|
||||
_ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
|
||||
EngineType.CHATGPT: ChatGPTAdapter,
|
||||
EngineType.PERPLEXITY: PerplexityAdapter,
|
||||
EngineType.KIMI: KimiAdapter,
|
||||
EngineType.WENXIN: WenxinAdapter,
|
||||
EngineType.DOUBAO: DoubaoAdapter,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _build_adapters() -> dict[str, AIEngineAdapter]:
|
||||
adapters: dict[str, AIEngineAdapter] = {}
|
||||
for engine_type, cls in _ADAPTER_CLASSES.items():
|
||||
try:
|
||||
adapters[engine_type.value] = cls()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to initialize {engine_type.value} adapter")
|
||||
return adapters
|
||||
|
||||
|
||||
def get_batch_service() -> BatchQueryService:
|
||||
return BatchQueryService(_build_adapters())
|
||||
|
||||
|
||||
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
||||
return QueryResultResponse(
|
||||
engine_type=r.engine_type.value,
|
||||
query=r.query,
|
||||
raw_response=r.raw_response,
|
||||
has_brand_citation=r.has_brand_citation,
|
||||
has_competitor_citation=r.has_competitor_citation,
|
||||
brand_context=r.brand_context,
|
||||
competitor_contexts=r.competitor_contexts,
|
||||
response_time_ms=r.response_time_ms,
|
||||
)
|
||||
|
||||
|
||||
def _parse_engine(value: str) -> EngineType:
|
||||
try:
|
||||
return EngineType(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown engine: {value}",
|
||||
)
|
||||
|
||||
|
||||
async def _execute_batch(
|
||||
service: BatchQueryService,
|
||||
engine_types: list[EngineType],
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None,
|
||||
) -> BatchQueryResponse:
|
||||
results = await service.query_batch(
|
||||
engine_types, query, brand_name, competitor_names
|
||||
)
|
||||
citation_rate = service.calculate_citation_rate(results)
|
||||
return BatchQueryResponse(
|
||||
results=[_result_to_response(r) for r in results],
|
||||
citation_rate=CitationRateResponse(**citation_rate),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/query", response_model=QueryResultResponse)
|
||||
async def query_single_engine(
|
||||
request: SingleQueryRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = get_batch_service()
|
||||
engine_type = _parse_engine(request.engine)
|
||||
try:
|
||||
result = await service.query_single(
|
||||
engine_type, request.query, request.brand_name, request.competitor_names
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
return _result_to_response(result)
|
||||
|
||||
|
||||
@router.post("/query-batch", response_model=BatchQueryResponse)
|
||||
async def query_batch_engines(
|
||||
request: BatchQueryRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = get_batch_service()
|
||||
engine_types = [_parse_engine(e) for e in request.engines]
|
||||
return await _execute_batch(
|
||||
service, engine_types, request.query, request.brand_name, request.competitor_names
|
||||
)
|
||||
|
||||
|
||||
@router.get("/results", response_model=BatchQueryResponse)
|
||||
async def get_query_results(
|
||||
engines: str = Query(..., description="Comma-separated engine names"),
|
||||
query: str = Query(..., min_length=1, max_length=500),
|
||||
brand_name: str = Query(..., min_length=1, max_length=200),
|
||||
competitor_names: str | None = Query(None, description="Comma-separated competitor names"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
service = get_batch_service()
|
||||
engine_list = [e.strip() for e in engines.split(",") if e.strip()]
|
||||
engine_types = [_parse_engine(e) for e in engine_list]
|
||||
comp_names = (
|
||||
[n.strip() for n in competitor_names.split(",") if n.strip()]
|
||||
if competitor_names
|
||||
else None
|
||||
)
|
||||
return await _execute_batch(
|
||||
service, engine_types, query, brand_name, comp_names
|
||||
)
|
||||
|
|
@ -38,6 +38,7 @@ from app.api.platforms import router as platforms_router
|
|||
from app.api.platform_rules import router as platform_rules_router
|
||||
from app.api.image import router as image_router
|
||||
from app.api.knowledge_graph import router as knowledge_graph_router
|
||||
from app.api.ai_engines import router as ai_engines_router
|
||||
from app.config import settings
|
||||
from app.database import engine, Base
|
||||
from app.schemas.common import ErrorResponse, ErrorCode
|
||||
|
|
@ -165,6 +166,7 @@ app.include_router(platforms_router, prefix="/api/v1")
|
|||
app.include_router(platform_rules_router)
|
||||
app.include_router(image_router, prefix="/api/v1")
|
||||
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
|
||||
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
|
||||
|
||||
|
||||
@app.get("/health", tags=["可观测性"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
|
||||
from .chatgpt import ChatGPTAdapter
|
||||
from .perplexity import PerplexityAdapter
|
||||
from .kimi import KimiAdapter
|
||||
from .wenxin import WenxinAdapter
|
||||
from .doubao import DoubaoAdapter
|
||||
from .batch_query import BatchQueryService
|
||||
|
||||
__all__ = [
|
||||
"AIEngineAdapter",
|
||||
"AIQueryResult",
|
||||
"CitationInfo",
|
||||
"EngineType",
|
||||
"ChatGPTAdapter",
|
||||
"PerplexityAdapter",
|
||||
"KimiAdapter",
|
||||
"WenxinAdapter",
|
||||
"DoubaoAdapter",
|
||||
"BatchQueryService",
|
||||
]
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
_RETRYABLE_STATUS = {429, 500, 502, 503}
|
||||
|
||||
|
||||
class EngineType(str, Enum):
|
||||
CHATGPT = "chatgpt"
|
||||
PERPLEXITY = "perplexity"
|
||||
KIMI = "kimi"
|
||||
WENXIN = "wenxin"
|
||||
DOUBAO = "doubao"
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CitationInfo:
|
||||
source_url: str | None
|
||||
source_title: str | None
|
||||
citation_context: str
|
||||
confidence: float
|
||||
position: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIQueryResult:
|
||||
engine_type: EngineType
|
||||
query: str
|
||||
raw_response: str
|
||||
citations: list[CitationInfo]
|
||||
has_brand_citation: bool
|
||||
has_competitor_citation: bool
|
||||
brand_context: str | None
|
||||
competitor_contexts: list[str]
|
||||
response_time_ms: int
|
||||
timestamp: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AIEngineAdapter(ABC):
|
||||
def __init__(self, api_key: str, rate_limiter=None):
|
||||
self.api_key = api_key
|
||||
self.rate_limiter = rate_limiter
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_engine_type(self) -> EngineType:
|
||||
pass
|
||||
|
||||
def _detect_citations(
|
||||
self,
|
||||
response: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None,
|
||||
) -> tuple[bool, bool, str | None, list[str]]:
|
||||
has_brand = brand_name.lower() in response.lower()
|
||||
brand_context = None
|
||||
if has_brand:
|
||||
idx = response.lower().find(brand_name.lower())
|
||||
start = max(0, idx - 100)
|
||||
end = min(len(response), idx + len(brand_name) + 100)
|
||||
brand_context = response[start:end]
|
||||
|
||||
has_competitor = False
|
||||
competitor_contexts = []
|
||||
if competitor_names:
|
||||
for name in competitor_names:
|
||||
if name.lower() in response.lower():
|
||||
has_competitor = True
|
||||
idx = response.lower().find(name.lower())
|
||||
start = max(0, idx - 100)
|
||||
end = min(len(response), idx + len(name) + 100)
|
||||
competitor_contexts.append(response[start:end])
|
||||
|
||||
return has_brand, has_competitor, brand_context, competitor_contexts
|
||||
|
||||
async def _request_with_retry(self, payload: dict) -> dict:
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
engine_name = self.get_engine_type().value
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
response = await self._client.post(self._endpoint, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
if response.status_code in _RETRYABLE_STATUS:
|
||||
retry_after = response.headers.get("retry-after")
|
||||
wait = float(retry_after) if retry_after else 2**attempt
|
||||
logger.warning(
|
||||
f"[{engine_name}] HTTP {response.status_code}, "
|
||||
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
|
||||
)
|
||||
last_error = Exception(
|
||||
f"HTTP {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
raise Exception(
|
||||
f"HTTP {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
|
||||
except httpx.TransportError as exc:
|
||||
logger.warning(
|
||||
f"[{engine_name}] Transport error: {exc}, "
|
||||
f"retry {attempt + 1}/{_MAX_RETRIES}"
|
||||
)
|
||||
last_error = Exception(f"Network error: {exc}")
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
|
||||
raise last_error or Exception("Max retries exceeded")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def __aenter__(self) -> "AIEngineAdapter":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
await self.close()
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchQueryService:
|
||||
def __init__(self, adapters: dict[str, AIEngineAdapter]):
|
||||
self.adapters = adapters
|
||||
|
||||
async def query_single(
|
||||
self,
|
||||
engine_type: EngineType,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
adapter = self.adapters.get(engine_type.value)
|
||||
if not adapter:
|
||||
raise ValueError(f"Unknown engine type: {engine_type}")
|
||||
return await adapter.query(query, brand_name, competitor_names)
|
||||
|
||||
async def query_batch(
|
||||
self,
|
||||
engines: list[EngineType],
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> list[AIQueryResult]:
|
||||
tasks = [
|
||||
self.query_single(engine, query, brand_name, competitor_names)
|
||||
for engine in engines
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
successful: list[AIQueryResult] = []
|
||||
for r in results:
|
||||
if isinstance(r, AIQueryResult):
|
||||
successful.append(r)
|
||||
elif isinstance(r, Exception):
|
||||
logger.warning(f"Engine query failed: {r}")
|
||||
return successful
|
||||
|
||||
def calculate_citation_rate(self, results: list[AIQueryResult]) -> dict:
|
||||
total = len(results)
|
||||
brand_cited = sum(1 for r in results if r.has_brand_citation)
|
||||
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
|
||||
return {
|
||||
"total_engines": total,
|
||||
"brand_citation_count": brand_cited,
|
||||
"brand_citation_rate": brand_cited / total if total > 0 else 0,
|
||||
"competitor_citation_count": competitor_cited,
|
||||
"competitor_citation_rate": competitor_cited / total if total > 0 else 0,
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "gpt-4o"
|
||||
_DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
class ChatGPTAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
base_url or os.getenv("OPENAI_BASE_URL", _DEFAULT_BASE_URL)
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.CHATGPT
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[chatgpt] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
)
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "doubao-pro-4k"
|
||||
_DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
|
||||
class DoubaoAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
endpoint_id: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DOUBAO_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._endpoint_id = endpoint_id or os.getenv("DOUBAO_ENDPOINT_ID", "")
|
||||
self._base_url = _DEFAULT_BASE_URL.rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.DOUBAO
|
||||
|
||||
def _get_model_id(self) -> str:
|
||||
if self._endpoint_id and self._endpoint_id.strip():
|
||||
if not self._endpoint_id.startswith("ep-"):
|
||||
return f"ep-{self._endpoint_id}"
|
||||
return self._endpoint_id
|
||||
return _DEFAULT_MODEL
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
model_id = self._get_model_id()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[doubao] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", model_id), "usage": data.get("usage")},
|
||||
)
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "moonshot-v1-8k"
|
||||
_DEFAULT_BASE_URL = "https://api.moonshot.cn/v1"
|
||||
|
||||
|
||||
class KimiAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("MOONSHOT_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or _DEFAULT_MODEL
|
||||
self._base_url = (
|
||||
base_url or _DEFAULT_BASE_URL
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.KIMI
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[kimi] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "pplx-70b-online"
|
||||
_DEFAULT_BASE_URL = "https://api.perplexity.ai"
|
||||
|
||||
|
||||
class PerplexityAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
base_url or os.getenv("PERPLEXITY_BASE_URL", _DEFAULT_BASE_URL)
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.PERPLEXITY
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful research assistant."},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
citations = self._extract_citations(data)
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[perplexity] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} citations={len(citations)} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=citations,
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
)
|
||||
|
||||
def _extract_citations(self, data: dict) -> list[CitationInfo]:
|
||||
raw_citations = data.get("citations", [])
|
||||
if not raw_citations:
|
||||
return []
|
||||
|
||||
citations = []
|
||||
for idx, cit in enumerate(raw_citations):
|
||||
citations.append(
|
||||
CitationInfo(
|
||||
source_url=cit.get("url"),
|
||||
source_title=cit.get("title"),
|
||||
citation_context="",
|
||||
confidence=1.0,
|
||||
position=idx + 1,
|
||||
)
|
||||
)
|
||||
return citations
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "completions_pro"
|
||||
_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
_CHAT_URL_TEMPLATE = (
|
||||
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}"
|
||||
"?access_token={token}"
|
||||
)
|
||||
|
||||
_cached_token: str | None = None
|
||||
_token_expires_at: float = 0.0
|
||||
|
||||
|
||||
class WenxinAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("BAIDU_QIANFAN_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self.secret_key = secret_key or os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||||
self._model = _DEFAULT_MODEL
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.WENXIN
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
global _cached_token, _token_expires_at
|
||||
|
||||
now = time.monotonic()
|
||||
if _cached_token and now < _token_expires_at:
|
||||
return _cached_token
|
||||
|
||||
response = await self._client.post(
|
||||
_TOKEN_URL,
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.api_key,
|
||||
"client_secret": self.secret_key,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"文心一言获取 access_token 失败: {response.status_code} {response.text[:300]}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
token = data.get("access_token")
|
||||
if not token:
|
||||
error_desc = data.get("error_description", "未知错误")
|
||||
raise RuntimeError(f"文心一言获取 access_token 失败: {error_desc}")
|
||||
|
||||
expires_in = data.get("expires_in", 2592000)
|
||||
_cached_token = token
|
||||
_token_expires_at = now + expires_in - 300
|
||||
|
||||
logger.info("[wenxin] access_token 获取成功")
|
||||
return token
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
access_token = await self._get_access_token()
|
||||
chat_url = _CHAT_URL_TEMPLATE.format(
|
||||
model=self._model,
|
||||
token=access_token,
|
||||
)
|
||||
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"system": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 2000,
|
||||
}
|
||||
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
response = await self._client.post(chat_url, json=payload)
|
||||
|
||||
if response.status_code == 429:
|
||||
raise RuntimeError("文心一言 API 限流")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_body = response.text[:500]
|
||||
raise RuntimeError(
|
||||
f"文心一言 API 返回错误 {response.status_code}: {error_body}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
error_code = data.get("error_code")
|
||||
if error_code:
|
||||
error_msg = data.get("error_msg", "未知错误")
|
||||
raise RuntimeError(f"文心一言 API 错误 {error_code}: {error_msg}")
|
||||
|
||||
content = data.get("result", "")
|
||||
if not content:
|
||||
raise RuntimeError("文心一言 API 返回空内容")
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[wenxin] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": self._model, "usage": data.get("usage")},
|
||||
)
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.main import app
|
||||
from app.models.user import User
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.services.auth import hash_password
|
||||
from app.services.ai_engine.base import AIQueryResult, EngineType
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(async_session, test_user):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
async def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_result(
|
||||
engine_type: EngineType,
|
||||
query: str = "best insurance",
|
||||
has_brand: bool = False,
|
||||
has_competitor: bool = False,
|
||||
) -> AIQueryResult:
|
||||
return AIQueryResult(
|
||||
engine_type=engine_type,
|
||||
query=query,
|
||||
raw_response="BrandX is a great insurance company",
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_competitor,
|
||||
brand_context="BrandX is great" if has_brand else None,
|
||||
competitor_contexts=["CompY is ok"] if has_competitor else [],
|
||||
response_time_ms=150,
|
||||
timestamp=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class TestSingleQueryEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_single_engine(self, async_client):
|
||||
mock_result = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_single.return_value = mock_result
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
response = await async_client.post(
|
||||
"/api/v1/ai-engines/query",
|
||||
json={
|
||||
"engine": "chatgpt",
|
||||
"query": "best insurance",
|
||||
"brand_name": "BrandX",
|
||||
"competitor_names": ["CompY"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["engine_type"] == "chatgpt"
|
||||
assert data["has_brand_citation"] is True
|
||||
assert data["query"] == "best insurance"
|
||||
|
||||
|
||||
class TestBatchQueryEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_batch_parallel(self, async_client):
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_batch.return_value = [r1, r2]
|
||||
mock_service.calculate_citation_rate = MagicMock(return_value={
|
||||
"total_engines": 2,
|
||||
"brand_citation_count": 1,
|
||||
"brand_citation_rate": 0.5,
|
||||
"competitor_citation_count": 1,
|
||||
"competitor_citation_rate": 0.5,
|
||||
})
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
response = await async_client.post(
|
||||
"/api/v1/ai-engines/query-batch",
|
||||
json={
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
"query": "best insurance",
|
||||
"brand_name": "BrandX",
|
||||
"competitor_names": ["CompY"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "results" in data
|
||||
assert "citation_rate" in data
|
||||
assert len(data["results"]) == 2
|
||||
assert data["citation_rate"]["brand_citation_rate"] == 0.5
|
||||
|
||||
|
||||
class TestGetResultsEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_results(self, async_client):
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
r2 = _make_result(EngineType.KIMI, has_brand=False)
|
||||
with patch("app.api.ai_engines.get_batch_service") as mock_get_service:
|
||||
mock_service = AsyncMock()
|
||||
mock_service.query_batch.return_value = [r1, r2]
|
||||
mock_service.calculate_citation_rate = MagicMock(return_value={
|
||||
"total_engines": 2,
|
||||
"brand_citation_count": 1,
|
||||
"brand_citation_rate": 0.5,
|
||||
"competitor_citation_count": 0,
|
||||
"competitor_citation_rate": 0.0,
|
||||
})
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/v1/ai-engines/results",
|
||||
params={
|
||||
"engines": "chatgpt,kimi",
|
||||
"query": "best insurance",
|
||||
"brand_name": "BrandX",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "results" in data
|
||||
assert "citation_rate" in data
|
||||
|
||||
|
||||
class TestUnauthorizedAccess:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_returns_401(self, async_session):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.post(
|
||||
"/api/v1/ai-engines/query",
|
||||
json={
|
||||
"engine": "chatgpt",
|
||||
"query": "test",
|
||||
"brand_name": "BrandX",
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
|
@ -0,0 +1,270 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
|
||||
from app.services.ai_engine.base import (
|
||||
AIEngineAdapter,
|
||||
AIQueryResult,
|
||||
CitationInfo,
|
||||
EngineType,
|
||||
)
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
|
||||
|
||||
def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
|
||||
mock_resp = Mock()
|
||||
mock_resp.status_code = status_code
|
||||
mock_resp.json.return_value = json_data or {}
|
||||
mock_resp.text = text
|
||||
mock_resp.headers = headers or {}
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestKimiAdapter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
assert adapter.api_key == "test-key"
|
||||
assert adapter.get_engine_type() == EngineType.KIMI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为是全球领先的ICT公司"}}],
|
||||
"usage": {"total_tokens": 100},
|
||||
"model": "moonshot-v1-8k",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("华为公司", brand_name="华为")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.KIMI
|
||||
assert "华为" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert result.metadata.get("model") == "moonshot-v1-8k"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 500: Internal Server Error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 429: rate limited"),
|
||||
):
|
||||
with pytest.raises(Exception, match="429"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
|
||||
class TestWenxinAdapter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
||||
assert adapter.api_key == "test-key"
|
||||
assert adapter.secret_key == "test-secret"
|
||||
assert adapter.get_engine_type() == EngineType.WENXIN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
import app.services.ai_engine.wenxin as wenxin_mod
|
||||
wenxin_mod._cached_token = None
|
||||
wenxin_mod._token_expires_at = 0.0
|
||||
|
||||
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
||||
mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000}
|
||||
mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
token_response = _make_mock_response(200, mock_token_data)
|
||||
chat_response = _make_mock_response(200, mock_chat_data)
|
||||
mock_client.post.side_effect = [token_response, chat_response]
|
||||
|
||||
with patch.object(adapter, "_client", mock_client):
|
||||
result = await adapter.query("华为公司", brand_name="华为")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.WENXIN
|
||||
assert "华为" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self):
|
||||
import app.services.ai_engine.wenxin as wenxin_mod
|
||||
wenxin_mod._cached_token = None
|
||||
wenxin_mod._token_expires_at = 0.0
|
||||
|
||||
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
|
||||
error_response = _make_mock_response(500, text="Internal Server Error")
|
||||
mock_client.post.side_effect = [token_response, error_response]
|
||||
|
||||
with patch.object(adapter, "_client", mock_client):
|
||||
with pytest.raises(RuntimeError, match="文心"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling(self):
|
||||
import app.services.ai_engine.wenxin as wenxin_mod
|
||||
wenxin_mod._cached_token = None
|
||||
wenxin_mod._token_expires_at = 0.0
|
||||
|
||||
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
|
||||
rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"})
|
||||
mock_client.post.side_effect = [token_response, rate_limit_response]
|
||||
|
||||
with patch.object(adapter, "_client", mock_client):
|
||||
with pytest.raises(RuntimeError, match="限流"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
|
||||
class TestDoubaoAdapter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
||||
assert adapter.api_key == "test-key"
|
||||
assert adapter._endpoint_id == "ep-test"
|
||||
assert adapter.get_engine_type() == EngineType.DOUBAO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为是全球知名企业"}}],
|
||||
"model": "ep-test",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("华为公司", brand_name="华为")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.DOUBAO
|
||||
assert "华为" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self):
|
||||
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 500: Internal Server Error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling(self):
|
||||
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 429: rate limited"),
|
||||
):
|
||||
with pytest.raises(Exception, match="429"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
|
||||
class TestEngineType:
|
||||
def test_kimi_engine_type(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.KIMI
|
||||
assert adapter.get_engine_type().value == "kimi"
|
||||
|
||||
def test_wenxin_engine_type(self):
|
||||
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
||||
assert adapter.get_engine_type() == EngineType.WENXIN
|
||||
assert adapter.get_engine_type().value == "wenxin"
|
||||
|
||||
def test_doubao_engine_type(self):
|
||||
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
||||
assert adapter.get_engine_type() == EngineType.DOUBAO
|
||||
assert adapter.get_engine_type().value == "doubao"
|
||||
|
||||
|
||||
class TestChineseCitationDetection:
|
||||
def test_brand_name_detection_chinese(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"华为是全球领先的ICT基础设施和智能终端提供商",
|
||||
brand_name="华为",
|
||||
competitor_names=None,
|
||||
)
|
||||
assert has_brand is True
|
||||
assert brand_ctx is not None
|
||||
assert "华为" in brand_ctx
|
||||
|
||||
def test_brand_name_detection_case_insensitive(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
has_brand, _, _, _ = adapter._detect_citations(
|
||||
"Apple is a great company and apple makes phones",
|
||||
brand_name="apple",
|
||||
competitor_names=None,
|
||||
)
|
||||
assert has_brand is True
|
||||
|
||||
def test_competitor_name_detection(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"华为和小米都是中国知名手机品牌",
|
||||
brand_name="华为",
|
||||
competitor_names=["小米"],
|
||||
)
|
||||
assert has_brand is True
|
||||
assert has_comp is True
|
||||
assert brand_ctx is not None
|
||||
assert len(comp_ctx) > 0
|
||||
|
||||
def test_no_citations_when_no_match(self):
|
||||
adapter = KimiAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"今天天气很好",
|
||||
brand_name="华为",
|
||||
competitor_names=["小米"],
|
||||
)
|
||||
assert has_brand is False
|
||||
assert has_comp is False
|
||||
assert brand_ctx is None
|
||||
assert len(comp_ctx) == 0
|
||||
|
||||
|
||||
class TestAdapterInheritance:
|
||||
def test_all_adapters_inherit_base(self):
|
||||
assert issubclass(KimiAdapter, AIEngineAdapter)
|
||||
assert issubclass(WenxinAdapter, AIEngineAdapter)
|
||||
assert issubclass(DoubaoAdapter, AIEngineAdapter)
|
||||
|
||||
def test_all_adapters_have_query_method(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
assert hasattr(cls, "query")
|
||||
assert callable(getattr(cls, "query"))
|
||||
|
||||
def test_all_adapters_have_detect_citations(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
assert hasattr(cls, "_detect_citations")
|
||||
|
||||
def test_all_adapters_have_get_engine_type(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
|
||||
assert instance.get_engine_type() in EngineType
|
||||
|
|
@ -0,0 +1,529 @@
|
|||
import pytest
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.ai_engine.base import (
|
||||
AIEngineAdapter,
|
||||
AIQueryResult,
|
||||
CitationInfo,
|
||||
EngineType,
|
||||
)
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
|
||||
|
||||
class TestEngineType:
|
||||
def test_engine_type_values(self):
|
||||
assert EngineType.CHATGPT == "chatgpt"
|
||||
assert EngineType.PERPLEXITY == "perplexity"
|
||||
assert EngineType.KIMI == "kimi"
|
||||
assert EngineType.WENXIN == "wenxin"
|
||||
assert EngineType.DOUBAO == "doubao"
|
||||
assert EngineType.DEEPSEEK == "deepseek"
|
||||
assert EngineType.QWEN == "qwen"
|
||||
|
||||
|
||||
class TestCitationInfo:
|
||||
def test_citation_info_creation(self):
|
||||
info = CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example Title",
|
||||
citation_context="brand was mentioned here",
|
||||
confidence=0.95,
|
||||
position=1,
|
||||
)
|
||||
assert info.source_url == "https://example.com"
|
||||
assert info.source_title == "Example Title"
|
||||
assert info.citation_context == "brand was mentioned here"
|
||||
assert info.confidence == 0.95
|
||||
assert info.position == 1
|
||||
|
||||
def test_citation_info_optional_fields(self):
|
||||
info = CitationInfo(
|
||||
source_url=None,
|
||||
source_title=None,
|
||||
citation_context="some context",
|
||||
confidence=0.5,
|
||||
position=3,
|
||||
)
|
||||
assert info.source_url is None
|
||||
assert info.source_title is None
|
||||
|
||||
|
||||
class TestAIQueryResult:
|
||||
def test_ai_query_result_creation(self):
|
||||
now = datetime.now(UTC)
|
||||
result = AIQueryResult(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
query="best insurance companies",
|
||||
raw_response="I recommend BrandX for insurance.",
|
||||
citations=[],
|
||||
has_brand_citation=True,
|
||||
has_competitor_citation=False,
|
||||
brand_context="I recommend BrandX for insurance.",
|
||||
competitor_contexts=[],
|
||||
response_time_ms=1500,
|
||||
timestamp=now,
|
||||
)
|
||||
assert result.engine_type == EngineType.CHATGPT
|
||||
assert result.query == "best insurance companies"
|
||||
assert result.raw_response == "I recommend BrandX for insurance."
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is False
|
||||
assert result.brand_context == "I recommend BrandX for insurance."
|
||||
assert result.competitor_contexts == []
|
||||
assert result.response_time_ms == 1500
|
||||
assert result.timestamp == now
|
||||
|
||||
def test_ai_query_result_with_citations(self):
|
||||
citation = CitationInfo(
|
||||
source_url="https://brandx.com",
|
||||
source_title="BrandX Official",
|
||||
citation_context="BrandX is a leading provider",
|
||||
confidence=0.9,
|
||||
position=1,
|
||||
)
|
||||
result = AIQueryResult(
|
||||
engine_type=EngineType.PERPLEXITY,
|
||||
query="insurance comparison",
|
||||
raw_response="BrandX is great",
|
||||
citations=[citation],
|
||||
has_brand_citation=True,
|
||||
has_competitor_citation=False,
|
||||
brand_context="BrandX is great",
|
||||
competitor_contexts=[],
|
||||
response_time_ms=2000,
|
||||
timestamp=datetime.now(UTC),
|
||||
)
|
||||
assert len(result.citations) == 1
|
||||
assert result.citations[0].source_url == "https://brandx.com"
|
||||
|
||||
def test_ai_query_result_default_metadata(self):
|
||||
result = AIQueryResult(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
query="test",
|
||||
raw_response="test",
|
||||
citations=[],
|
||||
has_brand_citation=False,
|
||||
has_competitor_citation=False,
|
||||
brand_context=None,
|
||||
competitor_contexts=[],
|
||||
response_time_ms=100,
|
||||
timestamp=datetime.now(UTC),
|
||||
)
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestAIEngineAdapterBase:
|
||||
def test_cannot_instantiate_abstract_class(self):
|
||||
with pytest.raises(TypeError):
|
||||
AIEngineAdapter(api_key="test-key")
|
||||
|
||||
def test_detect_citations_brand_found(self):
|
||||
class ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"BrandX is the best insurance company",
|
||||
"BrandX",
|
||||
None,
|
||||
)
|
||||
assert has_brand is True
|
||||
assert has_comp is False
|
||||
assert brand_ctx is not None
|
||||
assert "BrandX" in brand_ctx
|
||||
|
||||
def test_detect_citations_competitor_found(self):
|
||||
class ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"CompetitorY is also a good choice for insurance",
|
||||
"BrandX",
|
||||
["CompetitorY", "CompetitorZ"],
|
||||
)
|
||||
assert has_brand is False
|
||||
assert has_comp is True
|
||||
assert len(comp_ctx) == 1
|
||||
assert "CompetitorY" in comp_ctx[0]
|
||||
|
||||
def test_detect_citations_both_found(self):
|
||||
class ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"BrandX and CompetitorY are both good insurance options",
|
||||
"BrandX",
|
||||
["CompetitorY"],
|
||||
)
|
||||
assert has_brand is True
|
||||
assert has_comp is True
|
||||
assert brand_ctx is not None
|
||||
assert len(comp_ctx) == 1
|
||||
|
||||
def test_detect_citations_none_found(self):
|
||||
class ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"Some random text without brand names",
|
||||
"BrandX",
|
||||
["CompetitorY"],
|
||||
)
|
||||
assert has_brand is False
|
||||
assert has_comp is False
|
||||
assert brand_ctx is None
|
||||
assert comp_ctx == []
|
||||
|
||||
def test_detect_citations_case_insensitive(self):
|
||||
class ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
adapter = ConcreteAdapter(api_key="test-key")
|
||||
has_brand, _, _, _ = adapter._detect_citations(
|
||||
"brandx is great",
|
||||
"BrandX",
|
||||
None,
|
||||
)
|
||||
assert has_brand is True
|
||||
|
||||
|
||||
class TestChatGPTAdapter:
|
||||
@pytest.fixture
|
||||
def chatgpt_adapter(self):
|
||||
return ChatGPTAdapter(api_key="test-api-key")
|
||||
|
||||
def test_chatgpt_init(self, chatgpt_adapter):
|
||||
assert chatgpt_adapter.api_key == "test-api-key"
|
||||
assert chatgpt_adapter.get_engine_type() == EngineType.CHATGPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_query_success(self, chatgpt_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "BrandX is a leading insurance company with great service."
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await chatgpt_adapter.query(
|
||||
query="best insurance companies",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.CHATGPT
|
||||
assert result.query == "best insurance companies"
|
||||
assert "BrandX" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert result.response_time_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_query_with_rate_limiter(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = ChatGPTAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Some response"}}],
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_query_api_timeout(self, chatgpt_adapter):
|
||||
import httpx
|
||||
|
||||
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await chatgpt_adapter.query(
|
||||
query="test query",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_query_invalid_response(self, chatgpt_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
|
||||
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await chatgpt_adapter.query(
|
||||
query="test query",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_brand_citation_detection(self, chatgpt_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "BrandX offers excellent insurance coverage."}}
|
||||
],
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await chatgpt_adapter.query(
|
||||
query="insurance",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.brand_context is not None
|
||||
assert "BrandX" in result.brand_context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chatgpt_competitor_citation_detection(self, chatgpt_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "CompetitorY and CompetitorZ are popular insurance providers."
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
with patch.object(chatgpt_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await chatgpt_adapter.query(
|
||||
query="insurance",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY", "CompetitorZ"],
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is False
|
||||
assert result.has_competitor_citation is True
|
||||
assert len(result.competitor_contexts) == 2
|
||||
|
||||
|
||||
class TestPerplexityAdapter:
|
||||
@pytest.fixture
|
||||
def perplexity_adapter(self):
|
||||
return PerplexityAdapter(api_key="test-api-key")
|
||||
|
||||
def test_perplexity_init(self, perplexity_adapter):
|
||||
assert perplexity_adapter.api_key == "test-api-key"
|
||||
assert perplexity_adapter.get_engine_type() == EngineType.PERPLEXITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_query_success(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "BrandX is a well-known insurance brand [1]."
|
||||
}
|
||||
}
|
||||
],
|
||||
"citations": [
|
||||
{"url": "https://brandx.com", "title": "BrandX Official Site"}
|
||||
],
|
||||
"model": "pplx-70b-online",
|
||||
}
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await perplexity_adapter.query(
|
||||
query="best insurance companies",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.PERPLEXITY
|
||||
assert result.query == "best insurance companies"
|
||||
assert "BrandX" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert len(result.citations) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_query_with_citations(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "BrandX is recommended [1]. CompetitorY is also good [2]."}}
|
||||
],
|
||||
"citations": [
|
||||
{"url": "https://brandx.com", "title": "BrandX"},
|
||||
{"url": "https://competitory.com", "title": "CompetitorY"},
|
||||
],
|
||||
"model": "pplx-70b-online",
|
||||
}
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await perplexity_adapter.query(
|
||||
query="insurance",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert len(result.citations) == 2
|
||||
assert result.citations[0].source_url == "https://brandx.com"
|
||||
assert result.citations[1].source_url == "https://competitory.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_query_with_rate_limiter(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = PerplexityAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Some response"}}],
|
||||
"model": "pplx-70b-online",
|
||||
"citations": [],
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_query_api_timeout(self, perplexity_adapter):
|
||||
import httpx
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await perplexity_adapter.query(
|
||||
query="test query",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_query_invalid_response(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await perplexity_adapter.query(
|
||||
query="test query",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_brand_citation_detection(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "BrandX offers excellent insurance coverage."}}
|
||||
],
|
||||
"citations": [],
|
||||
"model": "pplx-70b-online",
|
||||
}
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await perplexity_adapter.query(
|
||||
query="insurance",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.brand_context is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_competitor_citation_detection(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "CompetitorY is a popular insurance provider."}}
|
||||
],
|
||||
"citations": [],
|
||||
"model": "pplx-70b-online",
|
||||
}
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await perplexity_adapter.query(
|
||||
query="insurance",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is False
|
||||
assert result.has_competitor_citation is True
|
||||
assert len(result.competitor_contexts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perplexity_no_citations_field(self, perplexity_adapter):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Some response without citations field"}}],
|
||||
"model": "pplx-70b-online",
|
||||
}
|
||||
|
||||
with patch.object(perplexity_adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await perplexity_adapter.query(
|
||||
query="test",
|
||||
brand_name="BrandX",
|
||||
)
|
||||
|
||||
assert result.citations == []
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
|
||||
def _make_result(
|
||||
engine_type: EngineType,
|
||||
query: str = "test query",
|
||||
has_brand: bool = False,
|
||||
has_competitor: bool = False,
|
||||
) -> AIQueryResult:
|
||||
return AIQueryResult(
|
||||
engine_type=engine_type,
|
||||
query=query,
|
||||
raw_response="some response",
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_competitor,
|
||||
brand_context="brand context" if has_brand else None,
|
||||
competitor_contexts=["comp context"] if has_competitor else [],
|
||||
response_time_ms=100,
|
||||
timestamp=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class _StubAdapter(AIEngineAdapter):
|
||||
def __init__(self, engine_type: EngineType, result: AIQueryResult | None = None, side_effect=None):
|
||||
super().__init__(api_key="test-key")
|
||||
self._engine_type = engine_type
|
||||
self._result = result
|
||||
self._side_effect = side_effect
|
||||
|
||||
async def query(self, query: str, brand_name: str, competitor_names: list[str] | None = None) -> AIQueryResult:
|
||||
if self._side_effect:
|
||||
raise self._side_effect
|
||||
return self._result
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return self._engine_type
|
||||
|
||||
|
||||
class TestBatchQueryServiceInit:
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_with_adapters(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {
|
||||
"chatgpt": _StubAdapter(EngineType.CHATGPT),
|
||||
"perplexity": _StubAdapter(EngineType.PERPLEXITY),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
assert service.adapters is adapters
|
||||
assert len(service.adapters) == 2
|
||||
|
||||
|
||||
class TestBatchQuerySingleEngine:
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_single_success(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
expected = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT, result=expected)}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
result = await service.query_single(
|
||||
EngineType.CHATGPT, "best insurance", "BrandX", ["CompY"]
|
||||
)
|
||||
assert result == expected
|
||||
assert result.engine_type == EngineType.CHATGPT
|
||||
assert result.has_brand_citation is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_single_unknown_engine(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {"chatgpt": _StubAdapter(EngineType.CHATGPT)}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||
await service.query_single(EngineType.KIMI, "test", "BrandX")
|
||||
|
||||
|
||||
class TestBatchQueryParallel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_batch_multiple_engines(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False)
|
||||
adapters = {
|
||||
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
||||
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
results = await service.query_batch(
|
||||
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
||||
"best insurance",
|
||||
"BrandX",
|
||||
)
|
||||
assert len(results) == 2
|
||||
engine_types = {r.engine_type for r in results}
|
||||
assert engine_types == {EngineType.CHATGPT, EngineType.PERPLEXITY}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_batch_partial_failure(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True)
|
||||
adapters = {
|
||||
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
||||
"perplexity": _StubAdapter(
|
||||
EngineType.PERPLEXITY, side_effect=Exception("API error")
|
||||
),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
results = await service.query_batch(
|
||||
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
||||
"best insurance",
|
||||
"BrandX",
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].engine_type == EngineType.CHATGPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_batch_all_fail(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
adapters = {
|
||||
"chatgpt": _StubAdapter(EngineType.CHATGPT, side_effect=Exception("err")),
|
||||
"perplexity": _StubAdapter(EngineType.PERPLEXITY, side_effect=Exception("err")),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
results = await service.query_batch(
|
||||
[EngineType.CHATGPT, EngineType.PERPLEXITY],
|
||||
"test",
|
||||
"BrandX",
|
||||
)
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestBatchQueryAggregation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_results_aggregation(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
r1 = _make_result(EngineType.CHATGPT, has_brand=True, has_competitor=False)
|
||||
r2 = _make_result(EngineType.PERPLEXITY, has_brand=False, has_competitor=True)
|
||||
r3 = _make_result(EngineType.KIMI, has_brand=True, has_competitor=True)
|
||||
adapters = {
|
||||
"chatgpt": _StubAdapter(EngineType.CHATGPT, result=r1),
|
||||
"perplexity": _StubAdapter(EngineType.PERPLEXITY, result=r2),
|
||||
"kimi": _StubAdapter(EngineType.KIMI, result=r3),
|
||||
}
|
||||
service = BatchQueryService(adapters)
|
||||
|
||||
results = await service.query_batch(
|
||||
[EngineType.CHATGPT, EngineType.PERPLEXITY, EngineType.KIMI],
|
||||
"test",
|
||||
"BrandX",
|
||||
["CompY"],
|
||||
)
|
||||
assert len(results) == 3
|
||||
brand_cited = [r for r in results if r.has_brand_citation]
|
||||
competitor_cited = [r for r in results if r.has_competitor_citation]
|
||||
assert len(brand_cited) == 2
|
||||
assert len(competitor_cited) == 2
|
||||
|
||||
|
||||
class TestCitationRateCalculation:
|
||||
def test_brand_citation_rate(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
results = [
|
||||
_make_result(EngineType.CHATGPT, has_brand=True),
|
||||
_make_result(EngineType.PERPLEXITY, has_brand=True),
|
||||
_make_result(EngineType.KIMI, has_brand=False),
|
||||
]
|
||||
service = BatchQueryService({})
|
||||
rate = service.calculate_citation_rate(results)
|
||||
assert rate["total_engines"] == 3
|
||||
assert rate["brand_citation_count"] == 2
|
||||
assert rate["brand_citation_rate"] == pytest.approx(2 / 3)
|
||||
|
||||
def test_competitor_citation_rate(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
results = [
|
||||
_make_result(EngineType.CHATGPT, has_competitor=True),
|
||||
_make_result(EngineType.PERPLEXITY, has_competitor=False),
|
||||
_make_result(EngineType.KIMI, has_competitor=True),
|
||||
_make_result(EngineType.WENXIN, has_competitor=True),
|
||||
]
|
||||
service = BatchQueryService({})
|
||||
rate = service.calculate_citation_rate(results)
|
||||
assert rate["total_engines"] == 4
|
||||
assert rate["competitor_citation_count"] == 3
|
||||
assert rate["competitor_citation_rate"] == pytest.approx(3 / 4)
|
||||
|
||||
def test_empty_results(self):
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
service = BatchQueryService({})
|
||||
rate = service.calculate_citation_rate([])
|
||||
assert rate["total_engines"] == 0
|
||||
assert rate["brand_citation_count"] == 0
|
||||
assert rate["brand_citation_rate"] == 0
|
||||
assert rate["competitor_citation_count"] == 0
|
||||
assert rate["competitor_citation_rate"] == 0
|
||||
|
|
@ -0,0 +1,584 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useMemo, useCallback } from "react";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
} from "@/components/ui/card";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import {
|
||||
Search,
|
||||
RefreshCw,
|
||||
CheckCircle,
|
||||
XCircle,
|
||||
Clock,
|
||||
Cpu,
|
||||
ArrowRight,
|
||||
HelpCircle,
|
||||
Zap,
|
||||
} from "lucide-react";
|
||||
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
|
||||
import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines";
|
||||
import type {
|
||||
AIEngineType,
|
||||
AIQueryResult,
|
||||
AIEnginesResponse,
|
||||
CitationRate,
|
||||
} from "@/types/ai-engines";
|
||||
import { AI_ENGINE_OPTIONS } from "@/types/ai-engines";
|
||||
import type { BrandListResponse } from "@/types/brand";
|
||||
|
||||
function RingProgress({
|
||||
value,
|
||||
size = 80,
|
||||
strokeWidth = 6,
|
||||
colorClass,
|
||||
}: {
|
||||
value: number;
|
||||
size?: number;
|
||||
strokeWidth?: number;
|
||||
colorClass: string;
|
||||
}) {
|
||||
const radius = (size - strokeWidth) / 2;
|
||||
const circumference = 2 * Math.PI * radius;
|
||||
const offset = circumference - (value / 100) * circumference;
|
||||
|
||||
const colorMap: Record<string, string> = {
|
||||
"text-emerald-500": "#10b981",
|
||||
"text-emerald-600": "#059669",
|
||||
"text-red-500": "#ef4444",
|
||||
"text-red-600": "#dc2626",
|
||||
"text-amber-500": "#f59e0b",
|
||||
"text-blue-500": "#3b82f6",
|
||||
};
|
||||
|
||||
const stroke = colorMap[colorClass] || "#10b981";
|
||||
|
||||
return (
|
||||
<svg width={size} height={size} className="-rotate-90">
|
||||
<circle
|
||||
cx={size / 2}
|
||||
cy={size / 2}
|
||||
r={radius}
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={strokeWidth}
|
||||
className="text-muted/30"
|
||||
/>
|
||||
<circle
|
||||
cx={size / 2}
|
||||
cy={size / 2}
|
||||
r={radius}
|
||||
fill="none"
|
||||
stroke={stroke}
|
||||
strokeWidth={strokeWidth}
|
||||
strokeDasharray={circumference}
|
||||
strokeDashoffset={offset}
|
||||
strokeLinecap="round"
|
||||
className="transition-all duration-700"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
function CitationRateCard({
|
||||
rate,
|
||||
label,
|
||||
icon,
|
||||
colorClass,
|
||||
}: {
|
||||
rate: number;
|
||||
label: string;
|
||||
icon: React.ReactNode;
|
||||
colorClass: string;
|
||||
}) {
|
||||
const percentage = Math.round(rate * 100);
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="relative flex items-center justify-center">
|
||||
<RingProgress
|
||||
value={percentage}
|
||||
size={80}
|
||||
strokeWidth={6}
|
||||
colorClass={colorClass}
|
||||
/>
|
||||
<span className="absolute text-lg font-bold">
|
||||
{percentage}%
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className={`rounded-lg bg-muted p-1.5 ${colorClass}`}
|
||||
>
|
||||
{icon}
|
||||
</div>
|
||||
<span className="text-sm text-muted-foreground">{label}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function StatCard({
|
||||
title,
|
||||
value,
|
||||
subtitle,
|
||||
icon,
|
||||
colorClass,
|
||||
}: {
|
||||
title: string;
|
||||
value: string | number;
|
||||
subtitle?: string;
|
||||
icon: React.ReactNode;
|
||||
colorClass: string;
|
||||
}) {
|
||||
return (
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className={`rounded-lg bg-muted p-2 ${colorClass}`}>
|
||||
{icon}
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm text-muted-foreground">{title}</p>
|
||||
<p className={`text-2xl font-bold ${colorClass}`}>{value}</p>
|
||||
{subtitle && (
|
||||
<p className="text-xs text-muted-foreground">{subtitle}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function EngineCheckboxGroup({
|
||||
selected,
|
||||
onToggle,
|
||||
}: {
|
||||
selected: AIEngineType[];
|
||||
onToggle: (engine: AIEngineType) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{AI_ENGINE_OPTIONS.map((opt) => {
|
||||
const isSelected = selected.includes(opt.value);
|
||||
return (
|
||||
<button
|
||||
key={opt.value}
|
||||
type="button"
|
||||
onClick={() => onToggle(opt.value)}
|
||||
className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${
|
||||
isSelected
|
||||
? "border-primary bg-primary/10 text-primary"
|
||||
: "border-input bg-background text-muted-foreground hover:bg-muted"
|
||||
}`}
|
||||
>
|
||||
{isSelected && <CheckCircle className="h-3.5 w-3.5" />}
|
||||
{opt.label}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function EngineResultCard({
|
||||
result,
|
||||
brandName,
|
||||
}: {
|
||||
result: AIQueryResult;
|
||||
brandName: string;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
|
||||
const engineLabel =
|
||||
AI_ENGINE_OPTIONS.find((o) => o.value === result.engine_type)?.label ??
|
||||
result.engine_type;
|
||||
|
||||
const citationStatus = result.has_brand_citation;
|
||||
|
||||
const highlightBrand = (text: string) => {
|
||||
if (!brandName) return text;
|
||||
const parts = text.split(new RegExp(`(${brandName.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")})`, "gi"));
|
||||
return parts.map((part, i) =>
|
||||
part.toLowerCase() === brandName.toLowerCase() ? (
|
||||
<mark key={i} className="bg-emerald-100 text-emerald-800 rounded px-0.5">
|
||||
{part}
|
||||
</mark>
|
||||
) : (
|
||||
part
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Card className={citationStatus ? "border-emerald-200" : "border-red-200"}>
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<CardTitle className="text-lg">{engineLabel}</CardTitle>
|
||||
{citationStatus ? (
|
||||
<Badge className="bg-emerald-100 text-emerald-700 hover:bg-emerald-100">
|
||||
<CheckCircle className="mr-1 h-3 w-3" />
|
||||
已引用
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge className="bg-red-100 text-red-700 hover:bg-red-100">
|
||||
<XCircle className="mr-1 h-3 w-3" />
|
||||
未引用
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-1.5 text-sm text-muted-foreground">
|
||||
<Clock className="h-3.5 w-3.5" />
|
||||
{result.response_time_ms}ms
|
||||
</div>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{result.has_competitor_citation && result.competitor_contexts.length > 0 && (
|
||||
<div className="mb-3 rounded-lg border border-amber-200 bg-amber-50 p-3">
|
||||
<p className="text-xs font-medium text-amber-800 mb-1">
|
||||
竞品被引用
|
||||
</p>
|
||||
<div className="space-y-1">
|
||||
{result.competitor_contexts.map((ctx, i) => (
|
||||
<p key={i} className="text-xs text-amber-700">
|
||||
“{ctx}”
|
||||
</p>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result.brand_context && (
|
||||
<div className="mb-3 rounded-lg border border-emerald-200 bg-emerald-50 p-3">
|
||||
<p className="text-xs font-medium text-emerald-800 mb-1">
|
||||
品牌引用上下文
|
||||
</p>
|
||||
<p className="text-sm text-emerald-700">
|
||||
“{highlightBrand(result.brand_context)}”
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
className="w-full justify-between"
|
||||
>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{expanded ? "收起完整回答" : "查看AI完整回答"}
|
||||
</span>
|
||||
<ArrowRight
|
||||
className={`h-4 w-4 transition-transform ${expanded ? "rotate-90" : ""}`}
|
||||
/>
|
||||
</Button>
|
||||
|
||||
{expanded && (
|
||||
<div className="mt-3 rounded-lg border bg-muted/30 p-4">
|
||||
<p className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{highlightBrand(result.raw_response)}
|
||||
</p>
|
||||
<p className="mt-3 text-xs text-muted-foreground">
|
||||
查询时间: {new Date(result.timestamp).toLocaleString("zh-CN")}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function LoadingState() {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12">
|
||||
<RefreshCw className="h-8 w-8 animate-spin text-muted-foreground" />
|
||||
<p className="mt-4 text-sm text-muted-foreground">
|
||||
正在查询AI引擎...
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ErrorState({
|
||||
message,
|
||||
onRetry,
|
||||
}: {
|
||||
message: string;
|
||||
onRetry: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Card className="border-red-200">
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex flex-col items-center gap-3 text-center">
|
||||
<XCircle className="h-10 w-10 text-red-500" />
|
||||
<div>
|
||||
<p className="font-medium text-red-800">查询失败</p>
|
||||
<p className="text-sm text-red-600">{message}</p>
|
||||
</div>
|
||||
<Button variant="outline" size="sm" onClick={onRetry}>
|
||||
<RefreshCw className="mr-2 h-4 w-4" />
|
||||
重试
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function EmptyState() {
|
||||
return (
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex flex-col items-center gap-3 text-center">
|
||||
<HelpCircle className="h-10 w-10 text-muted-foreground" />
|
||||
<div>
|
||||
<p className="font-medium">暂无查询结果</p>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
请选择品牌和引擎,输入查询词后开始分析
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
export default function AIEnginesPage() {
|
||||
const [selectedBrandId, setSelectedBrandId] = useState<string>("");
|
||||
const [queryText, setQueryText] = useState("");
|
||||
const [selectedEngines, setSelectedEngines] = useState<AIEngineType[]>([
|
||||
"chatgpt",
|
||||
"perplexity",
|
||||
"kimi",
|
||||
"wenxin",
|
||||
"doubao",
|
||||
]);
|
||||
const [queryResults, setQueryResults] = useState<AIEnginesResponse | null>(null);
|
||||
const [queryError, setQueryError] = useState<string | null>(null);
|
||||
|
||||
const { data: brandsData, isLoading: brandsLoading } =
|
||||
useApi<BrandListResponse>("/api/v1/brands/?limit=100&offset=0");
|
||||
|
||||
const queryMutation = useApiMutation<AIEnginesResponse, {
|
||||
engines: AIEngineType[];
|
||||
query: string;
|
||||
brand_id: string;
|
||||
}>("/api/v1/ai-engines/query");
|
||||
|
||||
const brands = brandsData?.items ?? [];
|
||||
|
||||
const selectedBrand = brands.find((b) => b.id === selectedBrandId);
|
||||
const brandName = selectedBrand?.name ?? "";
|
||||
|
||||
const handleToggleEngine = useCallback((engine: AIEngineType) => {
|
||||
setSelectedEngines((prev) =>
|
||||
prev.includes(engine)
|
||||
? prev.filter((e) => e !== engine)
|
||||
: [...prev, engine]
|
||||
);
|
||||
}, []);
|
||||
|
||||
const handleQuery = useCallback(async () => {
|
||||
if (!selectedBrandId || !queryText.trim() || selectedEngines.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
setQueryError(null);
|
||||
setQueryResults(null);
|
||||
|
||||
try {
|
||||
const result = await queryMutation.trigger({
|
||||
engines: selectedEngines,
|
||||
query: queryText.trim(),
|
||||
brand_id: selectedBrandId,
|
||||
});
|
||||
if (result) {
|
||||
setQueryResults(result);
|
||||
} else {
|
||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
||||
}
|
||||
} catch {
|
||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
||||
}
|
||||
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);
|
||||
|
||||
const citationStats = useMemo(() => {
|
||||
if (!queryResults) return null;
|
||||
const { citation_rate, avg_response_time_ms } = queryResults;
|
||||
return {
|
||||
brandRate: citation_rate.brand_citation_rate,
|
||||
competitorRate: citation_rate.competitor_citation_rate,
|
||||
totalEngines: citation_rate.total_engines,
|
||||
brandCount: citation_rate.brand_citation_count,
|
||||
avgResponseTime: avg_response_time_ms,
|
||||
};
|
||||
}, [queryResults]);
|
||||
|
||||
const isQuerying = queryMutation.isMutating;
|
||||
|
||||
const canQuery =
|
||||
selectedBrandId && queryText.trim() && selectedEngines.length > 0;
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex flex-col gap-4 sm:flex-row sm:items-center sm:justify-between">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold tracking-tight">AI引擎分析</h1>
|
||||
<p className="text-muted-foreground">
|
||||
分析品牌在主流AI搜索引擎中的引用情况
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>查询配置</CardTitle>
|
||||
<CardDescription>
|
||||
选择品牌、输入查询词,选择要查询的AI引擎
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-4">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<Label>选择品牌</Label>
|
||||
<Select
|
||||
value={selectedBrandId}
|
||||
onValueChange={setSelectedBrandId}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="请选择品牌" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{brands.map((brand) => (
|
||||
<SelectItem key={brand.id} value={brand.id}>
|
||||
{brand.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>查询词</Label>
|
||||
<Input
|
||||
placeholder="例如:最佳智能手表推荐"
|
||||
value={queryText}
|
||||
onChange={(e) => setQueryText(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && canQuery && !isQuerying) {
|
||||
handleQuery();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>选择引擎</Label>
|
||||
<EngineCheckboxGroup
|
||||
selected={selectedEngines}
|
||||
onToggle={handleToggleEngine}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleQuery}
|
||||
disabled={!canQuery || isQuerying}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
{isQuerying ? (
|
||||
<>
|
||||
<RefreshCw className="mr-2 h-4 w-4 animate-spin" />
|
||||
查询中...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Search className="mr-2 h-4 w-4" />
|
||||
开始查询
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{isQuerying ? (
|
||||
<LoadingState />
|
||||
) : queryError ? (
|
||||
<ErrorState
|
||||
message={queryError}
|
||||
onRetry={handleQuery}
|
||||
/>
|
||||
) : queryResults ? (
|
||||
<>
|
||||
{citationStats && (
|
||||
<div className="grid gap-4 grid-cols-2 lg:grid-cols-4">
|
||||
<CitationRateCard
|
||||
rate={citationStats.brandRate}
|
||||
label="品牌引用率"
|
||||
icon={<CheckCircle className="h-4 w-4" />}
|
||||
colorClass="text-emerald-500"
|
||||
/>
|
||||
<StatCard
|
||||
title="覆盖引擎数"
|
||||
value={`${citationStats.brandCount}/${citationStats.totalEngines}`}
|
||||
subtitle="已引用/总引擎"
|
||||
icon={<Cpu className="h-5 w-5" />}
|
||||
colorClass="text-blue-500"
|
||||
/>
|
||||
<CitationRateCard
|
||||
rate={citationStats.competitorRate}
|
||||
label="竞品引用率"
|
||||
icon={<Zap className="h-4 w-4" />}
|
||||
colorClass="text-red-500"
|
||||
/>
|
||||
<StatCard
|
||||
title="平均响应时间"
|
||||
value={`${citationStats.avgResponseTime}`}
|
||||
subtitle="毫秒 (ms)"
|
||||
icon={<Clock className="h-5 w-5" />}
|
||||
colorClass="text-amber-500"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div>
|
||||
<h2 className="mb-4 text-lg font-semibold">引擎查询结果</h2>
|
||||
<div className="space-y-4">
|
||||
{queryResults.results.map((result) => (
|
||||
<EngineResultCard
|
||||
key={result.engine_type}
|
||||
result={result}
|
||||
brandName={brandName}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<EmptyState />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
import { fetchWithAuth } from "./client";
|
||||
import type { AIEngineType, AIEnginesResponse } from "@/types/ai-engines";
|
||||
|
||||
export const aiEnginesApi = {
|
||||
querySingle: (engineType: string, query: string, brandId: string) =>
|
||||
fetchWithAuth("/api/v1/ai-engines/query", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ engines: [engineType], query, brand_id: brandId }),
|
||||
}),
|
||||
|
||||
queryBatch: (engines: AIEngineType[], query: string, brandId: string) =>
|
||||
fetchWithAuth("/api/v1/ai-engines/query", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ engines, query, brand_id: brandId }),
|
||||
}),
|
||||
|
||||
getResults: (brandId: string) =>
|
||||
fetchWithAuth(`/api/v1/ai-engines/results/${brandId}`),
|
||||
};
|
||||
|
||||
export const MOCK_AI_ENGINES_RESPONSE: AIEnginesResponse = {
|
||||
results: [
|
||||
{
|
||||
engine_type: "chatgpt",
|
||||
query: "最佳智能手表推荐",
|
||||
raw_response:
|
||||
"在智能手表领域,Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻。此外,华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐。三星 Galaxy Watch 6 则是安卓用户的优质选择。",
|
||||
has_brand_citation: true,
|
||||
has_competitor_citation: true,
|
||||
brand_context:
|
||||
"Apple Watch Series 9 是目前市场上最受欢迎的选择之一,其出色的健康监测功能和生态系统整合令人印象深刻",
|
||||
competitor_contexts: [
|
||||
"华为 Watch GT 4 也凭借长续航和运动追踪功能获得了很多用户青睐",
|
||||
"三星 Galaxy Watch 6 则是安卓用户的优质选择",
|
||||
],
|
||||
response_time_ms: 3200,
|
||||
timestamp: "2026-05-25T10:00:00Z",
|
||||
},
|
||||
{
|
||||
engine_type: "perplexity",
|
||||
query: "最佳智能手表推荐",
|
||||
raw_response:
|
||||
"根据最新评测,华为 Watch GT 4 在续航和运动追踪方面表现优异。三星 Galaxy Watch 6 提供了出色的安卓生态体验。Garmin Forerunner 265 则是专业运动爱好者的首选。",
|
||||
has_brand_citation: false,
|
||||
has_competitor_citation: true,
|
||||
brand_context: null,
|
||||
competitor_contexts: [
|
||||
"华为 Watch GT 4 在续航和运动追踪方面表现优异",
|
||||
"三星 Galaxy Watch 6 提供了出色的安卓生态体验",
|
||||
"Garmin Forerunner 265 则是专业运动爱好者的首选",
|
||||
],
|
||||
response_time_ms: 2800,
|
||||
timestamp: "2026-05-25T10:01:00Z",
|
||||
},
|
||||
{
|
||||
engine_type: "kimi",
|
||||
query: "最佳智能手表推荐",
|
||||
raw_response:
|
||||
"智能手表推荐方面,Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选。华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域。",
|
||||
has_brand_citation: true,
|
||||
has_competitor_citation: true,
|
||||
brand_context:
|
||||
"Apple Watch Series 9 凭借其成熟的生态系统和健康功能依然是首选",
|
||||
competitor_contexts: [
|
||||
"华为 Watch GT 4 在国内市场也有不错的表现,尤其是运动健康领域",
|
||||
],
|
||||
response_time_ms: 4100,
|
||||
timestamp: "2026-05-25T10:02:00Z",
|
||||
},
|
||||
{
|
||||
engine_type: "wenxin",
|
||||
query: "最佳智能手表推荐",
|
||||
raw_response:
|
||||
"推荐华为 Watch GT 4,续航长达14天,运动追踪功能全面。小米 Watch S3 性价比很高。OPPO Watch 4 也是不错的选择。",
|
||||
has_brand_citation: false,
|
||||
has_competitor_citation: true,
|
||||
brand_context: null,
|
||||
competitor_contexts: [
|
||||
"推荐华为 Watch GT 4,续航长达14天,运动追踪功能全面",
|
||||
"小米 Watch S3 性价比很高",
|
||||
"OPPO Watch 4 也是不错的选择",
|
||||
],
|
||||
response_time_ms: 1900,
|
||||
timestamp: "2026-05-25T10:03:00Z",
|
||||
},
|
||||
{
|
||||
engine_type: "doubao",
|
||||
query: "最佳智能手表推荐",
|
||||
raw_response:
|
||||
"Apple Watch Series 9 在智能手表市场中综合表现最佳,尤其是健康监测和App生态。华为 Watch GT 4 是国产手表中的佼佼者。",
|
||||
has_brand_citation: true,
|
||||
has_competitor_citation: true,
|
||||
brand_context:
|
||||
"Apple Watch Series 9 在智能手表市场中综合表现最佳,尤其是健康监测和App生态",
|
||||
competitor_contexts: [
|
||||
"华为 Watch GT 4 是国产手表中的佼佼者",
|
||||
],
|
||||
response_time_ms: 2200,
|
||||
timestamp: "2026-05-25T10:04:00Z",
|
||||
},
|
||||
],
|
||||
citation_rate: {
|
||||
total_engines: 5,
|
||||
brand_citation_count: 3,
|
||||
brand_citation_rate: 0.6,
|
||||
competitor_citation_count: 5,
|
||||
competitor_citation_rate: 1.0,
|
||||
},
|
||||
avg_response_time_ms: 2840,
|
||||
};
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao";
|
||||
|
||||
export interface AIEngineOption {
|
||||
value: AIEngineType;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export const AI_ENGINE_OPTIONS: AIEngineOption[] = [
|
||||
{ value: "chatgpt", label: "ChatGPT" },
|
||||
{ value: "perplexity", label: "Perplexity" },
|
||||
{ value: "kimi", label: "Kimi" },
|
||||
{ value: "wenxin", label: "文心一言" },
|
||||
{ value: "doubao", label: "豆包" },
|
||||
];
|
||||
|
||||
export interface AIQueryResult {
|
||||
engine_type: string;
|
||||
query: string;
|
||||
raw_response: string;
|
||||
has_brand_citation: boolean;
|
||||
has_competitor_citation: boolean;
|
||||
brand_context: string | null;
|
||||
competitor_contexts: string[];
|
||||
response_time_ms: number;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export interface CitationRate {
|
||||
total_engines: number;
|
||||
brand_citation_count: number;
|
||||
brand_citation_rate: number;
|
||||
competitor_citation_count: number;
|
||||
competitor_citation_rate: number;
|
||||
}
|
||||
|
||||
export interface AIEnginesResponse {
|
||||
results: AIQueryResult[];
|
||||
citation_rate: CitationRate;
|
||||
avg_response_time_ms: number;
|
||||
}
|
||||
|
||||
export interface AIQueryRequest {
|
||||
engines: AIEngineType[];
|
||||
query: string;
|
||||
brand_id: string;
|
||||
}
|
||||
Loading…
Reference in New Issue