feat: 补齐AI引擎适配器 - 9引擎全覆盖+代理支持

后端(TDD):
- 基类添加proxy支持(构造函数>引擎专属环境变量>HTTPS_PROXY)
- ChatGPT/Perplexity适配器添加proxy参数
- 新增DeepSeek适配器(国内,OpenAI兼容)
- 新增通义千问适配器(国内,DashScope API)
- 新增Google Gemini适配器(国外,Google专有API,支持proxy)
- 新增腾讯元宝适配器(国内,OpenAI兼容)
- EngineType枚举新增GEMINI/YUANBAO
- 76个测试全部通过

前端:
- AI引擎选项扩展为9个(3国际+6国内)
- 引擎选择按国内外分组显示
- 类型定义更新(AIEngineOption.group)
This commit is contained in:
chiguyong 2026-05-25 12:16:16 +08:00
parent 9d67a801be
commit af3a184c0b
16 changed files with 1275 additions and 40 deletions

View File

@ -9,10 +9,11 @@ from app.models.user import User
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.batch_query import BatchQueryService from app.services.ai_engine.batch_query import BatchQueryService
from app.services.ai_engine.chatgpt import ChatGPTAdapter from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter from app.services.ai_engine.doubao import DoubaoAdapter
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.yuanbao import YuanbaoAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,6 +66,7 @@ _ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = {
EngineType.KIMI: KimiAdapter, EngineType.KIMI: KimiAdapter,
EngineType.WENXIN: WenxinAdapter, EngineType.WENXIN: WenxinAdapter,
EngineType.DOUBAO: DoubaoAdapter, EngineType.DOUBAO: DoubaoAdapter,
EngineType.YUANBAO: YuanbaoAdapter,
} }

View File

