feat(streaming): Phase C - LLM streaming + ReAct events + SSE endpoint
U8: StreamChunk protocol + OpenAI chat_stream + Gateway streaming with usage tracking U9: ReActEvent dataclass + execute_stream() yielding thinking/tool_call/tool_result/final_answer U10: POST /tasks/stream SSE endpoint + Client SDK stream_task() 15 new tests passing, no regression.
This commit is contained in:
parent
ec0e221beb
commit
2844eeb548
|
|
@ -462,4 +462,4 @@ def register_geo_tools(registry: ToolRegistry) -> None:
|
|||
tags=["knowledge", "deai"],
|
||||
))
|
||||
|
||||
logger.info(f"GEO tools registered: {len(registry.list_all_tools())} tools")
|
||||
logger.info(f"GEO tools registered: {len(registry.list_tools())} tools")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ dependencies = [
|
|||
server = [
|
||||
"fastapi>=0.110",
|
||||
"uvicorn>=0.27",
|
||||
"sse-starlette>=2.0",
|
||||
]
|
||||
mcp = [
|
||||
"mcp>=1.0",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import json
|
|||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
|
@ -39,6 +40,16 @@ class ReActResult:
|
|||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReActEvent:
|
||||
"""ReAct 执行事件"""
|
||||
|
||||
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error"
|
||||
step: int
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
class ReActEngine:
|
||||
"""ReAct 推理-行动循环引擎
|
||||
|
||||
|
|
@ -186,6 +197,172 @@ class ReActEngine:
|
|||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
):
|
||||
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||
|
||||
Same logic as execute() but yields events at each step instead of
|
||||
accumulating a result.
|
||||
"""
|
||||
tools = tools or []
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
conversation.extend(messages)
|
||||
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = 0
|
||||
step = 0
|
||||
output = ""
|
||||
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
||||
# Yield thinking event
|
||||
yield ReActEvent(
|
||||
event_type="thinking",
|
||||
step=step,
|
||||
data={"message": f"Step {step}: Calling LLM..."},
|
||||
)
|
||||
|
||||
# Think: call LLM
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
|
||||
step_tokens = response.usage.total_tokens
|
||||
total_tokens += step_tokens
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Record assistant message
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
}
|
||||
conversation.append(assistant_msg)
|
||||
|
||||
for tc in response.tool_calls:
|
||||
# Yield tool_call event
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# Yield tool_result event
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
|
||||
tool_msg = self._build_tool_result_message(tc.id, tool_result)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
else:
|
||||
# Check text parsing mode
|
||||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||
if parsed_calls and tools:
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
|
||||
for pc in parsed_calls:
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||||
)
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
trajectory.append(ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
arguments=pc["arguments"],
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
))
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
data={"tool_name": pc["name"], "result": tool_result},
|
||||
)
|
||||
tool_msg = self._build_tool_result_message(
|
||||
pc.get("id", f"text_tc_{step}"), tool_result
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
else:
|
||||
# Final answer
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
content=response.content,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
output = response.content or ""
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": len(trajectory),
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
if step >= self._max_steps and not output:
|
||||
if trajectory and trajectory[-1].content:
|
||||
output = trajectory[-1].content
|
||||
elif trajectory and trajectory[-1].result is not None:
|
||||
output = str(trajectory[-1].result)
|
||||
else:
|
||||
output = response.content or ""
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": len(trajectory),
|
||||
"total_tokens": total_tokens,
|
||||
"max_steps_reached": True,
|
||||
},
|
||||
)
|
||||
|
||||
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
||||
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
||||
schemas = []
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import time
|
|||
|
||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||
from agentkit.llm.config import LLMConfig
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
|
||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,6 +97,62 @@ class LLMGateway:
|
|||
|
||||
return response
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
"""Stream chat response, yielding StreamChunk objects"""
|
||||
resolved_model = self._resolve_model_alias(model)
|
||||
|
||||
if not self._providers:
|
||||
raise LLMProviderError("", "No provider registered")
|
||||
|
||||
try:
|
||||
provider, actual_model = self._resolve_model(resolved_model)
|
||||
except ModelNotFoundError as e:
|
||||
raise LLMProviderError("", str(e)) from e
|
||||
|
||||
request = LLMRequest(
|
||||
messages=messages,
|
||||
model=actual_model,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
total_content = ""
|
||||
final_usage = None
|
||||
final_model = resolved_model
|
||||
|
||||
async for chunk in provider.chat_stream(request):
|
||||
if chunk.content:
|
||||
total_content += chunk.content
|
||||
if chunk.usage:
|
||||
final_usage = chunk.usage
|
||||
if chunk.model:
|
||||
final_model = chunk.model
|
||||
yield chunk
|
||||
|
||||
# Track usage after stream completes
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
if final_usage is None:
|
||||
final_usage = TokenUsage()
|
||||
cost = self._calculate_cost(final_model, final_usage)
|
||||
self._usage_tracker.record(
|
||||
agent_name=agent_name,
|
||||
model=final_model,
|
||||
usage=final_usage,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
def _resolve_model_alias(self, model: str) -> str:
|
||||
"""解析模型别名"""
|
||||
if model in self._config.model_aliases:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,17 @@ class LLMRequest:
|
|||
self._extra = kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""LLM 流式响应块"""
|
||||
|
||||
content: str # Delta content
|
||||
model: str
|
||||
tool_calls: list[ToolCall] = field(default_factory=list) # Accumulated tool calls (only in final chunk)
|
||||
usage: TokenUsage | None = None # Only in final chunk
|
||||
is_final: bool = False # True for the last chunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM 响应"""
|
||||
|
|
@ -78,3 +89,18 @@ class LLMProvider(ABC):
|
|||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
"""发送 chat 请求并返回响应"""
|
||||
...
|
||||
|
||||
async def chat_stream(self, request: LLMRequest):
|
||||
"""Stream chat response. Override in subclasses that support streaming.
|
||||
|
||||
Yields StreamChunk objects. Default implementation falls back to
|
||||
non-streaming chat and yields a single chunk.
|
||||
"""
|
||||
response = await self.chat(request)
|
||||
yield StreamChunk(
|
||||
content=response.content,
|
||||
model=response.model,
|
||||
tool_calls=response.tool_calls,
|
||||
usage=response.usage,
|
||||
is_final=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
import httpx
|
||||
|
||||
from agentkit.core.exceptions import LLMProviderError
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -100,3 +100,108 @@ class OpenAICompatibleProvider(LLMProvider):
|
|||
tool_calls=tool_calls,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
async def chat_stream(self, request: LLMRequest):
|
||||
"""Stream chat response using SSE"""
|
||||
url = f"{self._base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload: dict = {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
if request.tools:
|
||||
payload["tools"] = request.tools
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
async with self._client.stream("POST", url, json=payload, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.aread()
|
||||
raise LLMProviderError("openai", f"HTTP {response.status_code}")
|
||||
|
||||
accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str}
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
# Usage-only chunk
|
||||
usage_data = data.get("usage")
|
||||
if usage_data:
|
||||
yield StreamChunk(
|
||||
content="",
|
||||
model=data.get("model", request.model),
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
# Accumulate tool calls from streaming
|
||||
raw_tool_calls = delta.get("tool_calls")
|
||||
if raw_tool_calls:
|
||||
for tc in raw_tool_calls:
|
||||
idx = tc.get("index", 0)
|
||||
if idx not in accumulated_tool_calls:
|
||||
accumulated_tool_calls[idx] = {
|
||||
"id": tc.get("id", ""),
|
||||
"name": "",
|
||||
"arguments_str": "",
|
||||
}
|
||||
if tc.get("id"):
|
||||
accumulated_tool_calls[idx]["id"] = tc["id"]
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
accumulated_tool_calls[idx]["name"] = func["name"]
|
||||
if func.get("arguments"):
|
||||
accumulated_tool_calls[idx]["arguments_str"] += func["arguments"]
|
||||
|
||||
# Only yield content chunks (not empty deltas)
|
||||
if content:
|
||||
yield StreamChunk(
|
||||
content=content,
|
||||
model=data.get("model", request.model),
|
||||
)
|
||||
|
||||
# If we accumulated tool calls, yield them as a final chunk
|
||||
if accumulated_tool_calls:
|
||||
tool_calls = []
|
||||
for idx in sorted(accumulated_tool_calls.keys()):
|
||||
tc_data = accumulated_tool_calls[idx]
|
||||
try:
|
||||
arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {}
|
||||
except json.JSONDecodeError:
|
||||
arguments = {"raw": tc_data["arguments_str"]}
|
||||
tool_calls.append(ToolCall(
|
||||
id=tc_data["id"],
|
||||
name=tc_data["name"],
|
||||
arguments=arguments,
|
||||
))
|
||||
yield StreamChunk(
|
||||
content="",
|
||||
model=request.model,
|
||||
tool_calls=tool_calls,
|
||||
is_final=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -126,6 +126,38 @@ class AgentKitClient:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def stream_task(
|
||||
self,
|
||||
input_data: dict,
|
||||
skill_name: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
):
|
||||
"""Stream task execution events via SSE.
|
||||
|
||||
Yields event dicts with 'event' and 'data' keys.
|
||||
"""
|
||||
payload: dict[str, Any] = {"input_data": input_data}
|
||||
if skill_name:
|
||||
payload["skill_name"] = skill_name
|
||||
if agent_name:
|
||||
payload["agent_name"] = agent_name
|
||||
|
||||
async with self._client.stream(
|
||||
"POST", "/api/v1/tasks/stream", json=payload
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
event_type = ""
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("event: "):
|
||||
event_type = line[7:]
|
||||
elif line.startswith("data: "):
|
||||
import json as _json
|
||||
data = _json.loads(line[6:])
|
||||
yield {"event": event_type, "data": data}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client"""
|
||||
await self._client.aclose()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Task submission routes"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
|
@ -191,3 +192,79 @@ async def cancel_task(task_id: str, req: Request):
|
|||
if not cancelled:
|
||||
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
|
||||
return {"task_id": task_id, "status": "cancelled"}
|
||||
|
||||
|
||||
@router.post("/tasks/stream")
|
||||
async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||
"""Submit a task and stream ReAct events via SSE"""
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
pool = req.app.state.agent_pool
|
||||
skill_registry = req.app.state.skill_registry
|
||||
intent_router = req.app.state.intent_router
|
||||
|
||||
agent = None
|
||||
|
||||
# Same agent resolution logic as submit_task
|
||||
if request.agent_name:
|
||||
agent = pool.get_agent(request.agent_name)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Agent '{request.agent_name}' not found",
|
||||
)
|
||||
elif request.skill_name:
|
||||
try:
|
||||
skill_registry.get(request.skill_name)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Skill '{request.skill_name}' not found",
|
||||
)
|
||||
agent = pool.get_agent(request.skill_name)
|
||||
if agent is None:
|
||||
agent = await pool.create_agent_from_skill(request.skill_name)
|
||||
else:
|
||||
all_skills = skill_registry.list_skills()
|
||||
if not all_skills:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No skills registered and no skill_name or agent_name specified",
|
||||
)
|
||||
try:
|
||||
routing_result = await intent_router.route(request.input_data, all_skills)
|
||||
skill_registry.get(routing_result.matched_skill)
|
||||
agent = pool.get_agent(routing_result.matched_skill)
|
||||
if agent is None:
|
||||
agent = await pool.create_agent_from_skill(routing_result.matched_skill)
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
async def event_generator():
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
||||
|
||||
# Build messages from input
|
||||
messages = [{"role": "user", "content": str(request.input_data)}]
|
||||
|
||||
# Get tools from agent
|
||||
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
|
||||
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=agent._llm_model if hasattr(agent, "_llm_model") else "default",
|
||||
agent_name=agent.name,
|
||||
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
|
||||
):
|
||||
yield {
|
||||
"event": event.event_type,
|
||||
"data": json.dumps({
|
||||
"step": event.step,
|
||||
"data": event.data,
|
||||
"timestamp": event.timestamp,
|
||||
}),
|
||||
}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,431 @@
|
|||
"""Streaming System 单元测试 - U8/U9/U10
|
||||
|
||||
覆盖:
|
||||
- StreamChunk 数据类
|
||||
- LLMProvider.chat_stream 默认回退
|
||||
- Gateway.chat_stream 流式 + 用量追踪
|
||||
- ReActEvent 数据类
|
||||
- ReActEngine.execute_stream 事件流
|
||||
- SSE 端点 /tasks/stream
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
# ── Test Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
class FakeTool(Tool):
|
||||
"""用于测试的 Fake Tool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake_tool",
|
||||
description: str = "A fake tool for testing",
|
||||
result: dict | None = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._result = result or {"status": "ok"}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
return self._result
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
"""快速构造 LLMResponse"""
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# U8: StreamChunk + chat_stream
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestStreamChunk:
|
||||
"""StreamChunk 数据类测试"""
|
||||
|
||||
def test_creation_with_content(self):
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
|
||||
chunk = StreamChunk(content="Hello", model="gpt-4o")
|
||||
assert chunk.content == "Hello"
|
||||
assert chunk.model == "gpt-4o"
|
||||
assert chunk.tool_calls == []
|
||||
assert chunk.usage is None
|
||||
assert chunk.is_final is False
|
||||
|
||||
def test_with_tool_calls_and_is_final(self):
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
|
||||
tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"})
|
||||
chunk = StreamChunk(
|
||||
content="",
|
||||
model="gpt-4o",
|
||||
tool_calls=[tc],
|
||||
is_final=True,
|
||||
)
|
||||
assert len(chunk.tool_calls) == 1
|
||||
assert chunk.tool_calls[0].name == "search"
|
||||
assert chunk.is_final is True
|
||||
|
||||
def test_with_usage(self):
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
|
||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||
chunk = StreamChunk(
|
||||
content="",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
is_final=True,
|
||||
)
|
||||
assert chunk.usage is not None
|
||||
assert chunk.usage.total_tokens == 150
|
||||
assert chunk.is_final is True
|
||||
|
||||
|
||||
class TestLLMProviderChatStreamDefault:
|
||||
"""LLMProvider.chat_stream 默认实现回退到 chat()"""
|
||||
|
||||
async def test_default_chat_stream_yields_single_chunk(self):
|
||||
from agentkit.llm.protocol import LLMProvider, StreamChunk
|
||||
|
||||
class SimpleProvider(LLMProvider):
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="hello",
|
||||
model="test",
|
||||
usage=TokenUsage(prompt_tokens=5, completion_tokens=10),
|
||||
)
|
||||
|
||||
provider = SimpleProvider()
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="test",
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in provider.chat_stream(request):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "hello"
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].usage.total_tokens == 15
|
||||
|
||||
|
||||
class TestGatewayChatStream:
|
||||
"""Gateway.chat_stream 流式测试"""
|
||||
|
||||
async def test_yields_chunks_from_provider(self):
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMProvider
|
||||
|
||||
class StreamingProvider(LLMProvider):
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="fallback",
|
||||
model="test",
|
||||
usage=TokenUsage(),
|
||||
)
|
||||
|
||||
async def chat_stream(self, request: LLMRequest):
|
||||
yield StreamChunk(content="Hello ", model="test")
|
||||
yield StreamChunk(content="World", model="test")
|
||||
yield StreamChunk(
|
||||
content="",
|
||||
model="test",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("test", StreamingProvider())
|
||||
|
||||
chunks = []
|
||||
async for chunk in gateway.chat_stream(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="test/model",
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].content == "Hello "
|
||||
assert chunks[1].content == "World"
|
||||
assert chunks[2].is_final is True
|
||||
|
||||
async def test_tracks_usage_after_stream_completes(self):
|
||||
from agentkit.llm.protocol import StreamChunk
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMProvider
|
||||
|
||||
class StreamingProvider(LLMProvider):
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="fallback",
|
||||
model="test",
|
||||
usage=TokenUsage(),
|
||||
)
|
||||
|
||||
async def chat_stream(self, request: LLMRequest):
|
||||
yield StreamChunk(content="Hi", model="test")
|
||||
yield StreamChunk(
|
||||
content="",
|
||||
model="test",
|
||||
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("test", StreamingProvider())
|
||||
|
||||
# Consume the stream
|
||||
chunks = []
|
||||
async for chunk in gateway.chat_stream(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="test/model",
|
||||
agent_name="stream_agent",
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify usage was tracked
|
||||
usage = gateway.get_usage()
|
||||
assert usage.total_tokens == 150
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# U9: ReActEvent + execute_stream
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestReActEvent:
|
||||
"""ReActEvent 数据类测试"""
|
||||
|
||||
def test_creation_with_event_type_and_step(self):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
event = ReActEvent(event_type="thinking", step=1)
|
||||
assert event.event_type == "thinking"
|
||||
assert event.step == 1
|
||||
assert event.data == {}
|
||||
|
||||
def test_has_timestamp(self):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
event = ReActEvent(event_type="thinking", step=1)
|
||||
assert event.timestamp is not None
|
||||
assert len(event.timestamp) > 0
|
||||
|
||||
def test_with_data(self):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
event = ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=2,
|
||||
data={"tool_name": "search", "arguments": {"q": "test"}},
|
||||
)
|
||||
assert event.data["tool_name"] == "search"
|
||||
|
||||
|
||||
class TestReActEngineExecuteStream:
|
||||
"""ReActEngine.execute_stream 事件流测试"""
|
||||
|
||||
async def test_yields_thinking_event_at_each_step(self):
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(return_value=make_response(content="Final answer"))
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should have thinking + final_answer
|
||||
thinking_events = [e for e in events if e.event_type == "thinking"]
|
||||
assert len(thinking_events) >= 1
|
||||
assert thinking_events[0].step == 1
|
||||
|
||||
async def test_yields_tool_call_and_tool_result_events(self):
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
|
||||
tool = FakeTool(name="calculator", result={"value": 42})
|
||||
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(side_effect=[
|
||||
make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
|
||||
),
|
||||
make_response(content="The result is 42"),
|
||||
])
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Calculate"}],
|
||||
tools=[tool],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_call_events = [e for e in events if e.event_type == "tool_call"]
|
||||
tool_result_events = [e for e in events if e.event_type == "tool_result"]
|
||||
|
||||
assert len(tool_call_events) == 1
|
||||
assert tool_call_events[0].data["tool_name"] == "calculator"
|
||||
assert len(tool_result_events) == 1
|
||||
assert tool_result_events[0].data["tool_name"] == "calculator"
|
||||
assert tool_result_events[0].data["result"] == {"value": 42}
|
||||
|
||||
async def test_yields_final_answer_event(self):
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42"))
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(final_events) == 1
|
||||
assert final_events[0].data["output"] == "The answer is 42"
|
||||
assert final_events[0].data["total_steps"] >= 1
|
||||
assert final_events[0].data["total_tokens"] > 0
|
||||
|
||||
async def test_yields_max_steps_reached_when_hitting_limit(self):
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||
|
||||
always_tool_response = make_response(
|
||||
content="Thinking...",
|
||||
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
||||
)
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(return_value=always_tool_response)
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Keep searching"}],
|
||||
tools=[tool],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(final_events) == 1
|
||||
assert final_events[0].data.get("max_steps_reached") is True
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# U10: SSE Endpoint + Client SDK
|
||||
# ══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestSSEEndpoint:
|
||||
"""SSE /tasks/stream 端点测试"""
|
||||
|
||||
def test_stream_task_returns_event_source_response(self):
|
||||
from fastapi.testclient import TestClient
|
||||
from agentkit.server.app import create_app
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
gateway = LLMGateway()
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.chat.return_value = LLMResponse(
|
||||
content="Final answer",
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
gateway.register_provider("test", mock_provider)
|
||||
|
||||
skill_registry = SkillRegistry()
|
||||
tool_registry = ToolRegistry()
|
||||
app = create_app(
|
||||
llm_gateway=gateway,
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
# Create an agent first
|
||||
client.post(
|
||||
"/api/v1/agents",
|
||||
json={
|
||||
"config": {
|
||||
"name": "stream_agent",
|
||||
"agent_type": "test",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "Stream Agent"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Stream task
|
||||
response = client.post(
|
||||
"/api/v1/tasks/stream",
|
||||
json={
|
||||
"input_data": {"query": "test"},
|
||||
"agent_name": "stream_agent",
|
||||
},
|
||||
)
|
||||
# Should return 200 with SSE content type
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers.get("content-type", "")
|
||||
|
||||
def test_stream_task_with_invalid_agent_returns_404(self):
|
||||
from fastapi.testclient import TestClient
|
||||
from agentkit.server.app import create_app
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
gateway = LLMGateway()
|
||||
skill_registry = SkillRegistry()
|
||||
tool_registry = ToolRegistry()
|
||||
app = create_app(
|
||||
llm_gateway=gateway,
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/tasks/stream",
|
||||
json={
|
||||
"input_data": {"query": "test"},
|
||||
"agent_name": "nonexistent_agent",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
Loading…
Reference in New Issue