fischer-agentkit/tests/unit/test_llm_gateway.py

183 lines
6.2 KiB
Python

"""LLM Gateway 测试"""
import pytest
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig, ProviderConfig
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
class FakeProvider(LLMProvider):
"""用于测试的 Fake Provider"""
def __init__(self, name: str = "fake", should_fail: bool = False):
self._name = name
self._should_fail = should_fail
self.last_request: LLMRequest | None = None
async def chat(self, request: LLMRequest) -> LLMResponse:
self.last_request = request
if self._should_fail:
raise LLMProviderError(self._name, "API error")
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
return LLMResponse(
content=f"response from {self._name}",
model=request.model,
usage=usage,
)
class TestLLMGatewayRegister:
"""Provider 注册测试"""
def test_register_provider(self):
gateway = LLMGateway()
provider = FakeProvider("openai")
gateway.register_provider("openai", provider)
assert "openai" in gateway._providers
def test_register_multiple_providers(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
assert len(gateway._providers) == 2
class TestLLMGatewayChat:
"""chat() 方法测试"""
async def test_chat_forwards_to_correct_provider(self):
gateway = LLMGateway()
fake = FakeProvider("openai")
gateway.register_provider("openai", fake)
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
assert response.content == "response from openai"
assert fake.last_request is not None
assert fake.last_request.model == "gpt-4o"
async def test_chat_records_usage(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="test_agent",
)
usage = gateway.get_usage()
assert usage.total_tokens > 0
async def test_chat_no_provider_raises_error(self):
gateway = LLMGateway()
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="nonexistent/model",
)
class TestLLMGatewayModelAlias:
"""模型别名解析测试"""
async def test_model_alias_resolves(self):
config = LLMConfig(
providers={"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1")},
model_aliases={"fast": "openai/gpt-4o-mini"},
)
gateway = LLMGateway(config=config)
fake = FakeProvider("openai")
gateway.register_provider("openai", fake)
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="fast",
)
assert response.content == "response from openai"
assert fake.last_request.model == "gpt-4o-mini"
async def test_nonexistent_model_alias_raises_error(self):
config = LLMConfig(
model_aliases={"fast": "openai/gpt-4o-mini"},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai"))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="nonexistent_alias",
)
class TestLLMGatewayFallback:
"""Fallback 策略测试"""
async def test_fallback_on_primary_failure(self):
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
},
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
assert response.content == "response from deepseek"
async def test_no_fallback_raises_error(self):
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
class TestLLMGatewayUsage:
"""Usage 查询测试"""
async def test_get_usage_by_agent_name(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="agent_a",
)
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="agent_b",
)
usage_a = gateway.get_usage(agent_name="agent_a")
assert usage_a.total_tokens > 0
assert all(r.agent_name == "agent_a" for r in usage_a.records)
async def test_get_usage_empty(self):
gateway = LLMGateway()
usage = gateway.get_usage()
assert usage.total_tokens == 0
assert usage.total_cost == 0.0
assert len(usage.records) == 0