128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
"""Tests for ReAct engine token streaming via execute_stream()."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.core.react import ReActEngine, ReActEvent
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage
|
|
|
|
|
|
def _make_stream_chunks(
|
|
content_parts: list[str],
|
|
model: str = "test-model",
|
|
usage: TokenUsage | None = None,
|
|
) -> list[StreamChunk]:
|
|
"""Build a list of StreamChunk objects simulating a streaming response."""
|
|
chunks = []
|
|
for part in content_parts:
|
|
chunks.append(StreamChunk(content=part, model=model))
|
|
# Final chunk with usage
|
|
chunks.append(StreamChunk(
|
|
content="",
|
|
model=model,
|
|
usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
is_final=True,
|
|
))
|
|
return chunks
|
|
|
|
|
|
def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock:
|
|
"""Create a mock LLMGateway whose chat_stream yields the given chunks."""
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
|
|
async def _stream(**kwargs):
|
|
for chunks in chunks_list:
|
|
for chunk in chunks:
|
|
yield chunk
|
|
# Remove after use
|
|
chunks_list.pop(0)
|
|
|
|
gateway.chat_stream = _stream
|
|
return gateway
|
|
|
|
|
|
class TestReActTokenStreaming:
|
|
"""Test that execute_stream() yields token events from chat_stream()."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_yields_token_events(self):
|
|
"""execute_stream should yield token events for each stream chunk."""
|
|
chunks = _make_stream_chunks(["Hello", " world", "!"])
|
|
gateway = _make_stream_gateway([chunks])
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
):
|
|
events.append(event)
|
|
|
|
token_events = [e for e in events if e.event_type == "token"]
|
|
assert len(token_events) == 3
|
|
assert token_events[0].data["content"] == "Hello"
|
|
assert token_events[1].data["content"] == " world"
|
|
assert token_events[2].data["content"] == "!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_final_answer_after_streaming(self):
|
|
"""After streaming tokens, a final_answer event should be yielded."""
|
|
chunks = _make_stream_chunks(["The answer is 42"])
|
|
gateway = _make_stream_gateway([chunks])
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "What?"}],
|
|
):
|
|
events.append(event)
|
|
|
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
|
assert len(final_events) == 1
|
|
assert "42" in final_events[0].data.get("output", "")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_events_have_correct_step(self):
|
|
"""Token events should carry the current step number."""
|
|
chunks = _make_stream_chunks(["Hi"])
|
|
gateway = _make_stream_gateway([chunks])
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
):
|
|
events.append(event)
|
|
|
|
token_events = [e for e in events if e.event_type == "token"]
|
|
for te in token_events:
|
|
assert te.step >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_response_from_stream(self):
|
|
"""_build_response_from_stream should construct a valid LLMResponse."""
|
|
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
|
|
response = ReActEngine._build_response_from_stream(
|
|
content="Hello world",
|
|
tool_calls=[],
|
|
usage=usage,
|
|
model="test-model",
|
|
)
|
|
assert response.content == "Hello world"
|
|
assert response.model == "test-model"
|
|
assert response.usage.prompt_tokens == 5
|
|
assert response.usage.completion_tokens == 10
|
|
assert not response.has_tool_calls
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_response_from_stream_no_usage(self):
|
|
"""_build_response_from_stream should handle None usage gracefully."""
|
|
response = ReActEngine._build_response_from_stream(
|
|
content="test",
|
|
tool_calls=[],
|
|
usage=None,
|
|
model="test-model",
|
|
)
|
|
assert response.content == "test"
|
|
assert response.usage.prompt_tokens == 0
|