fischer-agentkit/tests/unit/test_streaming.py

432 lines
15 KiB
Python

"""Streaming System 单元测试 - U8/U9/U10
覆盖:
- StreamChunk 数据类
- LLMProvider.chat_stream 默认回退
- Gateway.chat_stream 流式 + 用量追踪
- ReActEvent 数据类
- ReActEngine.execute_stream 事件流
- SSE 端点 /tasks/stream
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall
from agentkit.tools.base import Tool
# ── Test Helpers ──────────────────────────────────────────
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str = "fake_tool",
description: str = "A fake tool for testing",
result: dict | None = None,
):
super().__init__(name=name, description=description)
self._result = result or {"status": "ok"}
async def execute(self, **kwargs) -> dict:
return self._result
def make_response(
content: str = "",
tool_calls: list[ToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
"""快速构造 LLMResponse"""
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
)
# ══════════════════════════════════════════════════════════
# U8: StreamChunk + chat_stream
# ══════════════════════════════════════════════════════════
class TestStreamChunk:
"""StreamChunk 数据类测试"""
def test_creation_with_content(self):
from agentkit.llm.protocol import StreamChunk
chunk = StreamChunk(content="Hello", model="gpt-4o")
assert chunk.content == "Hello"
assert chunk.model == "gpt-4o"
assert chunk.tool_calls == []
assert chunk.usage is None
assert chunk.is_final is False
def test_with_tool_calls_and_is_final(self):
from agentkit.llm.protocol import StreamChunk
tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"})
chunk = StreamChunk(
content="",
model="gpt-4o",
tool_calls=[tc],
is_final=True,
)
assert len(chunk.tool_calls) == 1
assert chunk.tool_calls[0].name == "search"
assert chunk.is_final is True
def test_with_usage(self):
from agentkit.llm.protocol import StreamChunk
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
chunk = StreamChunk(
content="",
model="gpt-4o",
usage=usage,
is_final=True,
)
assert chunk.usage is not None
assert chunk.usage.total_tokens == 150
assert chunk.is_final is True
class TestLLMProviderChatStreamDefault:
"""LLMProvider.chat_stream 默认实现回退到 chat()"""
async def test_default_chat_stream_yields_single_chunk(self):
from agentkit.llm.protocol import LLMProvider, StreamChunk
class SimpleProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="hello",
model="test",
usage=TokenUsage(prompt_tokens=5, completion_tokens=10),
)
provider = SimpleProvider()
request = LLMRequest(
messages=[{"role": "user", "content": "hi"}],
model="test",
)
chunks = []
async for chunk in provider.chat_stream(request):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0].content == "hello"
assert chunks[0].is_final is True
assert chunks[0].usage.total_tokens == 15
class TestGatewayChatStream:
"""Gateway.chat_stream 流式测试"""
async def test_yields_chunks_from_provider(self):
from agentkit.llm.protocol import StreamChunk
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider
class StreamingProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="fallback",
model="test",
usage=TokenUsage(),
)
async def chat_stream(self, request: LLMRequest):
yield StreamChunk(content="Hello ", model="test")
yield StreamChunk(content="World", model="test")
yield StreamChunk(
content="",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
is_final=True,
)
gateway = LLMGateway()
gateway.register_provider("test", StreamingProvider())
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "hi"}],
model="test/model",
):
chunks.append(chunk)
assert len(chunks) == 3
assert chunks[0].content == "Hello "
assert chunks[1].content == "World"
assert chunks[2].is_final is True
async def test_tracks_usage_after_stream_completes(self):
from agentkit.llm.protocol import StreamChunk
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider
class StreamingProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="fallback",
model="test",
usage=TokenUsage(),
)
async def chat_stream(self, request: LLMRequest):
yield StreamChunk(content="Hi", model="test")
yield StreamChunk(
content="",
model="test",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
is_final=True,
)
gateway = LLMGateway()
gateway.register_provider("test", StreamingProvider())
# Consume the stream
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "hi"}],
model="test/model",
agent_name="stream_agent",
):
chunks.append(chunk)
# Verify usage was tracked
usage = gateway.get_usage()
assert usage.total_tokens == 150
# ══════════════════════════════════════════════════════════
# U9: ReActEvent + execute_stream
# ══════════════════════════════════════════════════════════
class TestReActEvent:
"""ReActEvent 数据类测试"""
def test_creation_with_event_type_and_step(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(event_type="thinking", step=1)
assert event.event_type == "thinking"
assert event.step == 1
assert event.data == {}
def test_has_timestamp(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(event_type="thinking", step=1)
assert event.timestamp is not None
assert len(event.timestamp) > 0
def test_with_data(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(
event_type="tool_call",
step=2,
data={"tool_name": "search", "arguments": {"q": "test"}},
)
assert event.data["tool_name"] == "search"
class TestReActEngineExecuteStream:
"""ReActEngine.execute_stream 事件流测试"""
async def test_yields_thinking_event_at_each_step(self):
from agentkit.core.react import ReActEngine, ReActEvent
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=make_response(content="Final answer"))
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hello"}],
):
events.append(event)
# Should have thinking + final_answer
thinking_events = [e for e in events if e.event_type == "thinking"]
assert len(thinking_events) >= 1
assert thinking_events[0].step == 1
async def test_yields_tool_call_and_tool_result_events(self):
from agentkit.core.react import ReActEngine, ReActEvent
tool = FakeTool(name="calculator", result={"value": 42})
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=[
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
),
make_response(content="The result is 42"),
])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Calculate"}],
tools=[tool],
):
events.append(event)
tool_call_events = [e for e in events if e.event_type == "tool_call"]
tool_result_events = [e for e in events if e.event_type == "tool_result"]
assert len(tool_call_events) == 1
assert tool_call_events[0].data["tool_name"] == "calculator"
assert len(tool_result_events) == 1
assert tool_result_events[0].data["tool_name"] == "calculator"
assert tool_result_events[0].data["result"] == {"value": 42}
async def test_yields_final_answer_event(self):
from agentkit.core.react import ReActEngine, ReActEvent
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42"))
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "What is the answer?"}],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert final_events[0].data["output"] == "The answer is 42"
assert final_events[0].data["total_steps"] >= 1
assert final_events[0].data["total_tokens"] > 0
async def test_yields_max_steps_reached_when_hitting_limit(self):
from agentkit.core.react import ReActEngine, ReActEvent
tool = FakeTool(name="search", result={"results": ["data"]})
always_tool_response = make_response(
content="Thinking...",
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
)
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=always_tool_response)
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Keep searching"}],
tools=[tool],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert final_events[0].data.get("max_steps_reached") is True
# ══════════════════════════════════════════════════════════
# U10: SSE Endpoint + Client SDK
# ══════════════════════════════════════════════════════════
class TestSSEEndpoint:
"""SSE /tasks/stream 端点测试"""
def test_stream_task_returns_event_source_response(self):
from fastapi.testclient import TestClient
from agentkit.server.app import create_app
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
gateway = LLMGateway()
mock_provider = AsyncMock()
mock_provider.chat.return_value = LLMResponse(
content="Final answer",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
skill_registry = SkillRegistry()
tool_registry = ToolRegistry()
app = create_app(
llm_gateway=gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
client = TestClient(app)
# Create an agent first
client.post(
"/api/v1/agents",
json={
"config": {
"name": "stream_agent",
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "Stream Agent"},
}
},
)
# Stream task
response = client.post(
"/api/v1/tasks/stream",
json={
"input_data": {"query": "test"},
"agent_name": "stream_agent",
},
)
# Should return 200 with SSE content type
assert response.status_code == 200
assert "text/event-stream" in response.headers.get("content-type", "")
def test_stream_task_with_invalid_agent_returns_404(self):
from fastapi.testclient import TestClient
from agentkit.server.app import create_app
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
gateway = LLMGateway()
skill_registry = SkillRegistry()
tool_registry = ToolRegistry()
app = create_app(
llm_gateway=gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
client = TestClient(app)
response = client.post(
"/api/v1/tasks/stream",
json={
"input_data": {"query": "test"},
"agent_name": "nonexistent_agent",
},
)
assert response.status_code == 404