fischer-agentkit/tests/unit/test_llm_protocol.py

150 lines
5.1 KiB
Python

"""LLM Protocol 数据类测试"""
import pytest
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
class TestTokenUsage:
"""TokenUsage 数据类测试"""
def test_default_values(self):
usage = TokenUsage()
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
def test_custom_values(self):
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
def test_total_tokens_computed(self):
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
assert usage.total_tokens == 150
class TestToolCall:
"""ToolCall 数据类测试"""
def test_tool_call_creation(self):
tc = ToolCall(id="call_123", name="get_weather", arguments={"city": "Beijing"})
assert tc.id == "call_123"
assert tc.name == "get_weather"
assert tc.arguments == {"city": "Beijing"}
def test_tool_call_with_empty_arguments(self):
tc = ToolCall(id="call_456", name="list_items", arguments={})
assert tc.arguments == {}
class TestLLMRequest:
"""LLMRequest 数据类测试"""
def test_basic_request(self):
request = LLMRequest(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4o-mini",
)
assert len(request.messages) == 1
assert request.model == "gpt-4o-mini"
assert request.tools is None
assert request.tool_choice == "auto"
assert request.temperature == 0.7
assert request.max_tokens == 2000
def test_request_with_tools(self):
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
},
}
]
request = LLMRequest(
messages=[{"role": "user", "content": "What's the weather?"}],
model="gpt-4o",
tools=tools,
tool_choice="auto",
temperature=0.0,
max_tokens=1000,
)
assert request.tools is not None
assert len(request.tools) == 1
assert request.temperature == 0.0
assert request.max_tokens == 1000
def test_request_with_extra_kwargs(self):
request = LLMRequest(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4o",
top_p=0.9,
)
assert request.model == "gpt-4o"
class TestLLMResponse:
"""LLMResponse 数据类测试"""
def test_basic_response(self):
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
response = LLMResponse(content="Hello!", model="gpt-4o-mini", usage=usage)
assert response.content == "Hello!"
assert response.model == "gpt-4o-mini"
assert response.usage.total_tokens == 30
assert response.tool_calls == []
assert response.latency_ms == 0.0
def test_response_with_tool_calls(self):
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
tool_calls = [
ToolCall(id="call_1", name="get_weather", arguments={"city": "Beijing"})
]
response = LLMResponse(
content="", model="gpt-4o", usage=usage, tool_calls=tool_calls, latency_ms=150.5
)
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "get_weather"
assert response.latency_ms == 150.5
def test_has_tool_calls_true(self):
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
tool_calls = [ToolCall(id="call_1", name="search", arguments={"q": "test"})]
response = LLMResponse(content="", model="gpt-4o", usage=usage, tool_calls=tool_calls)
assert response.has_tool_calls is True
def test_has_tool_calls_false(self):
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
response = LLMResponse(content="Hello!", model="gpt-4o-mini", usage=usage)
assert response.has_tool_calls is False
class TestLLMProvider:
"""LLMProvider ABC 测试"""
def test_cannot_instantiate_directly(self):
with pytest.raises(TypeError):
LLMProvider()
def test_subclass_must_implement_chat(self):
class IncompleteProvider(LLMProvider):
pass
with pytest.raises(TypeError):
IncompleteProvider()
async def test_subclass_with_chat_works(self):
class DummyProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
return LLMResponse(content="hi", model=request.model, usage=usage)
provider = DummyProvider()
request = LLMRequest(messages=[{"role": "user", "content": "hi"}], model="test")
response = await provider.chat(request)
assert response.content == "hi"