432 lines
15 KiB
Python
432 lines
15 KiB
Python
"""Streaming System 单元测试 - U8/U9/U10
|
|
|
|
覆盖:
|
|
- StreamChunk 数据类
|
|
- LLMProvider.chat_stream 默认回退
|
|
- Gateway.chat_stream 流式 + 用量追踪
|
|
- ReActEvent 数据类
|
|
- ReActEngine.execute_stream 事件流
|
|
- SSE 端点 /tasks/stream
|
|
"""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall
|
|
from agentkit.tools.base import Tool
|
|
|
|
|
|
# ── Test Helpers ──────────────────────────────────────────
|
|
|
|
|
|
class FakeTool(Tool):
|
|
"""用于测试的 Fake Tool"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "fake_tool",
|
|
description: str = "A fake tool for testing",
|
|
result: dict | None = None,
|
|
):
|
|
super().__init__(name=name, description=description)
|
|
self._result = result or {"status": "ok"}
|
|
|
|
async def execute(self, **kwargs) -> dict:
|
|
return self._result
|
|
|
|
|
|
def make_response(
|
|
content: str = "",
|
|
tool_calls: list[ToolCall] | None = None,
|
|
prompt_tokens: int = 10,
|
|
completion_tokens: int = 20,
|
|
) -> LLMResponse:
|
|
"""快速构造 LLMResponse"""
|
|
return LLMResponse(
|
|
content=content,
|
|
model="test-model",
|
|
usage=TokenUsage(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
),
|
|
tool_calls=tool_calls or [],
|
|
)
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════
|
|
# U8: StreamChunk + chat_stream
|
|
# ══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestStreamChunk:
|
|
"""StreamChunk 数据类测试"""
|
|
|
|
def test_creation_with_content(self):
|
|
from agentkit.llm.protocol import StreamChunk
|
|
|
|
chunk = StreamChunk(content="Hello", model="gpt-4o")
|
|
assert chunk.content == "Hello"
|
|
assert chunk.model == "gpt-4o"
|
|
assert chunk.tool_calls == []
|
|
assert chunk.usage is None
|
|
assert chunk.is_final is False
|
|
|
|
def test_with_tool_calls_and_is_final(self):
|
|
from agentkit.llm.protocol import StreamChunk
|
|
|
|
tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"})
|
|
chunk = StreamChunk(
|
|
content="",
|
|
model="gpt-4o",
|
|
tool_calls=[tc],
|
|
is_final=True,
|
|
)
|
|
assert len(chunk.tool_calls) == 1
|
|
assert chunk.tool_calls[0].name == "search"
|
|
assert chunk.is_final is True
|
|
|
|
def test_with_usage(self):
|
|
from agentkit.llm.protocol import StreamChunk
|
|
|
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
|
chunk = StreamChunk(
|
|
content="",
|
|
model="gpt-4o",
|
|
usage=usage,
|
|
is_final=True,
|
|
)
|
|
assert chunk.usage is not None
|
|
assert chunk.usage.total_tokens == 150
|
|
assert chunk.is_final is True
|
|
|
|
|
|
class TestLLMProviderChatStreamDefault:
|
|
"""LLMProvider.chat_stream 默认实现回退到 chat()"""
|
|
|
|
async def test_default_chat_stream_yields_single_chunk(self):
|
|
from agentkit.llm.protocol import LLMProvider, StreamChunk
|
|
|
|
class SimpleProvider(LLMProvider):
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
return LLMResponse(
|
|
content="hello",
|
|
model="test",
|
|
usage=TokenUsage(prompt_tokens=5, completion_tokens=10),
|
|
)
|
|
|
|
provider = SimpleProvider()
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="test",
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(request):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0].content == "hello"
|
|
assert chunks[0].is_final is True
|
|
assert chunks[0].usage.total_tokens == 15
|
|
|
|
|
|
class TestGatewayChatStream:
|
|
"""Gateway.chat_stream 流式测试"""
|
|
|
|
async def test_yields_chunks_from_provider(self):
|
|
from agentkit.llm.protocol import StreamChunk
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMProvider
|
|
|
|
class StreamingProvider(LLMProvider):
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
return LLMResponse(
|
|
content="fallback",
|
|
model="test",
|
|
usage=TokenUsage(),
|
|
)
|
|
|
|
async def chat_stream(self, request: LLMRequest):
|
|
yield StreamChunk(content="Hello ", model="test")
|
|
yield StreamChunk(content="World", model="test")
|
|
yield StreamChunk(
|
|
content="",
|
|
model="test",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
|
is_final=True,
|
|
)
|
|
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("test", StreamingProvider())
|
|
|
|
chunks = []
|
|
async for chunk in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="test/model",
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 3
|
|
assert chunks[0].content == "Hello "
|
|
assert chunks[1].content == "World"
|
|
assert chunks[2].is_final is True
|
|
|
|
async def test_tracks_usage_after_stream_completes(self):
|
|
from agentkit.llm.protocol import StreamChunk
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMProvider
|
|
|
|
class StreamingProvider(LLMProvider):
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
return LLMResponse(
|
|
content="fallback",
|
|
model="test",
|
|
usage=TokenUsage(),
|
|
)
|
|
|
|
async def chat_stream(self, request: LLMRequest):
|
|
yield StreamChunk(content="Hi", model="test")
|
|
yield StreamChunk(
|
|
content="",
|
|
model="test",
|
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
|
is_final=True,
|
|
)
|
|
|
|
gateway = LLMGateway()
|
|
gateway.register_provider("test", StreamingProvider())
|
|
|
|
# Consume the stream
|
|
chunks = []
|
|
async for chunk in gateway.chat_stream(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="test/model",
|
|
agent_name="stream_agent",
|
|
):
|
|
chunks.append(chunk)
|
|
|
|
# Verify usage was tracked
|
|
usage = gateway.get_usage()
|
|
assert usage.total_tokens == 150
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════
|
|
# U9: ReActEvent + execute_stream
|
|
# ══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestReActEvent:
|
|
"""ReActEvent 数据类测试"""
|
|
|
|
def test_creation_with_event_type_and_step(self):
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
event = ReActEvent(event_type="thinking", step=1)
|
|
assert event.event_type == "thinking"
|
|
assert event.step == 1
|
|
assert event.data == {}
|
|
|
|
def test_has_timestamp(self):
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
event = ReActEvent(event_type="thinking", step=1)
|
|
assert event.timestamp is not None
|
|
assert len(event.timestamp) > 0
|
|
|
|
def test_with_data(self):
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
event = ReActEvent(
|
|
event_type="tool_call",
|
|
step=2,
|
|
data={"tool_name": "search", "arguments": {"q": "test"}},
|
|
)
|
|
assert event.data["tool_name"] == "search"
|
|
|
|
|
|
class TestReActEngineExecuteStream:
|
|
"""ReActEngine.execute_stream 事件流测试"""
|
|
|
|
async def test_yields_thinking_event_at_each_step(self):
|
|
from agentkit.core.react import ReActEngine, ReActEvent
|
|
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock(return_value=make_response(content="Final answer"))
|
|
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
):
|
|
events.append(event)
|
|
|
|
# Should have thinking + final_answer
|
|
thinking_events = [e for e in events if e.event_type == "thinking"]
|
|
assert len(thinking_events) >= 1
|
|
assert thinking_events[0].step == 1
|
|
|
|
async def test_yields_tool_call_and_tool_result_events(self):
|
|
from agentkit.core.react import ReActEngine, ReActEvent
|
|
|
|
tool = FakeTool(name="calculator", result={"value": 42})
|
|
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock(side_effect=[
|
|
make_response(
|
|
content="",
|
|
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
|
|
),
|
|
make_response(content="The result is 42"),
|
|
])
|
|
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Calculate"}],
|
|
tools=[tool],
|
|
):
|
|
events.append(event)
|
|
|
|
tool_call_events = [e for e in events if e.event_type == "tool_call"]
|
|
tool_result_events = [e for e in events if e.event_type == "tool_result"]
|
|
|
|
assert len(tool_call_events) == 1
|
|
assert tool_call_events[0].data["tool_name"] == "calculator"
|
|
assert len(tool_result_events) == 1
|
|
assert tool_result_events[0].data["tool_name"] == "calculator"
|
|
assert tool_result_events[0].data["result"] == {"value": 42}
|
|
|
|
async def test_yields_final_answer_event(self):
|
|
from agentkit.core.react import ReActEngine, ReActEvent
|
|
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42"))
|
|
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "What is the answer?"}],
|
|
):
|
|
events.append(event)
|
|
|
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
|
assert len(final_events) == 1
|
|
assert final_events[0].data["output"] == "The answer is 42"
|
|
assert final_events[0].data["total_steps"] >= 1
|
|
assert final_events[0].data["total_tokens"] > 0
|
|
|
|
async def test_yields_max_steps_reached_when_hitting_limit(self):
|
|
from agentkit.core.react import ReActEngine, ReActEvent
|
|
|
|
tool = FakeTool(name="search", result={"results": ["data"]})
|
|
|
|
always_tool_response = make_response(
|
|
content="Thinking...",
|
|
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
|
)
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock(return_value=always_tool_response)
|
|
|
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
|
|
|
events = []
|
|
async for event in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "Keep searching"}],
|
|
tools=[tool],
|
|
):
|
|
events.append(event)
|
|
|
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
|
assert len(final_events) == 1
|
|
assert final_events[0].data.get("max_steps_reached") is True
|
|
|
|
|
|
# ══════════════════════════════════════════════════════════
|
|
# U10: SSE Endpoint + Client SDK
|
|
# ══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestSSEEndpoint:
|
|
"""SSE /tasks/stream 端点测试"""
|
|
|
|
def test_stream_task_returns_event_source_response(self):
|
|
from fastapi.testclient import TestClient
|
|
from agentkit.server.app import create_app
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
gateway = LLMGateway()
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.return_value = LLMResponse(
|
|
content="Final answer",
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
)
|
|
gateway.register_provider("test", mock_provider)
|
|
|
|
skill_registry = SkillRegistry()
|
|
tool_registry = ToolRegistry()
|
|
app = create_app(
|
|
llm_gateway=gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
client = TestClient(app)
|
|
|
|
# Create an agent first
|
|
client.post(
|
|
"/api/v1/agents",
|
|
json={
|
|
"config": {
|
|
"name": "stream_agent",
|
|
"agent_type": "test",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {"identity": "Stream Agent"},
|
|
}
|
|
},
|
|
)
|
|
|
|
# Stream task
|
|
response = client.post(
|
|
"/api/v1/tasks/stream",
|
|
json={
|
|
"input_data": {"query": "test"},
|
|
"agent_name": "stream_agent",
|
|
},
|
|
)
|
|
# Should return 200 with SSE content type
|
|
assert response.status_code == 200
|
|
assert "text/event-stream" in response.headers.get("content-type", "")
|
|
|
|
def test_stream_task_with_invalid_agent_returns_404(self):
|
|
from fastapi.testclient import TestClient
|
|
from agentkit.server.app import create_app
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
gateway = LLMGateway()
|
|
skill_registry = SkillRegistry()
|
|
tool_registry = ToolRegistry()
|
|
app = create_app(
|
|
llm_gateway=gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
client = TestClient(app)
|
|
|
|
response = client.post(
|
|
"/api/v1/tasks/stream",
|
|
json={
|
|
"input_data": {"query": "test"},
|
|
"agent_name": "nonexistent_agent",
|
|
},
|
|
)
|
|
assert response.status_code == 404
|