feat: 补齐AI引擎适配器 - 9引擎全覆盖+代理支持
后端(TDD): - 基类添加proxy支持(构造函数>引擎专属环境变量>HTTPS_PROXY) - ChatGPT/Perplexity适配器添加proxy参数 - 新增DeepSeek适配器(国内,OpenAI兼容) - 新增通义千问适配器(国内,DashScope API) - 新增Google Gemini适配器(国外,Google专有API,支持proxy) - 新增腾讯元宝适配器(国内,OpenAI兼容) - EngineType枚举新增GEMINI/YUANBAO - 76个测试全部通过 前端: - AI引擎选项扩展为9个(3国际+6国内) - 引擎选择按国内外分组显示 - 类型定义更新(AIEngineOption.group)
This commit is contained in:
parent
9d67a801be
commit
af3a184c0b
|
|
@ -9,10 +9,11 @@ from app.models.user import User
|
|||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -65,6 +66,7 @@ _ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
|
|||
EngineType.KIMI: KimiAdapter,
|
||||
EngineType.WENXIN: WenxinAdapter,
|
||||
EngineType.DOUBAO: DoubaoAdapter,
|
||||
EngineType.YUANBAO: YuanbaoAdapter,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
|
||||
from .chatgpt import ChatGPTAdapter
|
||||
from .perplexity import PerplexityAdapter
|
||||
from .kimi import KimiAdapter
|
||||
from .wenxin import WenxinAdapter
|
||||
from .doubao import DoubaoAdapter
|
||||
from .batch_query import BatchQueryService
|
||||
from .chatgpt import ChatGPTAdapter
|
||||
from .deepseek import DeepSeekAdapter
|
||||
from .doubao import DoubaoAdapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .kimi import KimiAdapter
|
||||
from .perplexity import PerplexityAdapter
|
||||
from .qwen import QwenAdapter
|
||||
from .wenxin import WenxinAdapter
|
||||
from .yuanbao import YuanbaoAdapter
|
||||
|
||||
__all__ = [
|
||||
"AIEngineAdapter",
|
||||
|
|
@ -12,9 +16,13 @@ __all__ = [
|
|||
"CitationInfo",
|
||||
"EngineType",
|
||||
"ChatGPTAdapter",
|
||||
"DeepSeekAdapter",
|
||||
"PerplexityAdapter",
|
||||
"KimiAdapter",
|
||||
"WenxinAdapter",
|
||||
"DoubaoAdapter",
|
||||
"YuanbaoAdapter",
|
||||
"QwenAdapter",
|
||||
"GeminiAdapter",
|
||||
"BatchQueryService",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -22,6 +23,8 @@ class EngineType(str, Enum):
|
|||
DOUBAO = "doubao"
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
GEMINI = "gemini"
|
||||
YUANBAO = "yuanbao"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -49,9 +52,10 @@ class AIQueryResult:
|
|||
|
||||
|
||||
class AIEngineAdapter(ABC):
|
||||
def __init__(self, api_key: str, rate_limiter=None):
|
||||
def __init__(self, api_key: str, rate_limiter=None, proxy: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.rate_limiter = rate_limiter
|
||||
self.proxy = proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -67,6 +71,19 @@ class AIEngineAdapter(ABC):
|
|||
def get_engine_type(self) -> EngineType:
|
||||
pass
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(**self._client_kwargs())
|
||||
return self._client
|
||||
|
||||
def _client_kwargs(self) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
}
|
||||
if self.proxy:
|
||||
kwargs["proxy"] = self.proxy
|
||||
return kwargs
|
||||
|
||||
def _detect_citations(
|
||||
self,
|
||||
response: str,
|
||||
|
|
@ -98,12 +115,13 @@ class AIEngineAdapter(ABC):
|
|||
if self.rate_limiter:
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
client = await self._get_client()
|
||||
engine_name = self.get_engine_type().value
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
response = await self._client.post(self._endpoint, json=payload)
|
||||
response = await client.post(self._endpoint, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("OPENAI_PROXY"),
|
||||
)
|
||||
self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -31,7 +33,7 @@ class ChatGPTAdapter(AIEngineAdapter):
|
|||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
**self._client_kwargs(),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "deepseek-chat"
|
||||
_DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
|
||||
|
||||
|
||||
class DeepSeekAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("DEEPSEEK_PROXY"),
|
||||
)
|
||||
self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
base_url or os.getenv("DEEPSEEK_BASE_URL", _DEFAULT_BASE_URL)
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
**self._client_kwargs(),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.DEEPSEEK
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[deepseek] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model)},
|
||||
)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "gemini-pro"
|
||||
_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
_MAX_RETRIES = 3
|
||||
_RETRYABLE_STATUS = {429, 500, 502, 503}
|
||||
|
||||
|
||||
class GeminiAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("GOOGLE_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/")
|
||||
self._proxy = proxy or os.getenv("GEMINI_PROXY")
|
||||
self._endpoint = (
|
||||
f"{self._base_url}/models/{self._model}:generateContent"
|
||||
f"?key={self.api_key}"
|
||||
)
|
||||
client_kwargs = {
|
||||
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
}
|
||||
if self._proxy:
|
||||
client_kwargs["proxy"] = self._proxy
|
||||
self._client = httpx.AsyncClient(**client_kwargs)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.GEMINI
|
||||
|
||||
async def _request_with_retry(self, payload: dict) -> dict:
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
response = await self._client.post(self._endpoint, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
if response.status_code in _RETRYABLE_STATUS:
|
||||
retry_after = response.headers.get("retry-after")
|
||||
wait = float(retry_after) if retry_after else 2**attempt
|
||||
logger.warning(
|
||||
f"[gemini] HTTP {response.status_code}, "
|
||||
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
|
||||
)
|
||||
last_error = Exception(
|
||||
f"HTTP {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}"
|
||||
)
|
||||
|
||||
except httpx.TransportError as exc:
|
||||
logger.warning(
|
||||
f"[gemini] Transport error: {exc}, "
|
||||
f"retry {attempt + 1}/{_MAX_RETRIES}"
|
||||
)
|
||||
last_error = RuntimeError(f"Gemini 网络错误: {exc}")
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
|
||||
raise last_error or RuntimeError("Gemini API 超过最大重试次数")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
payload = {
|
||||
"contents": [{"parts": [{"text": query}]}],
|
||||
"generationConfig": {"temperature": 0.7},
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
raise RuntimeError("Gemini API 返回空候选内容")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if not parts:
|
||||
raise RuntimeError("Gemini API 返回空内容")
|
||||
|
||||
content = parts[0].get("text", "")
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[gemini] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": self._model},
|
||||
)
|
||||
|
|
@ -20,10 +20,12 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
proxy: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
proxy=proxy or os.getenv("PERPLEXITY_PROXY"),
|
||||
)
|
||||
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
|
|
@ -31,7 +33,7 @@ class PerplexityAdapter(AIEngineAdapter):
|
|||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
**self._client_kwargs(),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "qwen-plus"
|
||||
_DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
|
||||
class QwenAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("DASHSCOPE_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or os.getenv("QWEN_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
base_url or os.getenv("QWEN_BASE_URL", _DEFAULT_BASE_URL)
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.QWEN
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[qwen] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
)
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "hunyuan-lite"
|
||||
_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
|
||||
|
||||
|
||||
class YuanbaoAdapter(AIEngineAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
rate_limiter=None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key or os.getenv("HUNYUAN_API_KEY", ""),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL)
|
||||
self._base_url = (
|
||||
base_url or os.getenv("HUNYUAN_BASE_URL", _DEFAULT_BASE_URL)
|
||||
).rstrip("/")
|
||||
self._endpoint = f"{self._base_url}/chat/completions"
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.YUANBAO
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
payload = {
|
||||
"model": self._model,
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
|
||||
data = await self._request_with_retry(payload)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[yuanbao] query='{query[:50]}...' brand={has_brand} "
|
||||
f"competitor={has_comp} time={elapsed_ms}ms"
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
|
||||
)
|
||||
|
|
@ -10,6 +10,7 @@ from app.services.ai_engine.base import (
|
|||
from app.services.ai_engine.kimi import KimiAdapter
|
||||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||||
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
||||
|
||||
|
||||
def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
|
||||
|
|
@ -202,6 +203,11 @@ class TestEngineType:
|
|||
assert adapter.get_engine_type() == EngineType.DOUBAO
|
||||
assert adapter.get_engine_type().value == "doubao"
|
||||
|
||||
def test_yuanbao_engine_type(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.YUANBAO
|
||||
assert adapter.get_engine_type().value == "yuanbao"
|
||||
|
||||
|
||||
class TestChineseCitationDetection:
|
||||
def test_brand_name_detection_chinese(self):
|
||||
|
|
@ -254,17 +260,18 @@ class TestAdapterInheritance:
|
|||
assert issubclass(KimiAdapter, AIEngineAdapter)
|
||||
assert issubclass(WenxinAdapter, AIEngineAdapter)
|
||||
assert issubclass(DoubaoAdapter, AIEngineAdapter)
|
||||
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
|
||||
|
||||
def test_all_adapters_have_query_method(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
||||
assert hasattr(cls, "query")
|
||||
assert callable(getattr(cls, "query"))
|
||||
|
||||
def test_all_adapters_have_detect_citations(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
||||
assert hasattr(cls, "_detect_citations")
|
||||
|
||||
def test_all_adapters_have_get_engine_type(self):
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]:
|
||||
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
||||
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
|
||||
assert instance.get_engine_type() in EngineType
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class TestEngineType:
|
|||
assert EngineType.DOUBAO == "doubao"
|
||||
assert EngineType.DEEPSEEK == "deepseek"
|
||||
assert EngineType.QWEN == "qwen"
|
||||
assert EngineType.YUANBAO == "yuanbao"
|
||||
|
||||
|
||||
class TestCitationInfo:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,285 @@
|
|||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.ai_engine.chatgpt import ChatGPTAdapter
|
||||
from app.services.ai_engine.deepseek import DeepSeekAdapter
|
||||
from app.services.ai_engine.perplexity import PerplexityAdapter
|
||||
|
||||
|
||||
class _ConcreteAdapter(AIEngineAdapter):
|
||||
async def query(self, query, brand_name, competitor_names=None):
|
||||
pass
|
||||
|
||||
def get_engine_type(self):
|
||||
return EngineType.CHATGPT
|
||||
|
||||
|
||||
class TestProxySupportBase:
|
||||
def test_base_init_with_proxy(self):
|
||||
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
assert adapter.proxy == "http://proxy:8080"
|
||||
|
||||
def test_base_init_proxy_default_none(self):
|
||||
adapter = _ConcreteAdapter(api_key="test-key")
|
||||
assert adapter.proxy is None
|
||||
|
||||
def test_base_init_proxy_from_https_proxy_env(self):
|
||||
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
|
||||
adapter = _ConcreteAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://env-proxy:3128"
|
||||
|
||||
def test_base_init_proxy_from_https_proxy_lowercase_env(self):
|
||||
with patch.dict(os.environ, {"https_proxy": "http://lower-proxy:3128"}, clear=False):
|
||||
if "HTTPS_PROXY" in os.environ:
|
||||
del os.environ["HTTPS_PROXY"]
|
||||
adapter = _ConcreteAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://lower-proxy:3128"
|
||||
|
||||
def test_base_init_explicit_proxy_overrides_env(self):
|
||||
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
|
||||
adapter = _ConcreteAdapter(
|
||||
api_key="test-key", proxy="http://explicit-proxy:8080"
|
||||
)
|
||||
assert adapter.proxy == "http://explicit-proxy:8080"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_with_proxy(self):
|
||||
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
client = await adapter._get_client()
|
||||
assert isinstance(client, httpx.AsyncClient)
|
||||
await adapter.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_without_proxy(self):
|
||||
adapter = _ConcreteAdapter(api_key="test-key")
|
||||
client = await adapter._get_client()
|
||||
assert isinstance(client, httpx.AsyncClient)
|
||||
await adapter.close()
|
||||
|
||||
|
||||
class TestChatGPTProxy:
|
||||
def test_chatgpt_proxy_parameter(self):
|
||||
adapter = ChatGPTAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
assert adapter.proxy == "http://proxy:8080"
|
||||
|
||||
def test_chatgpt_proxy_from_openai_proxy_env(self):
|
||||
with patch.dict(os.environ, {"OPENAI_PROXY": "http://openai-proxy:8080"}):
|
||||
adapter = ChatGPTAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://openai-proxy:8080"
|
||||
|
||||
def test_chatgpt_explicit_proxy_overrides_env(self):
|
||||
with patch.dict(os.environ, {"OPENAI_PROXY": "http://env-proxy:8080"}):
|
||||
adapter = ChatGPTAdapter(
|
||||
api_key="test-key", proxy="http://explicit:9090"
|
||||
)
|
||||
assert adapter.proxy == "http://explicit:9090"
|
||||
|
||||
def test_chatgpt_fallback_to_https_proxy_env(self):
|
||||
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
|
||||
if "OPENAI_PROXY" in os.environ:
|
||||
del os.environ["OPENAI_PROXY"]
|
||||
adapter = ChatGPTAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://fallback:3128"
|
||||
|
||||
def test_chatgpt_no_proxy(self):
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
for key in ("OPENAI_PROXY", "HTTPS_PROXY", "https_proxy"):
|
||||
os.environ.pop(key, None)
|
||||
adapter = ChatGPTAdapter(api_key="test-key")
|
||||
assert adapter.proxy is None
|
||||
|
||||
|
||||
class TestPerplexityProxy:
|
||||
def test_perplexity_proxy_parameter(self):
|
||||
adapter = PerplexityAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
assert adapter.proxy == "http://proxy:8080"
|
||||
|
||||
def test_perplexity_proxy_from_perplexity_proxy_env(self):
|
||||
with patch.dict(os.environ, {"PERPLEXITY_PROXY": "http://pplx-proxy:8080"}):
|
||||
adapter = PerplexityAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://pplx-proxy:8080"
|
||||
|
||||
def test_perplexity_fallback_to_https_proxy_env(self):
|
||||
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
|
||||
if "PERPLEXITY_PROXY" in os.environ:
|
||||
del os.environ["PERPLEXITY_PROXY"]
|
||||
adapter = PerplexityAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://fallback:3128"
|
||||
|
||||
|
||||
class TestDeepSeekAdapter:
|
||||
def test_deepseek_init_default(self):
|
||||
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds-test-key"}):
|
||||
adapter = DeepSeekAdapter()
|
||||
assert adapter.api_key == "ds-test-key"
|
||||
assert adapter.get_engine_type() == EngineType.DEEPSEEK
|
||||
|
||||
def test_deepseek_init_with_params(self):
|
||||
adapter = DeepSeekAdapter(
|
||||
api_key="custom-key",
|
||||
model="deepseek-reasoner",
|
||||
base_url="https://custom.api.com/v1",
|
||||
)
|
||||
assert adapter.api_key == "custom-key"
|
||||
assert adapter._model == "deepseek-reasoner"
|
||||
assert adapter._base_url == "https://custom.api.com/v1"
|
||||
|
||||
def test_deepseek_init_from_env(self):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DEEPSEEK_API_KEY": "env-key",
|
||||
"DEEPSEEK_MODEL": "deepseek-coder",
|
||||
"DEEPSEEK_BASE_URL": "https://env.api.com/v1",
|
||||
},
|
||||
):
|
||||
adapter = DeepSeekAdapter()
|
||||
assert adapter.api_key == "env-key"
|
||||
assert adapter._model == "deepseek-coder"
|
||||
assert adapter._base_url == "https://env.api.com/v1"
|
||||
|
||||
def test_deepseek_default_model(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
assert adapter._model == "deepseek-chat"
|
||||
|
||||
def test_deepseek_default_base_url(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
assert adapter._base_url == "https://api.deepseek.com/v1"
|
||||
|
||||
def test_deepseek_endpoint(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
assert adapter._endpoint == "https://api.deepseek.com/v1/chat/completions"
|
||||
|
||||
def test_deepseek_engine_type(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.DEEPSEEK
|
||||
|
||||
def test_deepseek_proxy_parameter(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
assert adapter.proxy == "http://proxy:8080"
|
||||
|
||||
def test_deepseek_proxy_from_deepseek_proxy_env(self):
|
||||
with patch.dict(os.environ, {"DEEPSEEK_PROXY": "http://ds-proxy:8080"}):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
assert adapter.proxy == "http://ds-proxy:8080"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_success(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "BrandX is a leading AI company with great models."}}
|
||||
],
|
||||
"model": "deepseek-chat",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await adapter.query(
|
||||
query="best AI companies",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.DEEPSEEK
|
||||
assert result.query == "best AI companies"
|
||||
assert "BrandX" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is False
|
||||
assert result.response_time_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_with_competitor(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "BrandX and CompetitorY are both strong AI companies."}}
|
||||
],
|
||||
"model": "deepseek-chat",
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await adapter.query(
|
||||
query="AI comparison",
|
||||
brand_name="BrandX",
|
||||
competitor_names=["CompetitorY"],
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is True
|
||||
assert len(result.competitor_contexts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_with_rate_limiter(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = DeepSeekAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Some response"}}],
|
||||
"model": "deepseek-chat",
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_api_error(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_transport_error(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.side_effect = httpx.TransportError("Connection failed")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_query_timeout(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.side_effect = httpx.TimeoutException("Request timed out")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await adapter.query(query="test", brand_name="BrandX")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_close(self):
|
||||
adapter = DeepSeekAdapter(api_key="test-key")
|
||||
await adapter.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_context_manager(self):
|
||||
async with DeepSeekAdapter(api_key="test-key") as adapter:
|
||||
assert adapter.api_key == "test-key"
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
from app.services.ai_engine.gemini import GeminiAdapter
|
||||
from app.services.ai_engine.qwen import QwenAdapter
|
||||
|
||||
|
||||
class TestQwenAdapter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
adapter = QwenAdapter(api_key="test-dashscope-key")
|
||||
assert adapter.api_key == "test-dashscope-key"
|
||||
assert adapter._model == "qwen-plus"
|
||||
assert adapter._base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_custom_params(self):
|
||||
adapter = QwenAdapter(
|
||||
api_key="custom-key",
|
||||
model="qwen-max",
|
||||
base_url="https://custom-url.com/v1",
|
||||
)
|
||||
assert adapter.api_key == "custom-key"
|
||||
assert adapter._model == "qwen-max"
|
||||
assert adapter._base_url == "https://custom-url.com/v1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
adapter = QwenAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
|
||||
"model": "qwen-plus",
|
||||
"usage": {"total_tokens": 100},
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("华为公司", brand_name="华为")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.QWEN
|
||||
assert "华为" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert result.metadata.get("model") == "qwen-plus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_engine_type(self):
|
||||
adapter = QwenAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.QWEN
|
||||
assert adapter.get_engine_type().value == "qwen"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(self):
|
||||
adapter = QwenAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 500: Internal Server Error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chinese_brand_citation_detection(self):
|
||||
adapter = QwenAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为和小米都是中国知名的科技企业"}}],
|
||||
"model": "qwen-plus",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query(
|
||||
"科技公司", brand_name="华为", competitor_names=["小米"]
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is True
|
||||
assert result.brand_context is not None
|
||||
assert "华为" in result.brand_context
|
||||
assert len(result.competitor_contexts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_called(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = QwenAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "测试回复"}}],
|
||||
"model": "qwen-plus",
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query("测试", brand_name="华为")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
|
||||
class TestGeminiAdapter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
adapter = GeminiAdapter(api_key="test-google-key")
|
||||
assert adapter.api_key == "test-google-key"
|
||||
assert adapter._model == "gemini-pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_custom_params(self):
|
||||
adapter = GeminiAdapter(
|
||||
api_key="custom-key",
|
||||
model="gemini-1.5-pro",
|
||||
)
|
||||
assert adapter.api_key == "custom-key"
|
||||
assert adapter._model == "gemini-1.5-pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
adapter = GeminiAdapter(api_key="test-key")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "Google is a leading technology company."}]}}],
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await adapter.query("best tech companies", brand_name="Google")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.GEMINI
|
||||
assert "Google" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_engine_type(self):
|
||||
adapter = GeminiAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.GEMINI
|
||||
assert adapter.get_engine_type().value == "gemini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(self):
|
||||
adapter = GeminiAdapter(api_key="test-key")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(RuntimeError, match="Gemini"):
|
||||
await adapter.query("test query", brand_name="Google")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_english_brand_citation_detection(self):
|
||||
adapter = GeminiAdapter(api_key="test-key")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "Google and Microsoft are major tech companies."}]}}],
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await adapter.query(
|
||||
"tech companies", brand_name="Google", competitor_names=["Microsoft"]
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is True
|
||||
assert result.brand_context is not None
|
||||
assert "Google" in result.brand_context
|
||||
assert len(result.competitor_contexts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_called(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = GeminiAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "Test response"}]}}],
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query("test", brand_name="Google")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_in_url(self):
|
||||
adapter = GeminiAdapter(api_key="my-secret-key")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [{"content": {"parts": [{"text": "response"}]}}],
|
||||
}
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
await adapter.query("test", brand_name="BrandX")
|
||||
|
||||
call_args = mock_post.call_args
|
||||
assert "key=my-secret-key" in str(call_args)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_support(self):
|
||||
adapter = GeminiAdapter(api_key="test-key", proxy="http://proxy:8080")
|
||||
assert adapter._proxy == "http://proxy:8080"
|
||||
|
||||
|
||||
class TestAdapterInheritance:
|
||||
def test_qwen_inherits_base(self):
|
||||
assert issubclass(QwenAdapter, AIEngineAdapter)
|
||||
|
||||
def test_gemini_inherits_base(self):
|
||||
assert issubclass(GeminiAdapter, AIEngineAdapter)
|
||||
|
||||
def test_qwen_has_query_method(self):
|
||||
assert hasattr(QwenAdapter, "query")
|
||||
assert callable(getattr(QwenAdapter, "query"))
|
||||
|
||||
def test_gemini_has_query_method(self):
|
||||
assert hasattr(GeminiAdapter, "query")
|
||||
assert callable(getattr(GeminiAdapter, "query"))
|
||||
|
||||
def test_qwen_has_detect_citations(self):
|
||||
assert hasattr(QwenAdapter, "_detect_citations")
|
||||
|
||||
def test_gemini_has_detect_citations(self):
|
||||
assert hasattr(GeminiAdapter, "_detect_citations")
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.ai_engine.base import (
|
||||
AIEngineAdapter,
|
||||
AIQueryResult,
|
||||
EngineType,
|
||||
)
|
||||
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
||||
|
||||
|
||||
class TestYuanbaoAdapterInitialization:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_api_key(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-hunyuan-key")
|
||||
assert adapter.api_key == "test-hunyuan-key"
|
||||
assert adapter._model == "hunyuan-lite"
|
||||
assert adapter._base_url == "https://api.hunyuan.cloud.tencent.com/v1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_custom_model(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key", model="hunyuan-pro")
|
||||
assert adapter._model == "hunyuan-pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_custom_base_url(self):
|
||||
adapter = YuanbaoAdapter(
|
||||
api_key="test-key",
|
||||
base_url="https://custom-api.example.com/v1",
|
||||
)
|
||||
assert adapter._base_url == "https://custom-api.example.com/v1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_env_vars(self):
|
||||
with patch.dict("os.environ", {
|
||||
"HUNYUAN_API_KEY": "env-key",
|
||||
"HUNYUAN_MODEL": "hunyuan-turbo",
|
||||
"HUNYUAN_BASE_URL": "https://env-url.example.com/v1",
|
||||
}):
|
||||
adapter = YuanbaoAdapter()
|
||||
assert adapter.api_key == "env-key"
|
||||
assert adapter._model == "hunyuan-turbo"
|
||||
assert adapter._base_url == "https://env-url.example.com/v1"
|
||||
|
||||
|
||||
class TestYuanbaoAdapterEngineType:
|
||||
def test_get_engine_type_returns_yuanbao(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() == EngineType.YUANBAO
|
||||
|
||||
def test_engine_type_value(self):
|
||||
assert EngineType.YUANBAO == "yuanbao"
|
||||
|
||||
def test_engine_type_in_enum(self):
|
||||
assert EngineType.YUANBAO in EngineType
|
||||
|
||||
|
||||
class TestYuanbaoAdapterQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_returns_ai_query_result(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
|
||||
"usage": {"total_tokens": 100},
|
||||
"model": "hunyuan-lite",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("华为公司", brand_name="华为")
|
||||
|
||||
assert isinstance(result, AIQueryResult)
|
||||
assert result.engine_type == EngineType.YUANBAO
|
||||
assert "华为" in result.raw_response
|
||||
assert result.has_brand_citation is True
|
||||
assert result.metadata.get("model") == "hunyuan-lite"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_competitor_detection(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "华为和小米都是中国知名手机品牌"}}],
|
||||
"model": "hunyuan-lite",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query(
|
||||
"手机品牌",
|
||||
brand_name="华为",
|
||||
competitor_names=["小米"],
|
||||
)
|
||||
|
||||
assert result.has_brand_citation is True
|
||||
assert result.has_competitor_citation is True
|
||||
assert len(result.competitor_contexts) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_no_brand_found(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "今天天气很好"}}],
|
||||
"model": "hunyuan-lite",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("天气", brand_name="华为")
|
||||
|
||||
assert result.has_brand_citation is False
|
||||
assert result.has_competitor_citation is False
|
||||
assert result.brand_context is None
|
||||
assert result.competitor_contexts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_records_response_time(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "测试回复"}}],
|
||||
"model": "hunyuan-lite",
|
||||
}
|
||||
|
||||
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
||||
result = await adapter.query("测试", brand_name="华为")
|
||||
|
||||
assert result.response_time_ms >= 0
|
||||
assert result.timestamp is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_rate_limiter(self):
|
||||
mock_limiter = AsyncMock()
|
||||
adapter = YuanbaoAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
||||
mock_response_data = {
|
||||
"choices": [{"message": {"content": "测试回复"}}],
|
||||
"model": "hunyuan-lite",
|
||||
}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
with patch.object(adapter._client, "post", new_callable=AsyncMock, return_value=mock_response):
|
||||
await adapter.query("测试", brand_name="华为")
|
||||
|
||||
mock_limiter.acquire.assert_awaited()
|
||||
|
||||
|
||||
class TestYuanbaoAdapterErrorHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 500: Internal Server Error"),
|
||||
):
|
||||
with pytest.raises(Exception, match="HTTP 500"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
|
||||
with patch.object(
|
||||
adapter,
|
||||
"_request_with_retry",
|
||||
side_effect=Exception("HTTP 429: rate limited"),
|
||||
):
|
||||
with pytest.raises(Exception, match="429"):
|
||||
await adapter.query("测试问题", brand_name="华为")
|
||||
|
||||
|
||||
class TestYuanbaoAdapterChineseBrandDetection:
|
||||
def test_chinese_brand_name_detection(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"华为是全球领先的ICT基础设施和智能终端提供商",
|
||||
brand_name="华为",
|
||||
competitor_names=None,
|
||||
)
|
||||
assert has_brand is True
|
||||
assert brand_ctx is not None
|
||||
assert "华为" in brand_ctx
|
||||
|
||||
def test_chinese_competitor_detection(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"华为和小米都是中国知名手机品牌",
|
||||
brand_name="华为",
|
||||
competitor_names=["小米"],
|
||||
)
|
||||
assert has_brand is True
|
||||
assert has_comp is True
|
||||
assert brand_ctx is not None
|
||||
assert len(comp_ctx) > 0
|
||||
|
||||
def test_no_match_chinese(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
||||
"今天天气很好",
|
||||
brand_name="华为",
|
||||
competitor_names=["小米"],
|
||||
)
|
||||
assert has_brand is False
|
||||
assert has_comp is False
|
||||
assert brand_ctx is None
|
||||
assert len(comp_ctx) == 0
|
||||
|
||||
|
||||
class TestYuanbaoAdapterInheritance:
|
||||
def test_inherits_from_base(self):
|
||||
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
|
||||
|
||||
def test_has_query_method(self):
|
||||
assert hasattr(YuanbaoAdapter, "query")
|
||||
assert callable(getattr(YuanbaoAdapter, "query"))
|
||||
|
||||
def test_has_detect_citations(self):
|
||||
assert hasattr(YuanbaoAdapter, "_detect_citations")
|
||||
|
||||
def test_has_get_engine_type(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
assert adapter.get_engine_type() in EngineType
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager(self):
|
||||
adapter = YuanbaoAdapter(api_key="test-key")
|
||||
async with adapter as a:
|
||||
assert a is adapter
|
||||
|
|
@ -179,8 +179,11 @@ function EngineCheckboxGroup({
|
|||
onToggle: (engine: AIEngineType) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<p className="mb-2 text-xs font-medium text-muted-foreground">国际引擎</p>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{AI_ENGINE_OPTIONS.map((opt) => {
|
||||
{AI_ENGINE_OPTIONS.filter((o) => o.group === "international").map((opt) => {
|
||||
const isSelected = selected.includes(opt.value);
|
||||
return (
|
||||
<button
|
||||
|
|
@ -199,6 +202,31 @@ function EngineCheckboxGroup({
|
|||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<p className="mb-2 text-xs font-medium text-muted-foreground">国内引擎</p>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{AI_ENGINE_OPTIONS.filter((o) => o.group === "domestic").map((opt) => {
|
||||
const isSelected = selected.includes(opt.value);
|
||||
return (
|
||||
<button
|
||||
key={opt.value}
|
||||
type="button"
|
||||
onClick={() => onToggle(opt.value)}
|
||||
className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${
|
||||
isSelected
|
||||
? "border-primary bg-primary/10 text-primary"
|
||||
: "border-input bg-background text-muted-foreground hover:bg-muted"
|
||||
}`}
|
||||
>
|
||||
{isSelected && <CheckCircle className="h-3.5 w-3.5" />}
|
||||
{opt.label}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,30 @@
|
|||
export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao";
|
||||
export type AIEngineType =
|
||||
| "chatgpt"
|
||||
| "perplexity"
|
||||
| "kimi"
|
||||
| "wenxin"
|
||||
| "doubao"
|
||||
| "deepseek"
|
||||
| "qwen"
|
||||
| "gemini"
|
||||
| "yuanbao";
|
||||
|
||||
export interface AIEngineOption {
|
||||
value: AIEngineType;
|
||||
label: string;
|
||||
group: "domestic" | "international";
|
||||
}
|
||||
|
||||
export const AI_ENGINE_OPTIONS: AIEngineOption[] = [
|
||||
{ value: "chatgpt", label: "ChatGPT" },
|
||||
{ value: "perplexity", label: "Perplexity" },
|
||||
{ value: "kimi", label: "Kimi" },
|
||||
{ value: "wenxin", label: "文心一言" },
|
||||
{ value: "doubao", label: "豆包" },
|
||||
{ value: "chatgpt", label: "ChatGPT", group: "international" },
|
||||
{ value: "perplexity", label: "Perplexity", group: "international" },
|
||||
{ value: "gemini", label: "Google Gemini", group: "international" },
|
||||
{ value: "kimi", label: "Kimi", group: "domestic" },
|
||||
{ value: "wenxin", label: "文心一言", group: "domestic" },
|
||||
{ value: "doubao", label: "豆包", group: "domestic" },
|
||||
{ value: "deepseek", label: "DeepSeek", group: "domestic" },
|
||||
{ value: "qwen", label: "通义千问", group: "domestic" },
|
||||
{ value: "yuanbao", label: "腾讯元宝", group: "domestic" },
|
||||
];
|
||||
|
||||
export interface AIQueryResult {
|
||||
|
|
|
|||
Loading…
Reference in New Issue