@ -1,10 +1,14 @@
from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType
from .chatgpt import ChatGPTAdapter
from .perplexity import PerplexityAdapter
from .kimi import KimiAdapter
from .wenxin import WenxinAdapter
from .doubao import DoubaoAdapter
from .batch_query import BatchQueryService from .batch_query import BatchQueryService
from .chatgpt import ChatGPTAdapter
from .deepseek import DeepSeekAdapter
from .doubao import DoubaoAdapter
from .gemini import GeminiAdapter
from .kimi import KimiAdapter
from .perplexity import PerplexityAdapter
from .qwen import QwenAdapter
from .wenxin import WenxinAdapter
from .yuanbao import YuanbaoAdapter
__all__ = [ __all__ = [
"AIEngineAdapter", "AIEngineAdapter",
@ -12,9 +16,13 @@ __all__ = [
"CitationInfo", "CitationInfo",
"EngineType", "EngineType",
"ChatGPTAdapter", "ChatGPTAdapter",
"DeepSeekAdapter",
"PerplexityAdapter", "PerplexityAdapter",
"KimiAdapter", "KimiAdapter",
"WenxinAdapter", "WenxinAdapter",
"DoubaoAdapter", "DoubaoAdapter",
"YuanbaoAdapter",
"QwenAdapter",
"GeminiAdapter",
"BatchQueryService", "BatchQueryService",
] ]

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
@ -22,6 +23,8 @@ class EngineType(str, Enum):
DOUBAO = "doubao" DOUBAO = "doubao"
DEEPSEEK = "deepseek" DEEPSEEK = "deepseek"
QWEN = "qwen" QWEN = "qwen"
GEMINI = "gemini"
YUANBAO = "yuanbao"
@dataclass @dataclass
@ -49,9 +52,10 @@ class AIQueryResult:
class AIEngineAdapter(ABC): class AIEngineAdapter(ABC):
def __init__(self, api_key: str, rate_limiter=None): def __init__(self, api_key: str, rate_limiter=None, proxy: str | None = None):
self.api_key = api_key self.api_key = api_key
self.rate_limiter = rate_limiter self.rate_limiter = rate_limiter
self.proxy = proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
self._client: httpx.AsyncClient | None = None self._client: httpx.AsyncClient | None = None
@abstractmethod @abstractmethod
@ -67,6 +71,19 @@ class AIEngineAdapter(ABC):
def get_engine_type(self) -> EngineType: def get_engine_type(self) -> EngineType:
pass pass
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(**self._client_kwargs())
return self._client
def _client_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
}
if self.proxy:
kwargs["proxy"] = self.proxy
return kwargs
def _detect_citations( def _detect_citations(
self, self,
response: str, response: str,
@ -98,12 +115,13 @@ class AIEngineAdapter(ABC):
if self.rate_limiter: if self.rate_limiter:
await self.rate_limiter.acquire() await self.rate_limiter.acquire()
client = await self._get_client()
engine_name = self.get_engine_type().value engine_name = self.get_engine_type().value
last_error: Exception | None = None last_error: Exception | None = None
for attempt in range(_MAX_RETRIES): for attempt in range(_MAX_RETRIES):
try: try:
response = await self._client.post(self._endpoint, json=payload) response = await client.post(self._endpoint, json=payload)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()

View File

@ -20,10 +20,12 @@ class ChatGPTAdapter(AIEngineAdapter):
model: str | None = None, model: str | None = None,
base_url: str | None = None, base_url: str | None = None,
rate_limiter=None, rate_limiter=None,
proxy: str | None = None,
): ):
super().__init__( super().__init__(
api_key=api_key or os.getenv("OPENAI_API_KEY", ""), api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
rate_limiter=rate_limiter, rate_limiter=rate_limiter,
proxy=proxy or os.getenv("OPENAI_PROXY"),
) )
self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL) self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL)
self._base_url = ( self._base_url = (
@ -31,7 +33,7 @@ class ChatGPTAdapter(AIEngineAdapter):
).rstrip("/") ).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions" self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), **self._client_kwargs(),
headers={ headers={
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -0,0 +1,90 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "deepseek-chat"
_DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
class DeepSeekAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
proxy: str | None = None,
):
super().__init__(
api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""),
rate_limiter=rate_limiter,
proxy=proxy or os.getenv("DEEPSEEK_PROXY"),
)
self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url or os.getenv("DEEPSEEK_BASE_URL", _DEFAULT_BASE_URL)
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
**self._client_kwargs(),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.DEEPSEEK
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 4096,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[deepseek] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model)},
)

View File

@ -0,0 +1,137 @@
import asyncio
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "gemini-pro"
_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class GeminiAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
rate_limiter=None,
proxy: str | None = None,
):
super().__init__(
api_key=api_key or os.getenv("GOOGLE_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL)
self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/")
self._proxy = proxy or os.getenv("GEMINI_PROXY")
self._endpoint = (
f"{self._base_url}/models/{self._model}:generateContent"
f"?key={self.api_key}"
)
client_kwargs = {
"timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
"headers": {"Content-Type": "application/json"},
}
if self._proxy:
client_kwargs["proxy"] = self._proxy
self._client = httpx.AsyncClient(**client_kwargs)
def get_engine_type(self) -> EngineType:
return EngineType.GEMINI
async def _request_with_retry(self, payload: dict) -> dict:
if self.rate_limiter:
await self.rate_limiter.acquire()
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await self._client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
logger.warning(
f"[gemini] HTTP {response.status_code}, "
f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s"
)
last_error = Exception(
f"HTTP {response.status_code}: {response.text[:300]}"
)
await asyncio.sleep(wait)
continue
raise RuntimeError(
f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}"
)
except httpx.TransportError as exc:
logger.warning(
f"[gemini] Transport error: {exc}, "
f"retry {attempt + 1}/{_MAX_RETRIES}"
)
last_error = RuntimeError(f"Gemini 网络错误: {exc}")
await asyncio.sleep(2**attempt)
continue
raise last_error or RuntimeError("Gemini API 超过最大重试次数")
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
payload = {
"contents": [{"parts": [{"text": query}]}],
"generationConfig": {"temperature": 0.7},
}
data = await self._request_with_retry(payload)
candidates = data.get("candidates", [])
if not candidates:
raise RuntimeError("Gemini API 返回空候选内容")
parts = candidates[0].get("content", {}).get("parts", [])
if not parts:
raise RuntimeError("Gemini API 返回空内容")
content = parts[0].get("text", "")
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[gemini] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": self._model},
)

