From af3a184c0b22b244bd01195030f9658106f82878 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 25 May 2026 12:16:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=A1=A5=E9=BD=90AI=E5=BC=95=E6=93=8E?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8=20-=209=E5=BC=95=E6=93=8E=E5=85=A8?= =?UTF-8?q?=E8=A6=86=E7=9B=96+=E4=BB=A3=E7=90=86=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 后端(TDD): - 基类添加proxy支持(构造函数>引擎专属环境变量>HTTPS_PROXY) - ChatGPT/Perplexity适配器添加proxy参数 - 新增DeepSeek适配器(国内,OpenAI兼容) - 新增通义千问适配器(国内,DashScope API) - 新增Google Gemini适配器(国外,Google专有API,支持proxy) - 新增腾讯元宝适配器(国内,OpenAI兼容) - EngineType枚举新增GEMINI/YUANBAO - 76个测试全部通过 前端: - AI引擎选项扩展为9个(3国际+6国内) - 引擎选择按国内外分组显示 - 类型定义更新(AIEngineOption.group) --- backend/app/api/ai_engines.py | 8 +- backend/app/services/ai_engine/__init__.py | 18 +- backend/app/services/ai_engine/base.py | 22 +- backend/app/services/ai_engine/chatgpt.py | 4 +- backend/app/services/ai_engine/deepseek.py | 90 ++++++ backend/app/services/ai_engine/gemini.py | 137 +++++++++ backend/app/services/ai_engine/perplexity.py | 4 +- backend/app/services/ai_engine/qwen.py | 91 ++++++ backend/app/services/ai_engine/yuanbao.py | 91 ++++++ .../test_services/test_ai_engine_chinese.py | 13 +- .../test_services/test_ai_engine_query.py | 1 + .../test_services/test_proxy_and_deepseek.py | 285 ++++++++++++++++++ .../test_qwen_gemini_adapters.py | 232 ++++++++++++++ .../test_services/test_yuanbao_adapter.py | 227 ++++++++++++++ .../(dashboard)/dashboard/ai-engines/page.tsx | 66 ++-- frontend/types/ai-engines.ts | 26 +- 16 files changed, 1275 insertions(+), 40 deletions(-) create mode 100644 backend/app/services/ai_engine/deepseek.py create mode 100644 backend/app/services/ai_engine/gemini.py create mode 100644 backend/app/services/ai_engine/qwen.py create mode 100644 backend/app/services/ai_engine/yuanbao.py create mode 100644 backend/tests/test_services/test_proxy_and_deepseek.py create mode 100644 backend/tests/test_services/test_qwen_gemini_adapters.py create mode 100644 backend/tests/test_services/test_yuanbao_adapter.py diff --git a/backend/app/api/ai_engines.py b/backend/app/api/ai_engines.py index 029ecbb..e5963c5 100644 --- a/backend/app/api/ai_engines.py +++ b/backend/app/api/ai_engines.py @@ -9,10 +9,11 @@ from app.models.user import User from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType from app.services.ai_engine.batch_query import BatchQueryService from app.services.ai_engine.chatgpt import ChatGPTAdapter -from app.services.ai_engine.perplexity import PerplexityAdapter -from app.services.ai_engine.kimi import KimiAdapter -from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.doubao import DoubaoAdapter +from app.services.ai_engine.kimi import KimiAdapter +from app.services.ai_engine.perplexity import PerplexityAdapter +from app.services.ai_engine.wenxin import WenxinAdapter +from app.services.ai_engine.yuanbao import YuanbaoAdapter logger = logging.getLogger(__name__) @@ -65,6 +66,7 @@ _ADAPTER_CLASSES: dict[EngineType, type[AIEngineAdapter]] = { EngineType.KIMI: KimiAdapter, EngineType.WENXIN: WenxinAdapter, EngineType.DOUBAO: DoubaoAdapter, + EngineType.YUANBAO: YuanbaoAdapter, } diff --git a/backend/app/services/ai_engine/__init__.py b/backend/app/services/ai_engine/__init__.py index 4d80b3f..73a5f30 100644 --- a/backend/app/services/ai_engine/__init__.py +++ b/backend/app/services/ai_engine/__init__.py @@ -1,10 +1,14 @@ from .base import AIEngineAdapter, AIQueryResult, CitationInfo, EngineType -from .chatgpt import ChatGPTAdapter -from .perplexity import PerplexityAdapter -from .kimi import KimiAdapter -from .wenxin import WenxinAdapter -from .doubao import DoubaoAdapter from .batch_query import BatchQueryService +from .chatgpt import ChatGPTAdapter +from .deepseek import DeepSeekAdapter +from .doubao import DoubaoAdapter +from .gemini import GeminiAdapter +from .kimi import KimiAdapter +from .perplexity import PerplexityAdapter +from .qwen import QwenAdapter +from .wenxin import WenxinAdapter +from .yuanbao import YuanbaoAdapter __all__ = [ "AIEngineAdapter", @@ -12,9 +16,13 @@ __all__ = [ "CitationInfo", "EngineType", "ChatGPTAdapter", + "DeepSeekAdapter", "PerplexityAdapter", "KimiAdapter", "WenxinAdapter", "DoubaoAdapter", + "YuanbaoAdapter", + "QwenAdapter", + "GeminiAdapter", "BatchQueryService", ] diff --git a/backend/app/services/ai_engine/base.py b/backend/app/services/ai_engine/base.py index 50189f6..a3f976c 100644 --- a/backend/app/services/ai_engine/base.py +++ b/backend/app/services/ai_engine/base.py @@ -1,5 +1,6 @@ import asyncio import logging +import os from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import UTC, datetime @@ -22,6 +23,8 @@ class EngineType(str, Enum): DOUBAO = "doubao" DEEPSEEK = "deepseek" QWEN = "qwen" + GEMINI = "gemini" + YUANBAO = "yuanbao" @dataclass @@ -49,9 +52,10 @@ class AIQueryResult: class AIEngineAdapter(ABC): - def __init__(self, api_key: str, rate_limiter=None): + def __init__(self, api_key: str, rate_limiter=None, proxy: str | None = None): self.api_key = api_key self.rate_limiter = rate_limiter + self.proxy = proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy") self._client: httpx.AsyncClient | None = None @abstractmethod @@ -67,6 +71,19 @@ class AIEngineAdapter(ABC): def get_engine_type(self) -> EngineType: pass + async def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(**self._client_kwargs()) + return self._client + + def _client_kwargs(self) -> dict[str, Any]: + kwargs: dict[str, Any] = { + "timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + } + if self.proxy: + kwargs["proxy"] = self.proxy + return kwargs + def _detect_citations( self, response: str, @@ -98,12 +115,13 @@ class AIEngineAdapter(ABC): if self.rate_limiter: await self.rate_limiter.acquire() + client = await self._get_client() engine_name = self.get_engine_type().value last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: - response = await self._client.post(self._endpoint, json=payload) + response = await client.post(self._endpoint, json=payload) if response.status_code == 200: return response.json() diff --git a/backend/app/services/ai_engine/chatgpt.py b/backend/app/services/ai_engine/chatgpt.py index 3b05ffc..85fec00 100644 --- a/backend/app/services/ai_engine/chatgpt.py +++ b/backend/app/services/ai_engine/chatgpt.py @@ -20,10 +20,12 @@ class ChatGPTAdapter(AIEngineAdapter): model: str | None = None, base_url: str | None = None, rate_limiter=None, + proxy: str | None = None, ): super().__init__( api_key=api_key or os.getenv("OPENAI_API_KEY", ""), rate_limiter=rate_limiter, + proxy=proxy or os.getenv("OPENAI_PROXY"), ) self._model = model or os.getenv("OPENAI_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -31,7 +33,7 @@ class ChatGPTAdapter(AIEngineAdapter): ).rstrip("/") self._endpoint = f"{self._base_url}/chat/completions" self._client = httpx.AsyncClient( - timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + **self._client_kwargs(), headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", diff --git a/backend/app/services/ai_engine/deepseek.py b/backend/app/services/ai_engine/deepseek.py new file mode 100644 index 0000000..06946f0 --- /dev/null +++ b/backend/app/services/ai_engine/deepseek.py @@ -0,0 +1,90 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "deepseek-chat" +_DEFAULT_BASE_URL = "https://api.deepseek.com/v1" + + +class DeepSeekAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + proxy: str | None = None, + ): + super().__init__( + api_key=api_key or os.getenv("DEEPSEEK_API_KEY", ""), + rate_limiter=rate_limiter, + proxy=proxy or os.getenv("DEEPSEEK_PROXY"), + ) + self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL) + self._base_url = ( + base_url or os.getenv("DEEPSEEK_BASE_URL", _DEFAULT_BASE_URL) + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + **self._client_kwargs(), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.DEEPSEEK + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 4096, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[deepseek] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model)}, + ) diff --git a/backend/app/services/ai_engine/gemini.py b/backend/app/services/ai_engine/gemini.py new file mode 100644 index 0000000..b9a0ae3 --- /dev/null +++ b/backend/app/services/ai_engine/gemini.py @@ -0,0 +1,137 @@ +import asyncio +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "gemini-pro" +_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" +_MAX_RETRIES = 3 +_RETRYABLE_STATUS = {429, 500, 502, 503} + + +class GeminiAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + rate_limiter=None, + proxy: str | None = None, + ): + super().__init__( + api_key=api_key or os.getenv("GOOGLE_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or os.getenv("GEMINI_MODEL", _DEFAULT_MODEL) + self._base_url = os.getenv("GEMINI_BASE_URL", _DEFAULT_BASE_URL).rstrip("/") + self._proxy = proxy or os.getenv("GEMINI_PROXY") + self._endpoint = ( + f"{self._base_url}/models/{self._model}:generateContent" + f"?key={self.api_key}" + ) + client_kwargs = { + "timeout": httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + "headers": {"Content-Type": "application/json"}, + } + if self._proxy: + client_kwargs["proxy"] = self._proxy + self._client = httpx.AsyncClient(**client_kwargs) + + def get_engine_type(self) -> EngineType: + return EngineType.GEMINI + + async def _request_with_retry(self, payload: dict) -> dict: + if self.rate_limiter: + await self.rate_limiter.acquire() + + last_error: Exception | None = None + + for attempt in range(_MAX_RETRIES): + try: + response = await self._client.post(self._endpoint, json=payload) + + if response.status_code == 200: + return response.json() + + if response.status_code in _RETRYABLE_STATUS: + retry_after = response.headers.get("retry-after") + wait = float(retry_after) if retry_after else 2**attempt + logger.warning( + f"[gemini] HTTP {response.status_code}, " + f"retry {attempt + 1}/{_MAX_RETRIES} in {wait:.1f}s" + ) + last_error = Exception( + f"HTTP {response.status_code}: {response.text[:300]}" + ) + await asyncio.sleep(wait) + continue + + raise RuntimeError( + f"Gemini API 返回错误 {response.status_code}: {response.text[:300]}" + ) + + except httpx.TransportError as exc: + logger.warning( + f"[gemini] Transport error: {exc}, " + f"retry {attempt + 1}/{_MAX_RETRIES}" + ) + last_error = RuntimeError(f"Gemini 网络错误: {exc}") + await asyncio.sleep(2**attempt) + continue + + raise last_error or RuntimeError("Gemini API 超过最大重试次数") + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + payload = { + "contents": [{"parts": [{"text": query}]}], + "generationConfig": {"temperature": 0.7}, + } + + data = await self._request_with_retry(payload) + + candidates = data.get("candidates", []) + if not candidates: + raise RuntimeError("Gemini API 返回空候选内容") + + parts = candidates[0].get("content", {}).get("parts", []) + if not parts: + raise RuntimeError("Gemini API 返回空内容") + + content = parts[0].get("text", "") + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[gemini] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": self._model}, + ) diff --git a/backend/app/services/ai_engine/perplexity.py b/backend/app/services/ai_engine/perplexity.py index f9f1001..3e5983a 100644 --- a/backend/app/services/ai_engine/perplexity.py +++ b/backend/app/services/ai_engine/perplexity.py @@ -20,10 +20,12 @@ class PerplexityAdapter(AIEngineAdapter): model: str | None = None, base_url: str | None = None, rate_limiter=None, + proxy: str | None = None, ): super().__init__( api_key=api_key or os.getenv("PERPLEXITY_API_KEY", ""), rate_limiter=rate_limiter, + proxy=proxy or os.getenv("PERPLEXITY_PROXY"), ) self._model = model or os.getenv("PERPLEXITY_MODEL", _DEFAULT_MODEL) self._base_url = ( @@ -31,7 +33,7 @@ class PerplexityAdapter(AIEngineAdapter): ).rstrip("/") self._endpoint = f"{self._base_url}/chat/completions" self._client = httpx.AsyncClient( - timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + **self._client_kwargs(), headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", diff --git a/backend/app/services/ai_engine/qwen.py b/backend/app/services/ai_engine/qwen.py new file mode 100644 index 0000000..849b0a2 --- /dev/null +++ b/backend/app/services/ai_engine/qwen.py @@ -0,0 +1,91 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "qwen-plus" +_DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + +class QwenAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("DASHSCOPE_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or os.getenv("QWEN_MODEL", _DEFAULT_MODEL) + self._base_url = ( + base_url or os.getenv("QWEN_BASE_URL", _DEFAULT_BASE_URL) + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.QWEN + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + { + "role": "system", + "content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", + }, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 2000, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[qwen] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + ) diff --git a/backend/app/services/ai_engine/yuanbao.py b/backend/app/services/ai_engine/yuanbao.py new file mode 100644 index 0000000..3afd752 --- /dev/null +++ b/backend/app/services/ai_engine/yuanbao.py @@ -0,0 +1,91 @@ +import logging +import os +import time +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL = "hunyuan-lite" +_DEFAULT_BASE_URL = "https://api.hunyuan.cloud.tencent.com/v1" + + +class YuanbaoAdapter(AIEngineAdapter): + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + rate_limiter=None, + ): + super().__init__( + api_key=api_key or os.getenv("HUNYUAN_API_KEY", ""), + rate_limiter=rate_limiter, + ) + self._model = model or os.getenv("HUNYUAN_MODEL", _DEFAULT_MODEL) + self._base_url = ( + base_url or os.getenv("HUNYUAN_BASE_URL", _DEFAULT_BASE_URL) + ).rstrip("/") + self._endpoint = f"{self._base_url}/chat/completions" + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + def get_engine_type(self) -> EngineType: + return EngineType.YUANBAO + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + + messages = [ + { + "role": "system", + "content": "你是一个专业的AI搜索助手。请基于你的知识,详细回答用户的问题。如果引用了外部来源,请在回答中标注来源URL或出处名称。", + }, + {"role": "user", "content": query}, + ] + payload = { + "model": self._model, + "messages": messages, + "temperature": 0.7, + "max_tokens": 2000, + } + + data = await self._request_with_retry(payload) + content = data["choices"][0]["message"]["content"] + + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + logger.info( + f"[yuanbao] query='{query[:50]}...' brand={has_brand} " + f"competitor={has_comp} time={elapsed_ms}ms" + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"model": data.get("model", self._model), "usage": data.get("usage")}, + ) diff --git a/backend/tests/test_services/test_ai_engine_chinese.py b/backend/tests/test_services/test_ai_engine_chinese.py index 5e301f8..53cefe2 100644 --- a/backend/tests/test_services/test_ai_engine_chinese.py +++ b/backend/tests/test_services/test_ai_engine_chinese.py @@ -10,6 +10,7 @@ from app.services.ai_engine.base import ( from app.services.ai_engine.kimi import KimiAdapter from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.doubao import DoubaoAdapter +from app.services.ai_engine.yuanbao import YuanbaoAdapter def _make_mock_response(status_code=200, json_data=None, text="", headers=None): @@ -202,6 +203,11 @@ class TestEngineType: assert adapter.get_engine_type() == EngineType.DOUBAO assert adapter.get_engine_type().value == "doubao" + def test_yuanbao_engine_type(self): + adapter = YuanbaoAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.YUANBAO + assert adapter.get_engine_type().value == "yuanbao" + class TestChineseCitationDetection: def test_brand_name_detection_chinese(self): @@ -254,17 +260,18 @@ class TestAdapterInheritance: assert issubclass(KimiAdapter, AIEngineAdapter) assert issubclass(WenxinAdapter, AIEngineAdapter) assert issubclass(DoubaoAdapter, AIEngineAdapter) + assert issubclass(YuanbaoAdapter, AIEngineAdapter) def test_all_adapters_have_query_method(self): - for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: assert hasattr(cls, "query") assert callable(getattr(cls, "query")) def test_all_adapters_have_detect_citations(self): - for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: assert hasattr(cls, "_detect_citations") def test_all_adapters_have_get_engine_type(self): - for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter]: + for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s") assert instance.get_engine_type() in EngineType diff --git a/backend/tests/test_services/test_ai_engine_query.py b/backend/tests/test_services/test_ai_engine_query.py index 3f94f53..f2a93c8 100644 --- a/backend/tests/test_services/test_ai_engine_query.py +++ b/backend/tests/test_services/test_ai_engine_query.py @@ -21,6 +21,7 @@ class TestEngineType: assert EngineType.DOUBAO == "doubao" assert EngineType.DEEPSEEK == "deepseek" assert EngineType.QWEN == "qwen" + assert EngineType.YUANBAO == "yuanbao" class TestCitationInfo: diff --git a/backend/tests/test_services/test_proxy_and_deepseek.py b/backend/tests/test_services/test_proxy_and_deepseek.py new file mode 100644 index 0000000..8ffa056 --- /dev/null +++ b/backend/tests/test_services/test_proxy_and_deepseek.py @@ -0,0 +1,285 @@ +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType +from app.services.ai_engine.chatgpt import ChatGPTAdapter +from app.services.ai_engine.deepseek import DeepSeekAdapter +from app.services.ai_engine.perplexity import PerplexityAdapter + + +class _ConcreteAdapter(AIEngineAdapter): + async def query(self, query, brand_name, competitor_names=None): + pass + + def get_engine_type(self): + return EngineType.CHATGPT + + +class TestProxySupportBase: + def test_base_init_with_proxy(self): + adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080") + assert adapter.proxy == "http://proxy:8080" + + def test_base_init_proxy_default_none(self): + adapter = _ConcreteAdapter(api_key="test-key") + assert adapter.proxy is None + + def test_base_init_proxy_from_https_proxy_env(self): + with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}): + adapter = _ConcreteAdapter(api_key="test-key") + assert adapter.proxy == "http://env-proxy:3128" + + def test_base_init_proxy_from_https_proxy_lowercase_env(self): + with patch.dict(os.environ, {"https_proxy": "http://lower-proxy:3128"}, clear=False): + if "HTTPS_PROXY" in os.environ: + del os.environ["HTTPS_PROXY"] + adapter = _ConcreteAdapter(api_key="test-key") + assert adapter.proxy == "http://lower-proxy:3128" + + def test_base_init_explicit_proxy_overrides_env(self): + with patch.dict(os.environ, {"HTTPS_PROXY": "http://env-proxy:3128"}): + adapter = _ConcreteAdapter( + api_key="test-key", proxy="http://explicit-proxy:8080" + ) + assert adapter.proxy == "http://explicit-proxy:8080" + + @pytest.mark.asyncio + async def test_get_client_with_proxy(self): + adapter = _ConcreteAdapter(api_key="test-key", proxy="http://proxy:8080") + client = await adapter._get_client() + assert isinstance(client, httpx.AsyncClient) + await adapter.close() + + @pytest.mark.asyncio + async def test_get_client_without_proxy(self): + adapter = _ConcreteAdapter(api_key="test-key") + client = await adapter._get_client() + assert isinstance(client, httpx.AsyncClient) + await adapter.close() + + +class TestChatGPTProxy: + def test_chatgpt_proxy_parameter(self): + adapter = ChatGPTAdapter(api_key="test-key", proxy="http://proxy:8080") + assert adapter.proxy == "http://proxy:8080" + + def test_chatgpt_proxy_from_openai_proxy_env(self): + with patch.dict(os.environ, {"OPENAI_PROXY": "http://openai-proxy:8080"}): + adapter = ChatGPTAdapter(api_key="test-key") + assert adapter.proxy == "http://openai-proxy:8080" + + def test_chatgpt_explicit_proxy_overrides_env(self): + with patch.dict(os.environ, {"OPENAI_PROXY": "http://env-proxy:8080"}): + adapter = ChatGPTAdapter( + api_key="test-key", proxy="http://explicit:9090" + ) + assert adapter.proxy == "http://explicit:9090" + + def test_chatgpt_fallback_to_https_proxy_env(self): + with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False): + if "OPENAI_PROXY" in os.environ: + del os.environ["OPENAI_PROXY"] + adapter = ChatGPTAdapter(api_key="test-key") + assert adapter.proxy == "http://fallback:3128" + + def test_chatgpt_no_proxy(self): + with patch.dict(os.environ, {}, clear=False): + for key in ("OPENAI_PROXY", "HTTPS_PROXY", "https_proxy"): + os.environ.pop(key, None) + adapter = ChatGPTAdapter(api_key="test-key") + assert adapter.proxy is None + + +class TestPerplexityProxy: + def test_perplexity_proxy_parameter(self): + adapter = PerplexityAdapter(api_key="test-key", proxy="http://proxy:8080") + assert adapter.proxy == "http://proxy:8080" + + def test_perplexity_proxy_from_perplexity_proxy_env(self): + with patch.dict(os.environ, {"PERPLEXITY_PROXY": "http://pplx-proxy:8080"}): + adapter = PerplexityAdapter(api_key="test-key") + assert adapter.proxy == "http://pplx-proxy:8080" + + def test_perplexity_fallback_to_https_proxy_env(self): + with patch.dict(os.environ, {"HTTPS_PROXY": "http://fallback:3128"}, clear=False): + if "PERPLEXITY_PROXY" in os.environ: + del os.environ["PERPLEXITY_PROXY"] + adapter = PerplexityAdapter(api_key="test-key") + assert adapter.proxy == "http://fallback:3128" + + +class TestDeepSeekAdapter: + def test_deepseek_init_default(self): + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "ds-test-key"}): + adapter = DeepSeekAdapter() + assert adapter.api_key == "ds-test-key" + assert adapter.get_engine_type() == EngineType.DEEPSEEK + + def test_deepseek_init_with_params(self): + adapter = DeepSeekAdapter( + api_key="custom-key", + model="deepseek-reasoner", + base_url="https://custom.api.com/v1", + ) + assert adapter.api_key == "custom-key" + assert adapter._model == "deepseek-reasoner" + assert adapter._base_url == "https://custom.api.com/v1" + + def test_deepseek_init_from_env(self): + with patch.dict( + os.environ, + { + "DEEPSEEK_API_KEY": "env-key", + "DEEPSEEK_MODEL": "deepseek-coder", + "DEEPSEEK_BASE_URL": "https://env.api.com/v1", + }, + ): + adapter = DeepSeekAdapter() + assert adapter.api_key == "env-key" + assert adapter._model == "deepseek-coder" + assert adapter._base_url == "https://env.api.com/v1" + + def test_deepseek_default_model(self): + adapter = DeepSeekAdapter(api_key="test-key") + assert adapter._model == "deepseek-chat" + + def test_deepseek_default_base_url(self): + adapter = DeepSeekAdapter(api_key="test-key") + assert adapter._base_url == "https://api.deepseek.com/v1" + + def test_deepseek_endpoint(self): + adapter = DeepSeekAdapter(api_key="test-key") + assert adapter._endpoint == "https://api.deepseek.com/v1/chat/completions" + + def test_deepseek_engine_type(self): + adapter = DeepSeekAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.DEEPSEEK + + def test_deepseek_proxy_parameter(self): + adapter = DeepSeekAdapter(api_key="test-key", proxy="http://proxy:8080") + assert adapter.proxy == "http://proxy:8080" + + def test_deepseek_proxy_from_deepseek_proxy_env(self): + with patch.dict(os.environ, {"DEEPSEEK_PROXY": "http://ds-proxy:8080"}): + adapter = DeepSeekAdapter(api_key="test-key") + assert adapter.proxy == "http://ds-proxy:8080" + + @pytest.mark.asyncio + async def test_deepseek_query_success(self): + adapter = DeepSeekAdapter(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "BrandX is a leading AI company with great models."}} + ], + "model": "deepseek-chat", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await adapter.query( + query="best AI companies", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.DEEPSEEK + assert result.query == "best AI companies" + assert "BrandX" in result.raw_response + assert result.has_brand_citation is True + assert result.has_competitor_citation is False + assert result.response_time_ms >= 0 + + @pytest.mark.asyncio + async def test_deepseek_query_with_competitor(self): + adapter = DeepSeekAdapter(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "BrandX and CompetitorY are both strong AI companies."}} + ], + "model": "deepseek-chat", + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await adapter.query( + query="AI comparison", + brand_name="BrandX", + competitor_names=["CompetitorY"], + ) + + assert result.has_brand_citation is True + assert result.has_competitor_citation is True + assert len(result.competitor_contexts) == 1 + + @pytest.mark.asyncio + async def test_deepseek_query_with_rate_limiter(self): + mock_limiter = AsyncMock() + adapter = DeepSeekAdapter(api_key="test-key", rate_limiter=mock_limiter) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "Some response"}}], + "model": "deepseek-chat", + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query(query="test", brand_name="BrandX") + + mock_limiter.acquire.assert_awaited() + + @pytest.mark.asyncio + async def test_deepseek_query_api_error(self): + adapter = DeepSeekAdapter(api_key="test-key") + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + + with pytest.raises(Exception): + await adapter.query(query="test", brand_name="BrandX") + + @pytest.mark.asyncio + async def test_deepseek_query_transport_error(self): + adapter = DeepSeekAdapter(api_key="test-key") + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = httpx.TransportError("Connection failed") + + with pytest.raises(Exception): + await adapter.query(query="test", brand_name="BrandX") + + @pytest.mark.asyncio + async def test_deepseek_query_timeout(self): + adapter = DeepSeekAdapter(api_key="test-key") + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.side_effect = httpx.TimeoutException("Request timed out") + + with pytest.raises(Exception): + await adapter.query(query="test", brand_name="BrandX") + + @pytest.mark.asyncio + async def test_deepseek_close(self): + adapter = DeepSeekAdapter(api_key="test-key") + await adapter.close() + + @pytest.mark.asyncio + async def test_deepseek_context_manager(self): + async with DeepSeekAdapter(api_key="test-key") as adapter: + assert adapter.api_key == "test-key" diff --git a/backend/tests/test_services/test_qwen_gemini_adapters.py b/backend/tests/test_services/test_qwen_gemini_adapters.py new file mode 100644 index 0000000..5cafc3e --- /dev/null +++ b/backend/tests/test_services/test_qwen_gemini_adapters.py @@ -0,0 +1,232 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType +from app.services.ai_engine.gemini import GeminiAdapter +from app.services.ai_engine.qwen import QwenAdapter + + +class TestQwenAdapter: + @pytest.mark.asyncio + async def test_initialization(self): + adapter = QwenAdapter(api_key="test-dashscope-key") + assert adapter.api_key == "test-dashscope-key" + assert adapter._model == "qwen-plus" + assert adapter._base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1" + + @pytest.mark.asyncio + async def test_initialization_with_custom_params(self): + adapter = QwenAdapter( + api_key="custom-key", + model="qwen-max", + base_url="https://custom-url.com/v1", + ) + assert adapter.api_key == "custom-key" + assert adapter._model == "qwen-max" + assert adapter._base_url == "https://custom-url.com/v1" + + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + adapter = QwenAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}], + "model": "qwen-plus", + "usage": {"total_tokens": 100}, + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("华为公司", brand_name="华为") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.QWEN + assert "华为" in result.raw_response + assert result.has_brand_citation is True + assert result.metadata.get("model") == "qwen-plus" + + @pytest.mark.asyncio + async def test_get_engine_type(self): + adapter = QwenAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.QWEN + assert adapter.get_engine_type().value == "qwen" + + @pytest.mark.asyncio + async def test_error_handling(self): + adapter = QwenAdapter(api_key="test-key") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 500: Internal Server Error"), + ): + with pytest.raises(Exception, match="HTTP 500"): + await adapter.query("测试问题", brand_name="华为") + + @pytest.mark.asyncio + async def test_chinese_brand_citation_detection(self): + adapter = QwenAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "华为和小米都是中国知名的科技企业"}}], + "model": "qwen-plus", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query( + "科技公司", brand_name="华为", competitor_names=["小米"] + ) + + assert result.has_brand_citation is True + assert result.has_competitor_citation is True + assert result.brand_context is not None + assert "华为" in result.brand_context + assert len(result.competitor_contexts) == 1 + + @pytest.mark.asyncio + async def test_rate_limiter_called(self): + mock_limiter = AsyncMock() + adapter = QwenAdapter(api_key="test-key", rate_limiter=mock_limiter) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "测试回复"}}], + "model": "qwen-plus", + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query("测试", brand_name="华为") + + mock_limiter.acquire.assert_awaited() + + +class TestGeminiAdapter: + @pytest.mark.asyncio + async def test_initialization(self): + adapter = GeminiAdapter(api_key="test-google-key") + assert adapter.api_key == "test-google-key" + assert adapter._model == "gemini-pro" + + @pytest.mark.asyncio + async def test_initialization_with_custom_params(self): + adapter = GeminiAdapter( + api_key="custom-key", + model="gemini-1.5-pro", + ) + assert adapter.api_key == "custom-key" + assert adapter._model == "gemini-1.5-pro" + + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + adapter = GeminiAdapter(api_key="test-key") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [{"content": {"parts": [{"text": "Google is a leading technology company."}]}}], + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await adapter.query("best tech companies", brand_name="Google") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.GEMINI + assert "Google" in result.raw_response + assert result.has_brand_citation is True + + @pytest.mark.asyncio + async def test_get_engine_type(self): + adapter = GeminiAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.GEMINI + assert adapter.get_engine_type().value == "gemini" + + @pytest.mark.asyncio + async def test_error_handling(self): + adapter = GeminiAdapter(api_key="test-key") + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + with pytest.raises(RuntimeError, match="Gemini"): + await adapter.query("test query", brand_name="Google") + + @pytest.mark.asyncio + async def test_english_brand_citation_detection(self): + adapter = GeminiAdapter(api_key="test-key") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [{"content": {"parts": [{"text": "Google and Microsoft are major tech companies."}]}}], + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + result = await adapter.query( + "tech companies", brand_name="Google", competitor_names=["Microsoft"] + ) + + assert result.has_brand_citation is True + assert result.has_competitor_citation is True + assert result.brand_context is not None + assert "Google" in result.brand_context + assert len(result.competitor_contexts) == 1 + + @pytest.mark.asyncio + async def test_rate_limiter_called(self): + mock_limiter = AsyncMock() + adapter = GeminiAdapter(api_key="test-key", rate_limiter=mock_limiter) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [{"content": {"parts": [{"text": "Test response"}]}}], + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query("test", brand_name="Google") + + mock_limiter.acquire.assert_awaited() + + @pytest.mark.asyncio + async def test_api_key_in_url(self): + adapter = GeminiAdapter(api_key="my-secret-key") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "candidates": [{"content": {"parts": [{"text": "response"}]}}], + } + + with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_response + await adapter.query("test", brand_name="BrandX") + + call_args = mock_post.call_args + assert "key=my-secret-key" in str(call_args) + + @pytest.mark.asyncio + async def test_proxy_support(self): + adapter = GeminiAdapter(api_key="test-key", proxy="http://proxy:8080") + assert adapter._proxy == "http://proxy:8080" + + +class TestAdapterInheritance: + def test_qwen_inherits_base(self): + assert issubclass(QwenAdapter, AIEngineAdapter) + + def test_gemini_inherits_base(self): + assert issubclass(GeminiAdapter, AIEngineAdapter) + + def test_qwen_has_query_method(self): + assert hasattr(QwenAdapter, "query") + assert callable(getattr(QwenAdapter, "query")) + + def test_gemini_has_query_method(self): + assert hasattr(GeminiAdapter, "query") + assert callable(getattr(GeminiAdapter, "query")) + + def test_qwen_has_detect_citations(self): + assert hasattr(QwenAdapter, "_detect_citations") + + def test_gemini_has_detect_citations(self): + assert hasattr(GeminiAdapter, "_detect_citations") diff --git a/backend/tests/test_services/test_yuanbao_adapter.py b/backend/tests/test_services/test_yuanbao_adapter.py new file mode 100644 index 0000000..8351a84 --- /dev/null +++ b/backend/tests/test_services/test_yuanbao_adapter.py @@ -0,0 +1,227 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.ai_engine.base import ( + AIEngineAdapter, + AIQueryResult, + EngineType, +) +from app.services.ai_engine.yuanbao import YuanbaoAdapter + + +class TestYuanbaoAdapterInitialization: + @pytest.mark.asyncio + async def test_initialization_with_api_key(self): + adapter = YuanbaoAdapter(api_key="test-hunyuan-key") + assert adapter.api_key == "test-hunyuan-key" + assert adapter._model == "hunyuan-lite" + assert adapter._base_url == "https://api.hunyuan.cloud.tencent.com/v1" + + @pytest.mark.asyncio + async def test_initialization_with_custom_model(self): + adapter = YuanbaoAdapter(api_key="test-key", model="hunyuan-pro") + assert adapter._model == "hunyuan-pro" + + @pytest.mark.asyncio + async def test_initialization_with_custom_base_url(self): + adapter = YuanbaoAdapter( + api_key="test-key", + base_url="https://custom-api.example.com/v1", + ) + assert adapter._base_url == "https://custom-api.example.com/v1" + + @pytest.mark.asyncio + async def test_initialization_with_env_vars(self): + with patch.dict("os.environ", { + "HUNYUAN_API_KEY": "env-key", + "HUNYUAN_MODEL": "hunyuan-turbo", + "HUNYUAN_BASE_URL": "https://env-url.example.com/v1", + }): + adapter = YuanbaoAdapter() + assert adapter.api_key == "env-key" + assert adapter._model == "hunyuan-turbo" + assert adapter._base_url == "https://env-url.example.com/v1" + + +class TestYuanbaoAdapterEngineType: + def test_get_engine_type_returns_yuanbao(self): + adapter = YuanbaoAdapter(api_key="test-key") + assert adapter.get_engine_type() == EngineType.YUANBAO + + def test_engine_type_value(self): + assert EngineType.YUANBAO == "yuanbao" + + def test_engine_type_in_enum(self): + assert EngineType.YUANBAO in EngineType + + +class TestYuanbaoAdapterQuery: + @pytest.mark.asyncio + async def test_query_returns_ai_query_result(self): + adapter = YuanbaoAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}], + "usage": {"total_tokens": 100}, + "model": "hunyuan-lite", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("华为公司", brand_name="华为") + + assert isinstance(result, AIQueryResult) + assert result.engine_type == EngineType.YUANBAO + assert "华为" in result.raw_response + assert result.has_brand_citation is True + assert result.metadata.get("model") == "hunyuan-lite" + + @pytest.mark.asyncio + async def test_query_with_competitor_detection(self): + adapter = YuanbaoAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "华为和小米都是中国知名手机品牌"}}], + "model": "hunyuan-lite", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query( + "手机品牌", + brand_name="华为", + competitor_names=["小米"], + ) + + assert result.has_brand_citation is True + assert result.has_competitor_citation is True + assert len(result.competitor_contexts) > 0 + + @pytest.mark.asyncio + async def test_query_no_brand_found(self): + adapter = YuanbaoAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "今天天气很好"}}], + "model": "hunyuan-lite", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("天气", brand_name="华为") + + assert result.has_brand_citation is False + assert result.has_competitor_citation is False + assert result.brand_context is None + assert result.competitor_contexts == [] + + @pytest.mark.asyncio + async def test_query_records_response_time(self): + adapter = YuanbaoAdapter(api_key="test-key") + mock_response_data = { + "choices": [{"message": {"content": "测试回复"}}], + "model": "hunyuan-lite", + } + + with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): + result = await adapter.query("测试", brand_name="华为") + + assert result.response_time_ms >= 0 + assert result.timestamp is not None + + @pytest.mark.asyncio + async def test_query_with_rate_limiter(self): + mock_limiter = AsyncMock() + adapter = YuanbaoAdapter(api_key="test-key", rate_limiter=mock_limiter) + mock_response_data = { + "choices": [{"message": {"content": "测试回复"}}], + "model": "hunyuan-lite", + } + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + + with patch.object(adapter._client, "post", new_callable=AsyncMock, return_value=mock_response): + await adapter.query("测试", brand_name="华为") + + mock_limiter.acquire.assert_awaited() + + +class TestYuanbaoAdapterErrorHandling: + @pytest.mark.asyncio + async def test_api_error_handling(self): + adapter = YuanbaoAdapter(api_key="test-key") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 500: Internal Server Error"), + ): + with pytest.raises(Exception, match="HTTP 500"): + await adapter.query("测试问题", brand_name="华为") + + @pytest.mark.asyncio + async def test_rate_limit_handling(self): + adapter = YuanbaoAdapter(api_key="test-key") + + with patch.object( + adapter, + "_request_with_retry", + side_effect=Exception("HTTP 429: rate limited"), + ): + with pytest.raises(Exception, match="429"): + await adapter.query("测试问题", brand_name="华为") + + +class TestYuanbaoAdapterChineseBrandDetection: + def test_chinese_brand_name_detection(self): + adapter = YuanbaoAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "华为是全球领先的ICT基础设施和智能终端提供商", + brand_name="华为", + competitor_names=None, + ) + assert has_brand is True + assert brand_ctx is not None + assert "华为" in brand_ctx + + def test_chinese_competitor_detection(self): + adapter = YuanbaoAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "华为和小米都是中国知名手机品牌", + brand_name="华为", + competitor_names=["小米"], + ) + assert has_brand is True + assert has_comp is True + assert brand_ctx is not None + assert len(comp_ctx) > 0 + + def test_no_match_chinese(self): + adapter = YuanbaoAdapter(api_key="test-key") + has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( + "今天天气很好", + brand_name="华为", + competitor_names=["小米"], + ) + assert has_brand is False + assert has_comp is False + assert brand_ctx is None + assert len(comp_ctx) == 0 + + +class TestYuanbaoAdapterInheritance: + def test_inherits_from_base(self): + assert issubclass(YuanbaoAdapter, AIEngineAdapter) + + def test_has_query_method(self): + assert hasattr(YuanbaoAdapter, "query") + assert callable(getattr(YuanbaoAdapter, "query")) + + def test_has_detect_citations(self): + assert hasattr(YuanbaoAdapter, "_detect_citations") + + def test_has_get_engine_type(self): + adapter = YuanbaoAdapter(api_key="test-key") + assert adapter.get_engine_type() in EngineType + + @pytest.mark.asyncio + async def test_context_manager(self): + adapter = YuanbaoAdapter(api_key="test-key") + async with adapter as a: + assert a is adapter diff --git a/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx index f511a2b..4bf87e8 100644 --- a/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx +++ b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx @@ -179,25 +179,53 @@ function EngineCheckboxGroup({ onToggle: (engine: AIEngineType) => void; }) { return ( -
- {AI_ENGINE_OPTIONS.map((opt) => { - const isSelected = selected.includes(opt.value); - return ( - - ); - })} +
+
+

国际引擎

+
+ {AI_ENGINE_OPTIONS.filter((o) => o.group === "international").map((opt) => { + const isSelected = selected.includes(opt.value); + return ( + + ); + })} +
+
+
+

国内引擎

+
+ {AI_ENGINE_OPTIONS.filter((o) => o.group === "domestic").map((opt) => { + const isSelected = selected.includes(opt.value); + return ( + + ); + })} +
+
); } diff --git a/frontend/types/ai-engines.ts b/frontend/types/ai-engines.ts index 8311544..c8d59d9 100644 --- a/frontend/types/ai-engines.ts +++ b/frontend/types/ai-engines.ts @@ -1,16 +1,30 @@ -export type AIEngineType = "chatgpt" | "perplexity" | "kimi" | "wenxin" | "doubao"; +export type AIEngineType = + | "chatgpt" + | "perplexity" + | "kimi" + | "wenxin" + | "doubao" + | "deepseek" + | "qwen" + | "gemini" + | "yuanbao"; export interface AIEngineOption { value: AIEngineType; label: string; + group: "domestic" | "international"; } export const AI_ENGINE_OPTIONS: AIEngineOption[] = [ - { value: "chatgpt", label: "ChatGPT" }, - { value: "perplexity", label: "Perplexity" }, - { value: "kimi", label: "Kimi" }, - { value: "wenxin", label: "文心一言" }, - { value: "doubao", label: "豆包" }, + { value: "chatgpt", label: "ChatGPT", group: "international" }, + { value: "perplexity", label: "Perplexity", group: "international" }, + { value: "gemini", label: "Google Gemini", group: "international" }, + { value: "kimi", label: "Kimi", group: "domestic" }, + { value: "wenxin", label: "文心一言", group: "domestic" }, + { value: "doubao", label: "豆包", group: "domestic" }, + { value: "deepseek", label: "DeepSeek", group: "domestic" }, + { value: "qwen", label: "通义千问", group: "domestic" }, + { value: "yuanbao", label: "腾讯元宝", group: "domestic" }, ]; export interface AIQueryResult {