150 lines
5.1 KiB
Python
150 lines
5.1 KiB
Python
"""LLM Protocol 数据类测试"""
|
|
|
|
import pytest
|
|
|
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
|
|
|
|
|
|
class TestTokenUsage:
|
|
"""TokenUsage 数据类测试"""
|
|
|
|
def test_default_values(self):
|
|
usage = TokenUsage()
|
|
assert usage.prompt_tokens == 0
|
|
assert usage.completion_tokens == 0
|
|
assert usage.total_tokens == 0
|
|
|
|
def test_custom_values(self):
|
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
|
assert usage.prompt_tokens == 100
|
|
assert usage.completion_tokens == 50
|
|
assert usage.total_tokens == 150
|
|
|
|
def test_total_tokens_computed(self):
|
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
|
assert usage.total_tokens == 150
|
|
|
|
|
|
class TestToolCall:
|
|
"""ToolCall 数据类测试"""
|
|
|
|
def test_tool_call_creation(self):
|
|
tc = ToolCall(id="call_123", name="get_weather", arguments={"city": "Beijing"})
|
|
assert tc.id == "call_123"
|
|
assert tc.name == "get_weather"
|
|
assert tc.arguments == {"city": "Beijing"}
|
|
|
|
def test_tool_call_with_empty_arguments(self):
|
|
tc = ToolCall(id="call_456", name="list_items", arguments={})
|
|
assert tc.arguments == {}
|
|
|
|
|
|
class TestLLMRequest:
|
|
"""LLMRequest 数据类测试"""
|
|
|
|
def test_basic_request(self):
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="gpt-4o-mini",
|
|
)
|
|
assert len(request.messages) == 1
|
|
assert request.model == "gpt-4o-mini"
|
|
assert request.tools is None
|
|
assert request.tool_choice == "auto"
|
|
assert request.temperature == 0.7
|
|
assert request.max_tokens == 2000
|
|
|
|
def test_request_with_tools(self):
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Get weather",
|
|
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
|
},
|
|
}
|
|
]
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "What's the weather?"}],
|
|
model="gpt-4o",
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
temperature=0.0,
|
|
max_tokens=1000,
|
|
)
|
|
assert request.tools is not None
|
|
assert len(request.tools) == 1
|
|
assert request.temperature == 0.0
|
|
assert request.max_tokens == 1000
|
|
|
|
def test_request_with_extra_kwargs(self):
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="gpt-4o",
|
|
top_p=0.9,
|
|
)
|
|
assert request.model == "gpt-4o"
|
|
|
|
|
|
class TestLLMResponse:
|
|
"""LLMResponse 数据类测试"""
|
|
|
|
def test_basic_response(self):
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
response = LLMResponse(content="Hello!", model="gpt-4o-mini", usage=usage)
|
|
assert response.content == "Hello!"
|
|
assert response.model == "gpt-4o-mini"
|
|
assert response.usage.total_tokens == 30
|
|
assert response.tool_calls == []
|
|
assert response.latency_ms == 0.0
|
|
|
|
def test_response_with_tool_calls(self):
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
tool_calls = [
|
|
ToolCall(id="call_1", name="get_weather", arguments={"city": "Beijing"})
|
|
]
|
|
response = LLMResponse(
|
|
content="", model="gpt-4o", usage=usage, tool_calls=tool_calls, latency_ms=150.5
|
|
)
|
|
assert len(response.tool_calls) == 1
|
|
assert response.tool_calls[0].name == "get_weather"
|
|
assert response.latency_ms == 150.5
|
|
|
|
def test_has_tool_calls_true(self):
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
tool_calls = [ToolCall(id="call_1", name="search", arguments={"q": "test"})]
|
|
response = LLMResponse(content="", model="gpt-4o", usage=usage, tool_calls=tool_calls)
|
|
assert response.has_tool_calls is True
|
|
|
|
def test_has_tool_calls_false(self):
|
|
usage = TokenUsage(prompt_tokens=10, completion_tokens=20)
|
|
response = LLMResponse(content="Hello!", model="gpt-4o-mini", usage=usage)
|
|
assert response.has_tool_calls is False
|
|
|
|
|
|
class TestLLMProvider:
|
|
"""LLMProvider ABC 测试"""
|
|
|
|
def test_cannot_instantiate_directly(self):
|
|
with pytest.raises(TypeError):
|
|
LLMProvider()
|
|
|
|
def test_subclass_must_implement_chat(self):
|
|
class IncompleteProvider(LLMProvider):
|
|
pass
|
|
|
|
with pytest.raises(TypeError):
|
|
IncompleteProvider()
|
|
|
|
async def test_subclass_with_chat_works(self):
|
|
class DummyProvider(LLMProvider):
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
|
|
return LLMResponse(content="hi", model=request.model, usage=usage)
|
|
|
|
provider = DummyProvider()
|
|
request = LLMRequest(messages=[{"role": "user", "content": "hi"}], model="test")
|
|
response = await provider.chat(request)
|
|
assert response.content == "hi"
|