import pytest from unittest.mock import AsyncMock, Mock, patch, MagicMock from app.services.ai_engine.base import ( AIEngineAdapter, AIQueryResult, CitationInfo, EngineType, ) from app.services.ai_engine.kimi import KimiAdapter from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.doubao import DoubaoAdapter from app.services.ai_engine.yuanbao import YuanbaoAdapter def _make_mock_response(status_code=200, json_data=None, text="", headers=None): mock_resp = Mock() mock_resp.status_code = status_code mock_resp.json.return_value = json_data or {} mock_resp.text = text mock_resp.headers = headers or {} return mock_resp class TestKimiAdapter: @pytest.mark.asyncio async def test_initialization(self): adapter = KimiAdapter(api_key="test-key") assert adapter.api_key == "test-key" assert adapter.get_engine_type() == EngineType.KIMI @pytest.mark.asyncio async def test_query_returns_ai_query_result(self): adapter = KimiAdapter(api_key="test-key") mock_response_data = { "choices": [{"message": {"content": "华为是全球领先的ICT公司"}}], "usage": {"total_tokens": 100}, "model": "moonshot-v1-8k", } with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): result = await adapter.query("华为公司", brand_name="华为") assert isinstance(result, AIQueryResult) assert result.engine_type == EngineType.KIMI assert "华为" in result.raw_response assert result.has_brand_citation is True assert result.metadata.get("model") == "moonshot-v1-8k" @pytest.mark.asyncio async def test_api_error_handling(self): adapter = KimiAdapter(api_key="test-key") with patch.object( adapter, "_request_with_retry", side_effect=Exception("HTTP 500: Internal Server Error"), ): with pytest.raises(Exception, match="HTTP 500"): await adapter.query("测试问题", brand_name="华为") @pytest.mark.asyncio async def test_rate_limit_handling(self): adapter = KimiAdapter(api_key="test-key") with patch.object( adapter, "_request_with_retry", side_effect=Exception("HTTP 429: rate limited"), ): with pytest.raises(Exception, match="429"): await adapter.query("测试问题", brand_name="华为") class TestWenxinAdapter: @pytest.mark.asyncio async def test_initialization(self): adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") assert adapter.api_key == "test-key" assert adapter.secret_key == "test-secret" assert adapter.get_engine_type() == EngineType.WENXIN @pytest.mark.asyncio async def test_query_returns_ai_query_result(self): import app.services.ai_engine.wenxin as wenxin_mod wenxin_mod._cached_token = None wenxin_mod._token_expires_at = 0.0 adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000} mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}} mock_client = AsyncMock() token_response = _make_mock_response(200, mock_token_data) chat_response = _make_mock_response(200, mock_chat_data) mock_client.post.side_effect = [token_response, chat_response] with patch.object(adapter, "_client", mock_client): result = await adapter.query("华为公司", brand_name="华为") assert isinstance(result, AIQueryResult) assert result.engine_type == EngineType.WENXIN assert "华为" in result.raw_response assert result.has_brand_citation is True @pytest.mark.asyncio async def test_api_error_handling(self): import app.services.ai_engine.wenxin as wenxin_mod wenxin_mod._cached_token = None wenxin_mod._token_expires_at = 0.0 adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") mock_client = AsyncMock() token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000}) error_response = _make_mock_response(500, text="Internal Server Error") mock_client.post.side_effect = [token_response, error_response] with patch.object(adapter, "_client", mock_client): with pytest.raises(RuntimeError, match="文心"): await adapter.query("测试问题", brand_name="华为") @pytest.mark.asyncio async def test_rate_limit_handling(self): import app.services.ai_engine.wenxin as wenxin_mod wenxin_mod._cached_token = None wenxin_mod._token_expires_at = 0.0 adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") mock_client = AsyncMock() token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000}) rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"}) mock_client.post.side_effect = [token_response, rate_limit_response] with patch.object(adapter, "_client", mock_client): with pytest.raises(RuntimeError, match="限流"): await adapter.query("测试问题", brand_name="华为") class TestDoubaoAdapter: @pytest.mark.asyncio async def test_initialization(self): adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") assert adapter.api_key == "test-key" assert adapter._endpoint_id == "ep-test" assert adapter.get_engine_type() == EngineType.DOUBAO @pytest.mark.asyncio async def test_query_returns_ai_query_result(self): adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") mock_response_data = { "choices": [{"message": {"content": "华为是全球知名企业"}}], "model": "ep-test", } with patch.object(adapter, "_request_with_retry", return_value=mock_response_data): result = await adapter.query("华为公司", brand_name="华为") assert isinstance(result, AIQueryResult) assert result.engine_type == EngineType.DOUBAO assert "华为" in result.raw_response assert result.has_brand_citation is True @pytest.mark.asyncio async def test_api_error_handling(self): adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") with patch.object( adapter, "_request_with_retry", side_effect=Exception("HTTP 500: Internal Server Error"), ): with pytest.raises(Exception, match="HTTP 500"): await adapter.query("测试问题", brand_name="华为") @pytest.mark.asyncio async def test_rate_limit_handling(self): adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") with patch.object( adapter, "_request_with_retry", side_effect=Exception("HTTP 429: rate limited"), ): with pytest.raises(Exception, match="429"): await adapter.query("测试问题", brand_name="华为") class TestEngineType: def test_kimi_engine_type(self): adapter = KimiAdapter(api_key="test-key") assert adapter.get_engine_type() == EngineType.KIMI assert adapter.get_engine_type().value == "kimi" def test_wenxin_engine_type(self): adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret") assert adapter.get_engine_type() == EngineType.WENXIN assert adapter.get_engine_type().value == "wenxin" def test_doubao_engine_type(self): adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test") assert adapter.get_engine_type() == EngineType.DOUBAO assert adapter.get_engine_type().value == "doubao" def test_yuanbao_engine_type(self): adapter = YuanbaoAdapter(api_key="test-key") assert adapter.get_engine_type() == EngineType.YUANBAO assert adapter.get_engine_type().value == "yuanbao" class TestChineseCitationDetection: def test_brand_name_detection_chinese(self): adapter = KimiAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "华为是全球领先的ICT基础设施和智能终端提供商", brand_name="华为", competitor_names=None, ) assert has_brand is True assert brand_ctx is not None assert "华为" in brand_ctx def test_brand_name_detection_case_insensitive(self): adapter = KimiAdapter(api_key="test-key") has_brand, _, _, _ = adapter._detect_citations( "Apple is a great company and apple makes phones", brand_name="apple", competitor_names=None, ) assert has_brand is True def test_competitor_name_detection(self): adapter = KimiAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "华为和小米都是中国知名手机品牌", brand_name="华为", competitor_names=["小米"], ) assert has_brand is True assert has_comp is True assert brand_ctx is not None assert len(comp_ctx) > 0 def test_no_citations_when_no_match(self): adapter = KimiAdapter(api_key="test-key") has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations( "今天天气很好", brand_name="华为", competitor_names=["小米"], ) assert has_brand is False assert has_comp is False assert brand_ctx is None assert len(comp_ctx) == 0 class TestAdapterInheritance: def test_all_adapters_inherit_base(self): assert issubclass(KimiAdapter, AIEngineAdapter) assert issubclass(WenxinAdapter, AIEngineAdapter) assert issubclass(DoubaoAdapter, AIEngineAdapter) assert issubclass(YuanbaoAdapter, AIEngineAdapter) def test_all_adapters_have_query_method(self): for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: assert hasattr(cls, "query") assert callable(getattr(cls, "query")) def test_all_adapters_have_detect_citations(self): for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: assert hasattr(cls, "_detect_citations") def test_all_adapters_have_get_engine_type(self): for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]: instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s") assert instance.get_engine_type() in EngineType