feat(streaming): Phase C - LLM streaming + ReAct events + SSE endpoint

U8: StreamChunk protocol + OpenAI chat_stream + Gateway streaming with usage tracking
U9: ReActEvent dataclass + execute_stream() yielding thinking/tool_call/tool_result/final_answer
U10: POST /tasks/stream SSE endpoint + Client SDK stream_task()

15 new tests passing, no regression.
This commit is contained in:
chiguyong 2026-06-06 11:54:17 +08:00
parent ec0e221beb
commit 2844eeb548
9 changed files with 908 additions and 3 deletions

View File

@ -462,4 +462,4 @@ def register_geo_tools(registry: ToolRegistry) -> None:
tags=["knowledge", "deai"],
))
logger.info(f"GEO tools registered: {len(registry.list_all_tools())} tools")
logger.info(f"GEO tools registered: {len(registry.list_tools())} tools")

View File

@ -26,6 +26,7 @@ dependencies = [
server = [
"fastapi>=0.110",
"uvicorn>=0.27",
"sse-starlette>=2.0",
]
mcp = [
"mcp>=1.0",

View File

@ -8,6 +8,7 @@ import json
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from agentkit.llm.gateway import LLMGateway
@ -39,6 +40,16 @@ class ReActResult:
total_tokens: int
@dataclass
class ReActEvent:
"""ReAct 执行事件"""
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error"
step: int
data: dict[str, Any] = field(default_factory=dict)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
class ReActEngine:
"""ReAct 推理-行动循环引擎
@ -186,6 +197,172 @@ class ReActEngine:
total_tokens=total_tokens,
)
async def execute_stream(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
system_prompt: str | None = None,
):
"""Execute ReAct loop, yielding ReActEvent objects.
Same logic as execute() but yields events at each step instead of
accumulating a result.
"""
tools = tools or []
tool_schemas = self._build_tool_schemas(tools) if tools else None
conversation: list[dict[str, Any]] = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
conversation.extend(messages)
trajectory: list[ReActStep] = []
total_tokens = 0
step = 0
output = ""
while step < self._max_steps:
step += 1
# Yield thinking event
yield ReActEvent(
event_type="thinking",
step=step,
data={"message": f"Step {step}: Calling LLM..."},
)
# Think: call LLM
response = await self._llm_gateway.chat(
messages=conversation,
model=model,
agent_name=agent_name,
task_type=task_type,
tools=tool_schemas,
)
step_tokens = response.usage.total_tokens
total_tokens += step_tokens
if response.has_tool_calls:
# Record assistant message
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": response.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
},
}
for tc in response.tool_calls
],
}
conversation.append(assistant_msg)
for tc in response.tool_calls:
# Yield tool_call event
yield ReActEvent(
event_type="tool_call",
step=step,
data={"tool_name": tc.name, "arguments": tc.arguments},
)
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
react_step = ReActStep(
step=step,
action="tool_call",
tool_name=tc.name,
arguments=tc.arguments,
result=tool_result,
tokens=step_tokens,
)
trajectory.append(react_step)
# Yield tool_result event
yield ReActEvent(
event_type="tool_result",
step=step,
data={"tool_name": tc.name, "result": tool_result},
)
tool_msg = self._build_tool_result_message(tc.id, tool_result)
conversation.append(tool_msg)
else:
# Check text parsing mode
parsed_calls = self._parse_text_tool_calls(response.content or "")
if parsed_calls and tools:
conversation.append({"role": "assistant", "content": response.content})
for pc in parsed_calls:
yield ReActEvent(
event_type="tool_call",
step=step,
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
)
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
trajectory.append(ReActStep(
step=step,
action="tool_call",
tool_name=pc["name"],
arguments=pc["arguments"],
result=tool_result,
tokens=step_tokens,
))
yield ReActEvent(
event_type="tool_result",
step=step,
data={"tool_name": pc["name"], "result": tool_result},
)
tool_msg = self._build_tool_result_message(
pc.get("id", f"text_tc_{step}"), tool_result
)
conversation.append(tool_msg)
else:
# Final answer
react_step = ReActStep(
step=step,
action="final_answer",
content=response.content,
tokens=step_tokens,
)
trajectory.append(react_step)
output = response.content or ""
yield ReActEvent(
event_type="final_answer",
step=step,
data={
"output": output,
"total_steps": len(trajectory),
"total_tokens": total_tokens,
},
)
break
if step >= self._max_steps and not output:
if trajectory and trajectory[-1].content:
output = trajectory[-1].content
elif trajectory and trajectory[-1].result is not None:
output = str(trajectory[-1].result)
else:
output = response.content or ""
yield ReActEvent(
event_type="final_answer",
step=step,
data={
"output": output,
"total_steps": len(trajectory),
"total_tokens": total_tokens,
"max_steps_reached": True,
},
)
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
schemas = []

View File

@ -5,7 +5,7 @@ import time
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
logger = logging.getLogger(__name__)
@ -97,6 +97,62 @@ class LLMGateway:
return response
async def chat_stream(
self,
messages: list[dict[str, str]],
model: str,
agent_name: str = "",
task_type: str = "",
tools: list[dict] | None = None,
tool_choice: str = "auto",
**kwargs,
):
"""Stream chat response, yielding StreamChunk objects"""
resolved_model = self._resolve_model_alias(model)
if not self._providers:
raise LLMProviderError("", "No provider registered")
try:
provider, actual_model = self._resolve_model(resolved_model)
except ModelNotFoundError as e:
raise LLMProviderError("", str(e)) from e
request = LLMRequest(
messages=messages,
model=actual_model,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
start = time.monotonic()
total_content = ""
final_usage = None
final_model = resolved_model
async for chunk in provider.chat_stream(request):
if chunk.content:
total_content += chunk.content
if chunk.usage:
final_usage = chunk.usage
if chunk.model:
final_model = chunk.model
yield chunk
# Track usage after stream completes
latency_ms = (time.monotonic() - start) * 1000
if final_usage is None:
final_usage = TokenUsage()
cost = self._calculate_cost(final_model, final_usage)
self._usage_tracker.record(
agent_name=agent_name,
model=final_model,
usage=final_usage,
cost=cost,
latency_ms=latency_ms,
)
def _resolve_model_alias(self, model: str) -> str:
"""解析模型别名"""
if model in self._config.model_aliases:

View File

@ -56,6 +56,17 @@ class LLMRequest:
self._extra = kwargs
@dataclass
class StreamChunk:
"""LLM 流式响应块"""
content: str # Delta content
model: str
tool_calls: list[ToolCall] = field(default_factory=list) # Accumulated tool calls (only in final chunk)
usage: TokenUsage | None = None # Only in final chunk
is_final: bool = False # True for the last chunk
@dataclass
class LLMResponse:
"""LLM 响应"""
@ -78,3 +89,18 @@ class LLMProvider(ABC):
async def chat(self, request: LLMRequest) -> LLMResponse:
"""发送 chat 请求并返回响应"""
...
async def chat_stream(self, request: LLMRequest):
"""Stream chat response. Override in subclasses that support streaming.
Yields StreamChunk objects. Default implementation falls back to
non-streaming chat and yields a single chunk.
"""
response = await self.chat(request)
yield StreamChunk(
content=response.content,
model=response.model,
tool_calls=response.tool_calls,
usage=response.usage,
is_final=True,
)

View File

@ -7,7 +7,7 @@ import time
import httpx
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall
logger = logging.getLogger(__name__)
@ -100,3 +100,108 @@ class OpenAICompatibleProvider(LLMProvider):
tool_calls=tool_calls,
latency_ms=latency_ms,
)
async def chat_stream(self, request: LLMRequest):
"""Stream chat response using SSE"""
url = f"{self._base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
payload: dict = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": True,
"stream_options": {"include_usage": True},
}
if request.tools:
payload["tools"] = request.tools
payload["tool_choice"] = request.tool_choice
async with self._client.stream("POST", url, json=payload, headers=headers) as response:
if response.status_code != 200:
error_text = await response.aread()
raise LLMProviderError("openai", f"HTTP {response.status_code}")
accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str}
async for line in response.aiter_lines():
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
# Usage-only chunk
usage_data = data.get("usage")
if usage_data:
yield StreamChunk(
content="",
model=data.get("model", request.model),
usage=TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
),
is_final=True,
)
continue
delta = choices[0].get("delta", {})
content = delta.get("content", "")
# Accumulate tool calls from streaming
raw_tool_calls = delta.get("tool_calls")
if raw_tool_calls:
for tc in raw_tool_calls:
idx = tc.get("index", 0)
if idx not in accumulated_tool_calls:
accumulated_tool_calls[idx] = {
"id": tc.get("id", ""),
"name": "",
"arguments_str": "",
}
if tc.get("id"):
accumulated_tool_calls[idx]["id"] = tc["id"]
func = tc.get("function", {})
if func.get("name"):
accumulated_tool_calls[idx]["name"] = func["name"]
if func.get("arguments"):
accumulated_tool_calls[idx]["arguments_str"] += func["arguments"]
# Only yield content chunks (not empty deltas)
if content:
yield StreamChunk(
content=content,
model=data.get("model", request.model),
)
# If we accumulated tool calls, yield them as a final chunk
if accumulated_tool_calls:
tool_calls = []
for idx in sorted(accumulated_tool_calls.keys()):
tc_data = accumulated_tool_calls[idx]
try:
arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {}
except json.JSONDecodeError:
arguments = {"raw": tc_data["arguments_str"]}
tool_calls.append(ToolCall(
id=tc_data["id"],
name=tc_data["name"],
arguments=arguments,
))
yield StreamChunk(
content="",
model=request.model,
tool_calls=tool_calls,
is_final=True,
)

View File

@ -126,6 +126,38 @@ class AgentKitClient:
response.raise_for_status()
return response.json()
async def stream_task(
self,
input_data: dict,
skill_name: str | None = None,
agent_name: str | None = None,
):
"""Stream task execution events via SSE.
Yields event dicts with 'event' and 'data' keys.
"""
payload: dict[str, Any] = {"input_data": input_data}
if skill_name:
payload["skill_name"] = skill_name
if agent_name:
payload["agent_name"] = agent_name
async with self._client.stream(
"POST", "/api/v1/tasks/stream", json=payload
) as response:
response.raise_for_status()
event_type = ""
async for line in response.aiter_lines():
line = line.strip()
if not line:
continue
if line.startswith("event: "):
event_type = line[7:]
elif line.startswith("data: "):
import json as _json
data = _json.loads(line[6:])
yield {"event": event_type, "data": data}
async def close(self) -> None:
"""Close the HTTP client"""
await self._client.aclose()

View File

@ -1,5 +1,6 @@
"""Task submission routes"""
import json
import uuid
from datetime import datetime, timezone
@ -191,3 +192,79 @@ async def cancel_task(task_id: str, req: Request):
if not cancelled:
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
return {"task_id": task_id, "status": "cancelled"}
@router.post("/tasks/stream")
async def stream_task(request: SubmitTaskRequest, req: Request):
"""Submit a task and stream ReAct events via SSE"""
from sse_starlette.sse import EventSourceResponse
pool = req.app.state.agent_pool
skill_registry = req.app.state.skill_registry
intent_router = req.app.state.intent_router
agent = None
# Same agent resolution logic as submit_task
if request.agent_name:
agent = pool.get_agent(request.agent_name)
if agent is None:
raise HTTPException(
status_code=404,
detail=f"Agent '{request.agent_name}' not found",
)
elif request.skill_name:
try:
skill_registry.get(request.skill_name)
except Exception:
raise HTTPException(
status_code=404,
detail=f"Skill '{request.skill_name}' not found",
)
agent = pool.get_agent(request.skill_name)
if agent is None:
agent = await pool.create_agent_from_skill(request.skill_name)
else:
all_skills = skill_registry.list_skills()
if not all_skills:
raise HTTPException(
status_code=400,
detail="No skills registered and no skill_name or agent_name specified",
)
try:
routing_result = await intent_router.route(request.input_data, all_skills)
skill_registry.get(routing_result.matched_skill)
agent = pool.get_agent(routing_result.matched_skill)
if agent is None:
agent = await pool.create_agent_from_skill(routing_result.matched_skill)
except (ValueError, RuntimeError) as e:
raise HTTPException(status_code=400, detail=str(e))
async def event_generator():
from agentkit.core.react import ReActEngine
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
# Build messages from input
messages = [{"role": "user", "content": str(request.input_data)}]
# Get tools from agent
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
async for event in react_engine.execute_stream(
messages=messages,
tools=tools,
model=agent._llm_model if hasattr(agent, "_llm_model") else "default",
agent_name=agent.name,
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
):
yield {
"event": event.event_type,
"data": json.dumps({
"step": event.step,
"data": event.data,
"timestamp": event.timestamp,
}),
}
return EventSourceResponse(event_generator())

View File

@ -0,0 +1,431 @@
"""Streaming System 单元测试 - U8/U9/U10
覆盖:
- StreamChunk 数据类
- LLMProvider.chat_stream 默认回退
- Gateway.chat_stream 流式 + 用量追踪
- ReActEvent 数据类
- ReActEngine.execute_stream 事件流
- SSE 端点 /tasks/stream
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall
from agentkit.tools.base import Tool
# ── Test Helpers ──────────────────────────────────────────
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str = "fake_tool",
description: str = "A fake tool for testing",
result: dict | None = None,
):
super().__init__(name=name, description=description)
self._result = result or {"status": "ok"}
async def execute(self, **kwargs) -> dict:
return self._result
def make_response(
content: str = "",
tool_calls: list[ToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
"""快速构造 LLMResponse"""
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
)
# ══════════════════════════════════════════════════════════
# U8: StreamChunk + chat_stream
# ══════════════════════════════════════════════════════════
class TestStreamChunk:
"""StreamChunk 数据类测试"""
def test_creation_with_content(self):
from agentkit.llm.protocol import StreamChunk
chunk = StreamChunk(content="Hello", model="gpt-4o")
assert chunk.content == "Hello"
assert chunk.model == "gpt-4o"
assert chunk.tool_calls == []
assert chunk.usage is None
assert chunk.is_final is False
def test_with_tool_calls_and_is_final(self):
from agentkit.llm.protocol import StreamChunk
tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"})
chunk = StreamChunk(
content="",
model="gpt-4o",
tool_calls=[tc],
is_final=True,
)
assert len(chunk.tool_calls) == 1
assert chunk.tool_calls[0].name == "search"
assert chunk.is_final is True
def test_with_usage(self):
from agentkit.llm.protocol import StreamChunk
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
chunk = StreamChunk(
content="",
model="gpt-4o",
usage=usage,
is_final=True,
)
assert chunk.usage is not None
assert chunk.usage.total_tokens == 150
assert chunk.is_final is True
class TestLLMProviderChatStreamDefault:
"""LLMProvider.chat_stream 默认实现回退到 chat()"""
async def test_default_chat_stream_yields_single_chunk(self):
from agentkit.llm.protocol import LLMProvider, StreamChunk
class SimpleProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="hello",
model="test",
usage=TokenUsage(prompt_tokens=5, completion_tokens=10),
)
provider = SimpleProvider()
request = LLMRequest(
messages=[{"role": "user", "content": "hi"}],
model="test",
)
chunks = []
async for chunk in provider.chat_stream(request):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0].content == "hello"
assert chunks[0].is_final is True
assert chunks[0].usage.total_tokens == 15
class TestGatewayChatStream:
"""Gateway.chat_stream 流式测试"""
async def test_yields_chunks_from_provider(self):
from agentkit.llm.protocol import StreamChunk
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider
class StreamingProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="fallback",
model="test",
usage=TokenUsage(),
)
async def chat_stream(self, request: LLMRequest):
yield StreamChunk(content="Hello ", model="test")
yield StreamChunk(content="World", model="test")
yield StreamChunk(
content="",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
is_final=True,
)
gateway = LLMGateway()
gateway.register_provider("test", StreamingProvider())
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "hi"}],
model="test/model",
):
chunks.append(chunk)
assert len(chunks) == 3
assert chunks[0].content == "Hello "
assert chunks[1].content == "World"
assert chunks[2].is_final is True
async def test_tracks_usage_after_stream_completes(self):
from agentkit.llm.protocol import StreamChunk
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider
class StreamingProvider(LLMProvider):
async def chat(self, request: LLMRequest) -> LLMResponse:
return LLMResponse(
content="fallback",
model="test",
usage=TokenUsage(),
)
async def chat_stream(self, request: LLMRequest):
yield StreamChunk(content="Hi", model="test")
yield StreamChunk(
content="",
model="test",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
is_final=True,
)
gateway = LLMGateway()
gateway.register_provider("test", StreamingProvider())
# Consume the stream
chunks = []
async for chunk in gateway.chat_stream(
messages=[{"role": "user", "content": "hi"}],
model="test/model",
agent_name="stream_agent",
):
chunks.append(chunk)
# Verify usage was tracked
usage = gateway.get_usage()
assert usage.total_tokens == 150
# ══════════════════════════════════════════════════════════
# U9: ReActEvent + execute_stream
# ══════════════════════════════════════════════════════════
class TestReActEvent:
"""ReActEvent 数据类测试"""
def test_creation_with_event_type_and_step(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(event_type="thinking", step=1)
assert event.event_type == "thinking"
assert event.step == 1
assert event.data == {}
def test_has_timestamp(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(event_type="thinking", step=1)
assert event.timestamp is not None
assert len(event.timestamp) > 0
def test_with_data(self):
from agentkit.core.react import ReActEvent
event = ReActEvent(
event_type="tool_call",
step=2,
data={"tool_name": "search", "arguments": {"q": "test"}},
)
assert event.data["tool_name"] == "search"
class TestReActEngineExecuteStream:
"""ReActEngine.execute_stream 事件流测试"""
async def test_yields_thinking_event_at_each_step(self):
from agentkit.core.react import ReActEngine, ReActEvent
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=make_response(content="Final answer"))
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hello"}],
):
events.append(event)
# Should have thinking + final_answer
thinking_events = [e for e in events if e.event_type == "thinking"]
assert len(thinking_events) >= 1
assert thinking_events[0].step == 1
async def test_yields_tool_call_and_tool_result_events(self):
from agentkit.core.react import ReActEngine, ReActEvent
tool = FakeTool(name="calculator", result={"value": 42})
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=[
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
),
make_response(content="The result is 42"),
])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Calculate"}],
tools=[tool],
):
events.append(event)
tool_call_events = [e for e in events if e.event_type == "tool_call"]
tool_result_events = [e for e in events if e.event_type == "tool_result"]
assert len(tool_call_events) == 1
assert tool_call_events[0].data["tool_name"] == "calculator"
assert len(tool_result_events) == 1
assert tool_result_events[0].data["tool_name"] == "calculator"
assert tool_result_events[0].data["result"] == {"value": 42}
async def test_yields_final_answer_event(self):
from agentkit.core.react import ReActEngine, ReActEvent
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42"))
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "What is the answer?"}],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert final_events[0].data["output"] == "The answer is 42"
assert final_events[0].data["total_steps"] >= 1
assert final_events[0].data["total_tokens"] > 0
async def test_yields_max_steps_reached_when_hitting_limit(self):
from agentkit.core.react import ReActEngine, ReActEvent
tool = FakeTool(name="search", result={"results": ["data"]})
always_tool_response = make_response(
content="Thinking...",
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
)
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=always_tool_response)
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Keep searching"}],
tools=[tool],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert final_events[0].data.get("max_steps_reached") is True
# ══════════════════════════════════════════════════════════
# U10: SSE Endpoint + Client SDK
# ══════════════════════════════════════════════════════════
class TestSSEEndpoint:
"""SSE /tasks/stream 端点测试"""
def test_stream_task_returns_event_source_response(self):
from fastapi.testclient import TestClient
from agentkit.server.app import create_app
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
gateway = LLMGateway()
mock_provider = AsyncMock()
mock_provider.chat.return_value = LLMResponse(
content="Final answer",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
skill_registry = SkillRegistry()
tool_registry = ToolRegistry()
app = create_app(
llm_gateway=gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
client = TestClient(app)
# Create an agent first
client.post(
"/api/v1/agents",
json={
"config": {
"name": "stream_agent",
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "Stream Agent"},
}
},
)
# Stream task
response = client.post(
"/api/v1/tasks/stream",
json={
"input_data": {"query": "test"},
"agent_name": "stream_agent",
},
)
# Should return 200 with SSE content type
assert response.status_code == 200
assert "text/event-stream" in response.headers.get("content-type", "")
def test_stream_task_with_invalid_agent_returns_404(self):
from fastapi.testclient import TestClient
from agentkit.server.app import create_app
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
gateway = LLMGateway()
skill_registry = SkillRegistry()
tool_registry = ToolRegistry()
app = create_app(
llm_gateway=gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
client = TestClient(app)
response = client.post(
"/api/v1/tasks/stream",
json={
"input_data": {"query": "test"},
"agent_name": "nonexistent_agent",
},
)
assert response.status_code == 404