335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""LLM Gateway 测试"""
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.config import LLMConfig, ProviderConfig
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
|
|
|
|
|
|
class FakeProvider(LLMProvider):
|
|
"""用于测试的 Fake Provider"""
|
|
|
|
def __init__(self, name: str = "fake", should_fail: bool = False):
|
|
self._name = name
|
|
self._should_fail = should_fail
|
|
self.last_request: LLMRequest | None = None
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
self.last_request = request
|
|
if self._should_fail:
|
|
raise LLMProviderError(self._name, "API error")
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
return LLMResponse(
|
|
content=f"response from {self._name}",
|
|
model=request.model,
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
class FakeStreamProvider(LLMProvider):
|
|
"""Fake Provider with configurable streaming behavior."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "fake",
|
|
should_fail: bool = False,
|
|
fail_after_chunks: int = 0,
|
|
):
|
|
self._name = name
|
|
self._should_fail = should_fail
|
|
self._fail_after_chunks = fail_after_chunks
|
|
self.last_request: LLMRequest | None = None
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
self.last_request = request
|
|
if self._should_fail:
|
|
raise LLMProviderError(self._name, "API error")
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
return LLMResponse(
|
|
content=f"response from {self._name}",
|
|
model=request.model,
|
|
usage=usage,
|
|
)
|
|
|
|
async def chat_stream(self, request: LLMRequest):
|
|
self.last_request = request
|
|
if self._should_fail:
|
|
raise LLMProviderError(self._name, "API error")
|
|
|
|
chunks = ["Hello", " from ", self._name]
|
|
for i, text in enumerate(chunks):
|
|
if self._fail_after_chunks and i >= self._fail_after_chunks:
|
|
raise LLMProviderError(self._name, "Stream interrupted")
|
|
is_final = i == len(chunks) - 1
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20) if is_final else None
|
|
yield StreamChunk(
|
|
content=text,
|
|
model=request.model,
|
|
usage=usage,
|
|
is_final=is_final,
|
|
)
|
|
|
|
|
|
class TestLLMGatewayRegister:
|
|
"""Provider 注册测试"""
|
|
|
|
def test_register_provider(self):
|
|
gateway = LLMGateway()
|
|
provider = FakeProvider("openai")
|
|
gateway.register_provider("openai", provider)
|
|
assert "openai" in gateway._providers
|
|
|
|
def test_register_multiple_providers(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
assert len(gateway._providers) == 2
|
|
|
|
|
|
class TestLLMGatewayChat:
|
|
"""chat() 方法测试"""
|
|
|
|
async def test_chat_forwards_to_correct_provider(self):
|
|
gateway = LLMGateway()
|
|
fake = FakeProvider("openai")
|
|
gateway.register_provider("openai", fake)
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
assert response.content == "response from openai"
|
|
assert fake.last_request is not None
|
|
assert fake.last_request.model == "gpt-4o"
|
|
|
|
async def test_chat_records_usage(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="test_agent",
|
|
)
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens > 0
|
|
|
|
async def test_chat_no_provider_raises_error(self):
|
|
gateway = LLMGateway()
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="nonexistent/model",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayModelAlias:
|
|
"""模型别名解析测试"""
|
|
|
|
async def test_model_alias_resolves(self):
|
|
config = LLMConfig(
|
|
providers={"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1")},
|
|
model_aliases={"fast": "openai/gpt-4o-mini"},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
fake = FakeProvider("openai")
|
|
gateway.register_provider("openai", fake)
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="fast",
|
|
)
|
|
assert response.content == "response from openai"
|
|
assert fake.last_request.model == "gpt-4o-mini"
|
|
|
|
async def test_nonexistent_model_alias_raises_error(self):
|
|
config = LLMConfig(
|
|
model_aliases={"fast": "openai/gpt-4o-mini"},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="nonexistent_alias",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayFallback:
|
|
"""Fallback 策略测试"""
|
|
|
|
async def test_fallback_on_primary_failure(self):
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
|
|
},
|
|
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
|
|
gateway.register_provider("deepseek", FakeProvider("deepseek"))
|
|
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
assert response.content == "response from deepseek"
|
|
|
|
async def test_no_fallback_raises_error(self):
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeProvider("openai", should_fail=True))
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
)
|
|
|
|
|
|
class TestLLMGatewayUsage:
|
|
"""Usage 查询测试"""
|
|
|
|
async def test_get_usage_by_agent_name(self):
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeProvider("openai"))
|
|
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="agent_a",
|
|
)
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="agent_b",
|
|
)
|
|
|
|
usage_a = gateway.get_usage(agent_name="agent_a")
|
|
assert usage_a.total_tokens > 0
|
|
assert all(r.agent_name == "agent_a" for r in usage_a.records)
|
|
|
|
async def test_get_usage_empty(self):
|
|
gateway = LLMGateway()
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens == 0
|
|
assert usage.total_cost == 0.0
|
|
assert len(usage.records) == 0
|
|
|
|
|
|
class TestLLMGatewayStreamFallback:
|
|
"""chat_stream() fallback 策略测试"""
|
|
|
|
async def test_stream_fallback_on_primary_failure(self):
|
|
"""Primary fails before any chunk, fallback succeeds."""
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
|
|
},
|
|
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True))
|
|
gateway.register_provider("deepseek", FakeStreamProvider("deepseek"))
|
|
|
|
chunks = []
|
|
async for chunk in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
content = "".join(c.content for c in chunks)
|
|
assert "deepseek" in content
|
|
assert any(c.is_final for c in chunks)
|
|
|
|
async def test_stream_fails_after_chunks_graceful_termination(self):
|
|
"""Primary fails after chunks sent — yields error chunk and stops."""
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
|
|
},
|
|
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider(
|
|
"openai", FakeStreamProvider("openai", fail_after_chunks=1)
|
|
)
|
|
gateway.register_provider("deepseek", FakeStreamProvider("deepseek"))
|
|
|
|
chunks = []
|
|
async for chunk in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
# Should have: 1 real chunk + 1 error termination chunk
|
|
assert len(chunks) == 2
|
|
assert chunks[0].content == "Hello"
|
|
# Error termination chunk
|
|
assert chunks[1].content == ""
|
|
assert chunks[1].is_final is True
|
|
|
|
async def test_stream_all_models_fail(self):
|
|
"""All models fail — raises exception."""
|
|
config = LLMConfig(
|
|
providers={
|
|
"openai": ProviderConfig(api_key="test", base_url="https://api.openai.com/v1"),
|
|
"deepseek": ProviderConfig(api_key="test", base_url="https://api.deepseek.com/v1"),
|
|
},
|
|
fallbacks={"openai/gpt-4o": ["deepseek/deepseek-chat"]},
|
|
)
|
|
gateway = LLMGateway(config=config)
|
|
gateway.register_provider("openai", FakeStreamProvider("openai", should_fail=True))
|
|
gateway.register_provider("deepseek", FakeStreamProvider("deepseek", should_fail=True))
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
async for _ in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
):
|
|
pass
|
|
|
|
async def test_stream_single_model_no_fallback(self):
|
|
"""Single model with no fallback works normally."""
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeStreamProvider("openai"))
|
|
|
|
chunks = []
|
|
async for chunk in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
content = "".join(c.content for c in chunks)
|
|
assert "openai" in content
|
|
assert any(c.is_final for c in chunks)
|
|
|
|
async def test_stream_records_usage(self):
|
|
"""Usage is tracked after successful stream."""
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("openai", FakeStreamProvider("openai"))
|
|
|
|
async for _ in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="openai/gpt-4o",
|
|
agent_name="stream_agent",
|
|
):
|
|
pass
|
|
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens > 0
|