View File

@ -20,10 +20,12 @@ class PerplexityAdapter(AIEngineAdapter):
model: str | None = None, model: str | None = None,
base_url: str | None = None, base_url: str | None = None,
rate_limiter=None, rate_limiter=None,
proxy: str | None = None,
): ):
super().__init__( super().__init__(
api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""), api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""),
rate_limiter=rate_limiter, rate_limiter=rate_limiter,
proxy=proxy or os.getenv("PERPLEXITY_PROXY"),
) )
self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL) self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL)
self._base_url = ( self._base_url = (
@ -31,7 +33,7 @@ class PerplexityAdapter(AIEngineAdapter):
).rstrip("/") ).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions" self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), **self._client_kwargs(),
headers={ headers={
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -0,0 +1,91 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "qwen-plus"
_DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
class QwenAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("DASHSCOPE_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or os.getenv("QWEN_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url or os.getenv("QWEN_BASE_URL", _DEFAULT_BASE_URL)
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.QWEN
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{
"role": "system",
"content": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 2000,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[qwen] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
)

View File

@ -0,0 +1,91 @@
import logging
import os
import time
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_DEFAULT_MODEL = "hunyuan-lite"
_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1"
class YuanbaoAdapter(AIEngineAdapter):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
rate_limiter=None,
):
super().__init__(
api_key=api_key or os.getenv("HUNYUAN_API_KEY", ""),
rate_limiter=rate_limiter,
)
self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url or os.getenv("HUNYUAN_BASE_URL", _DEFAULT_BASE_URL)
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
def get_engine_type(self) -> EngineType:
return EngineType.YUANBAO
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
messages = [
{
"role": "system",
"content": "你是一个专业的AI搜索助手。请基于你的知识详细回答用户的问题。如果引用了外部来源请在回答中标注来源URL或出处名称。",
},
{"role": "user", "content": query},
]
payload = {
"model": self._model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 2000,
}
data = await self._request_with_retry(payload)
content = data["choices"][0]["message"]["content"]
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
logger.info(
f"[yuanbao] query='{query[:50]}...' brand={has_brand} "
f"competitor={has_comp} time={elapsed_ms}ms"
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"model": data.get("model", self._model), "usage": data.get("usage")},
)

View File

@ -10,6 +10,7 @@ from app.services.ai_engine.base import (
from app.services.ai_engine.kimi import KimiAdapter from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter from app.services.ai_engine.doubao import DoubaoAdapter
from app.services.ai_engine.yuanbao import YuanbaoAdapter
def _make_mock_response(status_code=200, json_data=None, text="", headers=None): def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
@ -202,6 +203,11 @@ class TestEngineType:
assert adapter.get_engine_type() == EngineType.DOUBAO assert adapter.get_engine_type() == EngineType.DOUBAO
assert adapter.get_engine_type().value == "doubao" assert adapter.get_engine_type().value == "doubao"
def test_yuanbao_engine_type(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.YUANBAO
assert adapter.get_engine_type().value == "yuanbao"
class TestChineseCitationDetection: class TestChineseCitationDetection:
def test_brand_name_detection_chinese(self): def test_brand_name_detection_chinese(self):
@ -254,17 +260,18 @@ class TestAdapterInheritance:
assert issubclass(KimiAdapter, AIEngineAdapter) assert issubclass(KimiAdapter, AIEngineAdapter)
assert issubclass(WenxinAdapter, AIEngineAdapter) assert issubclass(WenxinAdapter, AIEngineAdapter)
assert issubclass(DoubaoAdapter, AIEngineAdapter) assert issubclass(DoubaoAdapter, AIEngineAdapter)
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
def test_all_adapters_have_query_method(self): def test_all_adapters_have_query_method(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
assert hasattr(cls, "query") assert hasattr(cls, "query")
assert callable(getattr(cls, "query")) assert callable(getattr(cls, "query"))
def test_all_adapters_have_detect_citations(self): def test_all_adapters_have_detect_citations(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
assert hasattr(cls, "_detect_citations") assert hasattr(cls, "_detect_citations")
def test_all_adapters_have_get_engine_type(self): def test_all_adapters_have_get_engine_type(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s") instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
assert instance.get_engine_type() in EngineType assert instance.get_engine_type() in EngineType

View File

@ -21,6 +21,7 @@ class TestEngineType:
assert EngineType.DOUBAO == "doubao" assert EngineType.DOUBAO == "doubao"
assert EngineType.DEEPSEEK == "deepseek" assert EngineType.DEEPSEEK == "deepseek"
assert EngineType.QWEN == "qwen" assert EngineType.QWEN == "qwen"
assert EngineType.YUANBAO == "yuanbao"
class TestCitationInfo: class TestCitationInfo:

View File

@ -0,0 +1,285 @@
import os
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.chatgpt import ChatGPTAdapter
from app.services.ai_engine.deepseek import DeepSeekAdapter
from app.services.ai_engine.perplexity import PerplexityAdapter
class _ConcreteAdapter(AIEngineAdapter):
async def query(self, query, brand_name, competitor_names=None):
pass
def get_engine_type(self):
return EngineType.CHATGPT
class TestProxySupportBase:
def test_base_init_with_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_base_init_proxy_default_none(self):
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy is None
def test_base_init_proxy_from_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy == "http://env-proxy:3128"
def test_base_init_proxy_from_https_proxy_lowercase_env(self):
with patch.dict(os.environ, {"https_proxy": "http://lower-proxy:3128"}, clear=False):
if "HTTPS_PROXY" in os.environ:
del os.environ["HTTPS_PROXY"]
adapter = _ConcreteAdapter(api_key="test-key")
assert adapter.proxy == "http://lower-proxy:3128"
def test_base_init_explicit_proxy_overrides_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}):
adapter = _ConcreteAdapter(
api_key="test-key", proxy="http://explicit-proxy:8080"
)
assert adapter.proxy == "http://explicit-proxy:8080"
@pytest.mark.asyncio
async def test_get_client_with_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080")
client = await adapter._get_client()
assert isinstance(client, httpx.AsyncClient)
await adapter.close()
@pytest.mark.asyncio
async def test_get_client_without_proxy(self):
adapter = _ConcreteAdapter(api_key="test-key")
client = await adapter._get_client()
assert isinstance(client, httpx.AsyncClient)
await adapter.close()
class TestChatGPTProxy:
def test_chatgpt_proxy_parameter(self):
adapter = ChatGPTAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_chatgpt_proxy_from_openai_proxy_env(self):
with patch.dict(os.environ, {"OPENAI_PROXY": "http://openai-proxy:8080"}):
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy == "http://openai-proxy:8080"
def test_chatgpt_explicit_proxy_overrides_env(self):
with patch.dict(os.environ, {"OPENAI_PROXY": "http://env-proxy:8080"}):
adapter = ChatGPTAdapter(
api_key="test-key", proxy="http://explicit:9090"
)
assert adapter.proxy == "http://explicit:9090"
def test_chatgpt_fallback_to_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
if "OPENAI_PROXY" in os.environ:
del os.environ["OPENAI_PROXY"]
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy == "http://fallback:3128"
def test_chatgpt_no_proxy(self):
with patch.dict(os.environ, {}, clear=False):
for key in ("OPENAI_PROXY", "HTTPS_PROXY", "https_proxy"):
os.environ.pop(key, None)
adapter = ChatGPTAdapter(api_key="test-key")
assert adapter.proxy is None
class TestPerplexityProxy:
def test_perplexity_proxy_parameter(self):
adapter = PerplexityAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_perplexity_proxy_from_perplexity_proxy_env(self):
with patch.dict(os.environ, {"PERPLEXITY_PROXY": "http://pplx-proxy:8080"}):
adapter = PerplexityAdapter(api_key="test-key")
assert adapter.proxy == "http://pplx-proxy:8080"
def test_perplexity_fallback_to_https_proxy_env(self):
with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False):
if "PERPLEXITY_PROXY" in os.environ:
del os.environ["PERPLEXITY_PROXY"]
adapter = PerplexityAdapter(api_key="test-key")
assert adapter.proxy == "http://fallback:3128"
class TestDeepSeekAdapter:
def test_deepseek_init_default(self):
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds-test-key"}):
adapter = DeepSeekAdapter()
assert adapter.api_key == "ds-test-key"
assert adapter.get_engine_type() == EngineType.DEEPSEEK
def test_deepseek_init_with_params(self):
adapter = DeepSeekAdapter(
api_key="custom-key",
model="deepseek-reasoner",
base_url="https://custom.api.com/v1",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "deepseek-reasoner"
assert adapter._base_url == "https://custom.api.com/v1"
def test_deepseek_init_from_env(self):
with patch.dict(
os.environ,
{
"DEEPSEEK_API_KEY": "env-key",
"DEEPSEEK_MODEL": "deepseek-coder",
"DEEPSEEK_BASE_URL": "https://env.api.com/v1",
},
):
adapter = DeepSeekAdapter()
assert adapter.api_key == "env-key"
assert adapter._model == "deepseek-coder"
assert adapter._base_url == "https://env.api.com/v1"
def test_deepseek_default_model(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._model == "deepseek-chat"
def test_deepseek_default_base_url(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._base_url == "https://api.deepseek.com/v1"
def test_deepseek_endpoint(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter._endpoint == "https://api.deepseek.com/v1/chat/completions"
def test_deepseek_engine_type(self):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.DEEPSEEK
def test_deepseek_proxy_parameter(self):
adapter = DeepSeekAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter.proxy == "http://proxy:8080"
def test_deepseek_proxy_from_deepseek_proxy_env(self):
with patch.dict(os.environ, {"DEEPSEEK_PROXY": "http://ds-proxy:8080"}):
adapter = DeepSeekAdapter(api_key="test-key")
assert adapter.proxy == "http://ds-proxy:8080"
@pytest.mark.asyncio
async def test_deepseek_query_success(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX is a leading AI company with great models."}}
],
"model": "deepseek-chat",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
query="best AI companies",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.DEEPSEEK
assert result.query == "best AI companies"
assert "BrandX" in result.raw_response
assert result.has_brand_citation is True
assert result.has_competitor_citation is False
assert result.response_time_ms >= 0
@pytest.mark.asyncio
async def test_deepseek_query_with_competitor(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{"message": {"content": "BrandX and CompetitorY are both strong AI companies."}}
],
"model": "deepseek-chat",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
query="AI comparison",
brand_name="BrandX",
competitor_names=["CompetitorY"],
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_deepseek_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = DeepSeekAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "Some response"}}],
"model": "deepseek-chat",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query(query="test", brand_name="BrandX")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_deepseek_query_api_error(self):
adapter = DeepSeekAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_query_transport_error(self):
adapter = DeepSeekAdapter(api_key="test-key")
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TransportError("Connection failed")
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_query_timeout(self):
adapter = DeepSeekAdapter(api_key="test-key")
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.side_effect = httpx.TimeoutException("Request timed out")
with pytest.raises(Exception):
await adapter.query(query="test", brand_name="BrandX")
@pytest.mark.asyncio
async def test_deepseek_close(self):
adapter = DeepSeekAdapter(api_key="test-key")
await adapter.close()
@pytest.mark.asyncio
async def test_deepseek_context_manager(self):
async with DeepSeekAdapter(api_key="test-key") as adapter:
assert adapter.api_key == "test-key"

View File

@ -0,0 +1,232 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.gemini import GeminiAdapter
from app.services.ai_engine.qwen import QwenAdapter
class TestQwenAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = QwenAdapter(api_key="test-dashscope-key")
assert adapter.api_key == "test-dashscope-key"
assert adapter._model == "qwen-plus"
assert adapter._base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
@pytest.mark.asyncio
async def test_initialization_with_custom_params(self):
adapter = QwenAdapter(
api_key="custom-key",
model="qwen-max",
base_url="https://custom-url.com/v1",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "qwen-max"
assert adapter._base_url == "https://custom-url.com/v1"
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = QwenAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
"model": "qwen-plus",
"usage": {"total_tokens": 100},
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.QWEN
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "qwen-plus"
@pytest.mark.asyncio
async def test_get_engine_type(self):
adapter = QwenAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.QWEN
assert adapter.get_engine_type().value == "qwen"
@pytest.mark.asyncio
async def test_error_handling(self):
adapter = QwenAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_chinese_brand_citation_detection(self):
adapter = QwenAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为和小米都是中国知名的科技企业"}}],
"model": "qwen-plus",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query(
"科技公司", brand_name="华为", competitor_names=["小米"]
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert result.brand_context is not None
assert "华为" in result.brand_context
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_rate_limiter_called(self):
mock_limiter = AsyncMock()
adapter = QwenAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "qwen-plus",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("测试", brand_name="华为")
mock_limiter.acquire.assert_awaited()
class TestGeminiAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = GeminiAdapter(api_key="test-google-key")
assert adapter.api_key == "test-google-key"
assert adapter._model == "gemini-pro"
@pytest.mark.asyncio
async def test_initialization_with_custom_params(self):
adapter = GeminiAdapter(
api_key="custom-key",
model="gemini-1.5-pro",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "gemini-1.5-pro"
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Google is a leading technology company."}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query("best tech companies", brand_name="Google")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.GEMINI
assert "Google" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_get_engine_type(self):
adapter = GeminiAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.GEMINI
assert adapter.get_engine_type().value == "gemini"
@pytest.mark.asyncio
async def test_error_handling(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(RuntimeError, match="Gemini"):
await adapter.query("test query", brand_name="Google")
@pytest.mark.asyncio
async def test_english_brand_citation_detection(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Google and Microsoft are major tech companies."}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
"tech companies", brand_name="Google", competitor_names=["Microsoft"]
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert result.brand_context is not None
assert "Google" in result.brand_context
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_rate_limiter_called(self):
mock_limiter = AsyncMock()
adapter = GeminiAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Test response"}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("test", brand_name="Google")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_api_key_in_url(self):
adapter = GeminiAdapter(api_key="my-secret-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "response"}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("test", brand_name="BrandX")
call_args = mock_post.call_args
assert "key=my-secret-key" in str(call_args)
@pytest.mark.asyncio
async def test_proxy_support(self):
adapter = GeminiAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter._proxy == "http://proxy:8080"
class TestAdapterInheritance:
def test_qwen_inherits_base(self):
assert issubclass(QwenAdapter, AIEngineAdapter)
def test_gemini_inherits_base(self):
assert issubclass(GeminiAdapter, AIEngineAdapter)
def test_qwen_has_query_method(self):
assert hasattr(QwenAdapter, "query")
assert callable(getattr(QwenAdapter, "query"))
def test_gemini_has_query_method(self):
assert hasattr(GeminiAdapter, "query")
assert callable(getattr(GeminiAdapter, "query"))
def test_qwen_has_detect_citations(self):
assert hasattr(QwenAdapter, "_detect_citations")
def test_gemini_has_detect_citations(self):
assert hasattr(GeminiAdapter, "_detect_citations")

View File

@ -0,0 +1,227 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
EngineType,
)
from app.services.ai_engine.yuanbao import YuanbaoAdapter
class TestYuanbaoAdapterInitialization:
@pytest.mark.asyncio
async def test_initialization_with_api_key(self):
adapter = YuanbaoAdapter(api_key="test-hunyuan-key")
assert adapter.api_key == "test-hunyuan-key"
assert adapter._model == "hunyuan-lite"
assert adapter._base_url == "https://api.hunyuan.cloud.tencent.com/v1"
@pytest.mark.asyncio
async def test_initialization_with_custom_model(self):
adapter = YuanbaoAdapter(api_key="test-key", model="hunyuan-pro")
assert adapter._model == "hunyuan-pro"
@pytest.mark.asyncio
async def test_initialization_with_custom_base_url(self):
adapter = YuanbaoAdapter(
api_key="test-key",
base_url="https://custom-api.example.com/v1",
)
assert adapter._base_url == "https://custom-api.example.com/v1"
@pytest.mark.asyncio
async def test_initialization_with_env_vars(self):
with patch.dict("os.environ", {
"HUNYUAN_API_KEY": "env-key",
"HUNYUAN_MODEL": "hunyuan-turbo",
"HUNYUAN_BASE_URL": "https://env-url.example.com/v1",
}):
adapter = YuanbaoAdapter()
assert adapter.api_key == "env-key"
assert adapter._model == "hunyuan-turbo"
assert adapter._base_url == "https://env-url.example.com/v1"
class TestYuanbaoAdapterEngineType:
def test_get_engine_type_returns_yuanbao(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.YUANBAO
def test_engine_type_value(self):
assert EngineType.YUANBAO == "yuanbao"
def test_engine_type_in_enum(self):
assert EngineType.YUANBAO in EngineType
class TestYuanbaoAdapterQuery:
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
"usage": {"total_tokens": 100},
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.YUANBAO
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "hunyuan-lite"
@pytest.mark.asyncio
async def test_query_with_competitor_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为和小米都是中国知名手机品牌"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query(
"手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) > 0
@pytest.mark.asyncio
async def test_query_no_brand_found(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "今天天气很好"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("天气", brand_name="华为")
assert result.has_brand_citation is False
assert result.has_competitor_citation is False
assert result.brand_context is None
assert result.competitor_contexts == []
@pytest.mark.asyncio
async def test_query_records_response_time(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("测试", brand_name="华为")
assert result.response_time_ms >= 0
assert result.timestamp is not None
@pytest.mark.asyncio
async def test_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = YuanbaoAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response_data = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "hunyuan-lite",
}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(adapter._client, "post", new_callable=AsyncMock, return_value=mock_response):
await adapter.query("测试", brand_name="华为")
mock_limiter.acquire.assert_awaited()
class TestYuanbaoAdapterErrorHandling:
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = YuanbaoAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = YuanbaoAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestYuanbaoAdapterChineseBrandDetection:
def test_chinese_brand_name_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为是全球领先的ICT基础设施和智能终端提供商",
brand_name="华为",
competitor_names=None,
)
assert has_brand is True
assert brand_ctx is not None
assert "华为" in brand_ctx
def test_chinese_competitor_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为和小米都是中国知名手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) > 0
def test_no_match_chinese(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"今天天气很好",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert len(comp_ctx) == 0
class TestYuanbaoAdapterInheritance:
def test_inherits_from_base(self):
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
def test_has_query_method(self):
assert hasattr(YuanbaoAdapter, "query")
assert callable(getattr(YuanbaoAdapter, "query"))
def test_has_detect_citations(self):
assert hasattr(YuanbaoAdapter, "_detect_citations")
def test_has_get_engine_type(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() in EngineType
@pytest.mark.asyncio
async def test_context_manager(self):
adapter = YuanbaoAdapter(api_key="test-key")
async with adapter as a:
assert a is adapter

View File

@ -179,25 +179,53 @@ function EngineCheckboxGroup({
onToggle: (engine: AIEngineType) => void; onToggle: (engine: AIEngineType) => void;
}) { }) {
return ( return (
<div className="flex flex-wrap gap-2"> <div className="space-y-3">
{AI_ENGINE_OPTIONS.map((opt) => { <div>
const isSelected = selected.includes(opt.value); <p className="mb-2 text-xs font-medium text-muted-foreground"></p>
return ( <div className="flex flex-wrap gap-2">
<button {AI_ENGINE_OPTIONS.filter((o) => o.group === "international").map((opt) => {
key={opt.value} const isSelected = selected.includes(opt.value);
type="button" return (
onClick={() => onToggle(opt.value)} <button
className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${ key={opt.value}
isSelected type="button"
? "border-primary bg-primary/10 text-primary" onClick={() => onToggle(opt.value)}
: "border-input bg-background text-muted-foreground hover:bg-muted" className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${
}`} isSelected
> ? "border-primary bg-primary/10 text-primary"
{isSelected && <CheckCircle className="h-3.5 w-3.5" />} : "border-input bg-background text-muted-foreground hover:bg-muted"
{opt.label} }`}
</button> >
); {isSelected && <CheckCircle className="h-3.5 w-3.5" />}
})} {opt.label}
</button>
);
})}
</div>
</div>
<div>
<p className="mb-2 text-xs font-medium text-muted-foreground"></p>
<div className="flex flex-wrap gap-2">
{AI_ENGINE_OPTIONS.filter((o) => o.group === "domestic").map((opt) => {
const isSelected = selected.includes(opt.value);
return (
<button
key={opt.value}
type="button"
onClick={() => onToggle(opt.value)}
className={`inline-flex items-center gap-1.5 rounded-md border px-3 py-1.5 text-sm font-medium transition-colors ${
isSelected
? "border-primary bg-primary/10 text-primary"
: "border-input bg-background text-muted-foreground hover:bg-muted"
}`}
>
{isSelected && <CheckCircle className="h-3.5 w-3.5" />}
{opt.label}
</button>
);
})}
</div>
</div>
</div> </div>
); );
} }

View File

@ -1,16 +1,30 @@
export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao"; export type AIEngineType =
| "chatgpt"
| "perplexity"
| "kimi"
| "wenxin"
| "doubao"
| "deepseek"
| "qwen"
| "gemini"
| "yuanbao";
export interface AIEngineOption { export interface AIEngineOption {
value: AIEngineType; value: AIEngineType;
label: string; label: string;
group: "domestic" | "international";
} }
export const AI_ENGINE_OPTIONS: AIEngineOption[] = [ export const AI_ENGINE_OPTIONS: AIEngineOption[] = [
{ value: "chatgpt", label: "ChatGPT" }, { value: "chatgpt", label: "ChatGPT", group: "international" },
{ value: "perplexity", label: "Perplexity" }, { value: "perplexity", label: "Perplexity", group: "international" },
{ value: "kimi", label: "Kimi" }, { value: "gemini", label: "Google Gemini", group: "international" },
{ value: "wenxin", label: "文心一言" }, { value: "kimi", label: "Kimi", group: "domestic" },
{ value: "doubao", label: "豆包" }, { value: "wenxin", label: "文心一言", group: "domestic" },
{ value: "doubao", label: "豆包", group: "domestic" },
{ value: "deepseek", label: "DeepSeek", group: "domestic" },
{ value: "qwen", label: "通义千问", group: "domestic" },
{ value: "yuanbao", label: "腾讯元宝", group: "domestic" },
]; ];
export interface AIQueryResult { export interface AIQueryResult {