183 lines
6.2 KiB
Python
183 lines
6.2 KiB
Python
"""LLM Gateway 测试"""
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.config import LLMConfig, ProviderConfig
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
|
|
|
|
|
class FakeProvider(LLMProvider):
|
|
"""用于测试的 Fake Provider"""
|
|
|
|
def __init__(self, name: str = "fake", should_fail: bool = False):
|
|
self._name = name
|
|
self._should_fail = should_fail
|
|
self.last_request: LLMRequest | None = None
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
self.last_request = request
|
|
if self._should_fail:
|
|
raise LLMProviderError(self._name, "API error")
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
return LLMResponse(
|
|
content=f"response from {self._name}",
|
|
model=request.model,
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
class TestLLMGatewayRegister:
|
|
"""Provider 注册测试"""
|
|
|
|
def test_register_provider(self):
|
|
gateway = LLMGateway()
|
|
provider = FakeProvider("openai")
|
|
gateway.register_provider("openai", provider)
|
|
assert "openai" in gateway._providers
|
|
|
|
def test_register_multiple_providers(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
assert len(gateway._providers) == 2
|
|
|
|
|
|
class TestLLMGatewayChat:
|
|
"""chat() 方法测试"""
|
|
|
|
async def test_chat_forwards_to_correct_provider(self):
|
|
gateway = LLMGateway()
|
|
fake = FakeProvider("openai")
|
|
gateway.register_provider("openai", fake)
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
assert response.content == "response from openai"
|
|
assert fake.last_request is not None
|
|
assert fake.last_request.model == "gpt-4o"
|
|
|
|
async def test_chat_records_usage(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="test_agent",
|
|
)
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens > 0
|
|
|
|
async def test_chat_no_provider_raises_error(self):
|
|
gateway = LLMGateway()
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="nonexistent/model",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayModelAlias:
|
|
"""模型别名解析测试"""
|
|
|
|
async def test_model_alias_resolves(self):
|
|
config = LLMConfig(
|
|
providers={"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1")},
|
|
model_aliases={"fast": "openai/gpt-4o-mini"},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
fake = FakeProvider("openai")
|
|
gateway.register_provider("openai", fake)
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="fast",
|
|
)
|
|
assert response.content == "response from openai"
|
|
assert fake.last_request.model == "gpt-4o-mini"
|
|
|
|
async def test_nonexistent_model_alias_raises_error(self):
|
|
config = LLMConfig(
|
|
model_aliases={"fast": "openai/gpt-4o-mini"},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="nonexistent_alias",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayFallback:
|
|
"""Fallback 策略测试"""
|
|
|
|
async def test_fallback_on_primary_failure(self):
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
|
|
},
|
|
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
assert response.content == "response from deepseek"
|
|
|
|
async def test_no_fallback_raises_error(self):
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayUsage:
|
|
"""Usage 查询测试"""
|
|
|
|
async def test_get_usage_by_agent_name(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="agent_a",
|
|
)
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="agent_b",
|
|
)
|
|
|
|
usage_a = gateway.get_usage(agent_name="agent_a")
|
|
assert usage_a.total_tokens > 0
|
|
assert all(r.agent_name == "agent_a" for r in usage_a.records)
|
|
|
|
async def test_get_usage_empty(self):
|
|
gateway = LLMGateway()
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens == 0
|
|
assert usage.total_cost == 0.0
|
|
assert len(usage.records) == 0
|