fischer-agentkit/tests/unit/test_llm_gateway.py

335 lines
12 KiB
Python

"""LLM Gateway 测试"""
import pytest
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig, ProviderConfig
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
class FakeProvider(LLMProvider):
"""用于测试的 Fake Provider"""
def __init__(self, name: str = "fake", should_fail: bool = False):
self._name = name
self._should_fail = should_fail
self.last_request: LLMRequest | None = None
async def chat(self, request: LLMRequest) -> LLMResponse:
self.last_request = request
if self._should_fail:
raise LLMProviderError(self._name, "API error")
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
return LLMResponse(
content=f"response from {self._name}",
model=request.model,
usage=usage,
)
class FakeStreamProvider(LLMProvider):
"""Fake Provider with configurable streaming behavior."""
def __init__(
self,
name: str = "fake",
should_fail: bool = False,
fail_after_chunks: int = 0,
):
self._name = name
self._should_fail = should_fail
self._fail_after_chunks = fail_after_chunks
self.last_request: LLMRequest | None = None
async def chat(self, request: LLMRequest) -> LLMResponse:
self.last_request = request
if self._should_fail:
raise LLMProviderError(self._name, "API error")
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
return LLMResponse(
content=f"response from {self._name}",
model=request.model,
usage=usage,
)
async def chat_stream(self, request: LLMRequest):
self.last_request = request
if self._should_fail:
raise LLMProviderError(self._name, "API error")
chunks = ["Hello", " from ", self._name]
for i, text in enumerate(chunks):
if self._fail_after_chunks and i >= self._fail_after_chunks:
raise LLMProviderError(self._name, "Stream interrupted")
is_final = i == len(chunks) - 1
usage = TokenUsage(prompt_tokens=10, completion_tokens=20) if is_final else None
yield StreamChunk(
content=text,
model=request.model,
usage=usage,
is_final=is_final,
)
class TestLLMGatewayRegister:
"""Provider 注册测试"""
def test_register_provider(self):
gateway = LLMGateway()
provider = FakeProvider("openai")
gateway.register_provider("openai", provider)
assert "openai" in gateway._providers
def test_register_multiple_providers(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
assert len(gateway._providers) == 2
class TestLLMGatewayChat:
"""chat() 方法测试"""
async def test_chat_forwards_to_correct_provider(self):
gateway = LLMGateway()
fake = FakeProvider("openai")
gateway.register_provider("openai", fake)
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
assert response.content == "response from openai"
assert fake.last_request is not None
assert fake.last_request.model == "gpt-4o"
async def test_chat_records_usage(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="test_agent",
)
usage = gateway.get_usage()
assert usage.total_tokens > 0
async def test_chat_no_provider_raises_error(self):
gateway = LLMGateway()
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="nonexistent/model",
)
class TestLLMGatewayModelAlias:
"""模型别名解析测试"""
async def test_model_alias_resolves(self):
config = LLMConfig(
providers={"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1")},
model_aliases={"fast": "openai/gpt-4o-mini"},
)
gateway = LLMGateway(config=config)
fake = FakeProvider("openai")
gateway.register_provider("openai", fake)
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="fast",
)
assert response.content == "response from openai"
assert fake.last_request.model == "gpt-4o-mini"
async def test_nonexistent_model_alias_raises_error(self):
config = LLMConfig(
model_aliases={"fast": "openai/gpt-4o-mini"},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai"))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="nonexistent_alias",
)
class TestLLMGatewayFallback:
"""Fallback 策略测试"""
async def test_fallback_on_primary_failure(self):
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
},
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
gateway.register_provider("deepseek", FakeProvider("deepseek"))
response = await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
assert response.content == "response from deepseek"
async def test_no_fallback_raises_error(self):
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
with pytest.raises(LLMProviderError):
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
)
class TestLLMGatewayUsage:
"""Usage 查询测试"""
async def test_get_usage_by_agent_name(self):
gateway = LLMGateway()
gateway.register_provider("openai", FakeProvider("openai"))
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="agent_a",
)
await gateway.chat(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="agent_b",
)
usage_a = gateway.get_usage(agent_name="agent_a")
assert usage_a.total_tokens > 0
assert all(r.agent_name == "agent_a" for r in usage_a.records)
async def test_get_usage_empty(self):
gateway = LLMGateway()
usage = gateway.get_usage()
assert usage.total_tokens == 0
assert usage.total_cost == 0.0
assert len(usage.records) == 0
class TestLLMGatewayStreamFallback:
"""chat_stream() fallback 策略测试"""
async def test_stream_fallback_on_primary_failure(self):
"""Primary fails before any chunk, fallback succeeds."""
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
},
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True))
gateway.register_provider("deepseek", FakeStreamProvider("deepseek"))
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
):
chunks.append(chunk)
content = "".join(c.content for c in chunks)
assert "deepseek" in content
assert any(c.is_final for c in chunks)
async def test_stream_fails_after_chunks_graceful_termination(self):
"""Primary fails after chunks sent — yields error chunk and stops."""
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
},
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
)
gateway = LLMGateway(config=config)
gateway.register_provider(
"openai", FakeStreamProvider("openai", fail_after_chunks=1)
)
gateway.register_provider("deepseek", FakeStreamProvider("deepseek"))
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
):
chunks.append(chunk)
# Should have: 1 real chunk + 1 error termination chunk
assert len(chunks) == 2
assert chunks[0].content == "Hello"
# Error termination chunk
assert chunks[1].content == ""
assert chunks[1].is_final is True
async def test_stream_all_models_fail(self):
"""All models fail — raises exception."""
config = LLMConfig(
providers={
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
},
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
)
gateway = LLMGateway(config=config)
gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True))
gateway.register_provider("deepseek", FakeStreamProvider("deepseek", should_fail=True))
with pytest.raises(LLMProviderError):
async for _ in gateway.chat_stream(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
):
pass
async def test_stream_single_model_no_fallback(self):
"""Single model with no fallback works normally."""
gateway = LLMGateway()
gateway.register_provider("openai", FakeStreamProvider("openai"))
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
):
chunks.append(chunk)
content = "".join(c.content for c in chunks)
assert "openai" in content
assert any(c.is_final for c in chunks)
async def test_stream_records_usage(self):
"""Usage is tracked after successful stream."""
gateway = LLMGateway()
gateway.register_provider("openai", FakeStreamProvider("openai"))
async for _ in gateway.chat_stream(
messages=[{"role": "user", "content": "Hello"}],
model="openai/gpt-4o",
agent_name="stream_agent",
):
pass
usage = gateway.get_usage()
assert usage.total_tokens > 0