"""LLM Gateway 测试""" import pytest from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.config import LLMConfig, ProviderConfig from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage class FakeProvider(LLMProvider): """用于测试的 Fake Provider""" def __init__(self, name: str = "fake", should_fail: bool = False): self._name = name self._should_fail = should_fail self.last_request: LLMRequest | None = None async def chat(self, request: LLMRequest) -> LLMResponse: self.last_request = request if self._should_fail: raise LLMProviderError(self._name, "API error") usage = TokenUsage(prompt_tokens=10, completion_tokens=20) return LLMResponse( content=f"response from {self._name}", model=request.model, usage=usage, ) class FakeStreamProvider(LLMProvider): """Fake Provider with configurable streaming behavior.""" def __init__( self, name: str = "fake", should_fail: bool = False, fail_after_chunks: int = 0, ): self._name = name self._should_fail = should_fail self._fail_after_chunks = fail_after_chunks self.last_request: LLMRequest | None = None async def chat(self, request: LLMRequest) -> LLMResponse: self.last_request = request if self._should_fail: raise LLMProviderError(self._name, "API error") usage = TokenUsage(prompt_tokens=10, completion_tokens=20) return LLMResponse( content=f"response from {self._name}", model=request.model, usage=usage, ) async def chat_stream(self, request: LLMRequest): self.last_request = request if self._should_fail: raise LLMProviderError(self._name, "API error") chunks = ["Hello", " from ", self._name] for i, text in enumerate(chunks): if self._fail_after_chunks and i >= self._fail_after_chunks: raise LLMProviderError(self._name, "Stream interrupted") is_final = i == len(chunks) - 1 usage = TokenUsage(prompt_tokens=10, completion_tokens=20) if is_final else None yield StreamChunk( content=text, model=request.model, usage=usage, is_final=is_final, ) class TestLLMGatewayRegister: """Provider 注册测试""" def test_register_provider(self): gateway = LLMGateway() provider = FakeProvider("openai") gateway.register_provider("openai", provider) assert "openai" in gateway._providers def test_register_multiple_providers(self): gateway = LLMGateway() gateway.register_provider("openai", FakeProvider("openai")) gateway.register_provider("deepseek", FakeProvider("deepseek")) assert len(gateway._providers) == 2 class TestLLMGatewayChat: """chat() 方法测试""" async def test_chat_forwards_to_correct_provider(self): gateway = LLMGateway() fake = FakeProvider("openai") gateway.register_provider("openai", fake) response = await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ) assert response.content == "response from openai" assert fake.last_request is not None assert fake.last_request.model == "gpt-4o" async def test_chat_records_usage(self): gateway = LLMGateway() gateway.register_provider("openai", FakeProvider("openai")) await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", agent_name="test_agent", ) usage = gateway.get_usage() assert usage.total_tokens > 0 async def test_chat_no_provider_raises_error(self): gateway = LLMGateway() with pytest.raises(LLMProviderError): await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="nonexistent/model", ) class TestLLMGatewayModelAlias: """模型别名解析测试""" async def test_model_alias_resolves(self): config = LLMConfig( providers={"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1")}, model_aliases={"fast": "openai/gpt-4o-mini"}, ) gateway = LLMGateway(config=config) fake = FakeProvider("openai") gateway.register_provider("openai", fake) response = await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="fast", ) assert response.content == "response from openai" assert fake.last_request.model == "gpt-4o-mini" async def test_nonexistent_model_alias_raises_error(self): config = LLMConfig( model_aliases={"fast": "openai/gpt-4o-mini"}, ) gateway = LLMGateway(config=config) gateway.register_provider("openai", FakeProvider("openai")) gateway.register_provider("deepseek", FakeProvider("deepseek")) with pytest.raises(LLMProviderError): await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="nonexistent_alias", ) class TestLLMGatewayFallback: """Fallback 策略测试""" async def test_fallback_on_primary_failure(self): config = LLMConfig( providers={ "openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"), "deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"), }, fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]}, ) gateway = LLMGateway(config=config) gateway.register_provider("openai", FakeProvider("openai", should_fail=True)) gateway.register_provider("deepseek", FakeProvider("deepseek")) response = await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ) assert response.content == "response from deepseek" async def test_no_fallback_raises_error(self): config = LLMConfig( providers={ "openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"), }, ) gateway = LLMGateway(config=config) gateway.register_provider("openai", FakeProvider("openai", should_fail=True)) with pytest.raises(LLMProviderError): await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ) class TestLLMGatewayUsage: """Usage 查询测试""" async def test_get_usage_by_agent_name(self): gateway = LLMGateway() gateway.register_provider("openai", FakeProvider("openai")) await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", agent_name="agent_a", ) await gateway.chat( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", agent_name="agent_b", ) usage_a = gateway.get_usage(agent_name="agent_a") assert usage_a.total_tokens > 0 assert all(r.agent_name == "agent_a" for r in usage_a.records) async def test_get_usage_empty(self): gateway = LLMGateway() usage = gateway.get_usage() assert usage.total_tokens == 0 assert usage.total_cost == 0.0 assert len(usage.records) == 0 class TestLLMGatewayStreamFallback: """chat_stream() fallback 策略测试""" async def test_stream_fallback_on_primary_failure(self): """Primary fails before any chunk, fallback succeeds.""" config = LLMConfig( providers={ "openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"), "deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"), }, fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]}, ) gateway = LLMGateway(config=config) gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True)) gateway.register_provider("deepseek", FakeStreamProvider("deepseek")) chunks = [] async for chunk in gateway.chat_stream( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ): chunks.append(chunk) content = "".join(c.content for c in chunks) assert "deepseek" in content assert any(c.is_final for c in chunks) async def test_stream_fails_after_chunks_graceful_termination(self): """Primary fails after chunks sent — yields error chunk and stops.""" config = LLMConfig( providers={ "openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"), "deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"), }, fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]}, ) gateway = LLMGateway(config=config) gateway.register_provider( "openai", FakeStreamProvider("openai", fail_after_chunks=1) ) gateway.register_provider("deepseek", FakeStreamProvider("deepseek")) chunks = [] async for chunk in gateway.chat_stream( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ): chunks.append(chunk) # Should have: 1 real chunk + 1 error termination chunk assert len(chunks) == 2 assert chunks[0].content == "Hello" # Error termination chunk assert chunks[1].content == "" assert chunks[1].is_final is True async def test_stream_all_models_fail(self): """All models fail — raises exception.""" config = LLMConfig( providers={ "openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"), "deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"), }, fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]}, ) gateway = LLMGateway(config=config) gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True)) gateway.register_provider("deepseek", FakeStreamProvider("deepseek", should_fail=True)) with pytest.raises(LLMProviderError): async for _ in gateway.chat_stream( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ): pass async def test_stream_single_model_no_fallback(self): """Single model with no fallback works normally.""" gateway = LLMGateway() gateway.register_provider("openai", FakeStreamProvider("openai")) chunks = [] async for chunk in gateway.chat_stream( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", ): chunks.append(chunk) content = "".join(c.content for c in chunks) assert "openai" in content assert any(c.is_final for c in chunks) async def test_stream_records_usage(self): """Usage is tracked after successful stream.""" gateway = LLMGateway() gateway.register_provider("openai", FakeStreamProvider("openai")) async for _ in gateway.chat_stream( messages=[{"role": "user", "content": "Hello"}], model="openai/gpt-4o", agent_name="stream_agent", ): pass usage = gateway.get_usage() assert usage.total_tokens > 0