feat(streaming): Phase C - LLM streaming + ReAct events + SSE endpoint
U8: StreamChunk protocol + OpenAI chat_stream + Gateway streaming with usage tracking U9: ReActEvent dataclass + execute_stream() yielding thinking/tool_call/tool_result/final_answer U10: POST /tasks/stream SSE endpoint + Client SDK stream_task() 15 new tests passing, no regression.
This commit is contained in:
parent
ec0e221beb
commit
2844eeb548
|
|
@ -462,4 +462,4 @@ def register_geo_tools(registry: ToolRegistry) -> None:
|
||||||
tags=["knowledge", "deai"],
|
tags=["knowledge", "deai"],
|
||||||
))
|
))
|
||||||
|
|
||||||
logger.info(f"GEO tools registered: {len(registry.list_all_tools())} tools")
|
logger.info(f"GEO tools registered: {len(registry.list_tools())} tools")
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ dependencies = [
|
||||||
server = [
|
server = [
|
||||||
"fastapi>=0.110",
|
"fastapi>=0.110",
|
||||||
"uvicorn>=0.27",
|
"uvicorn>=0.27",
|
||||||
|
"sse-starlette>=2.0",
|
||||||
]
|
]
|
||||||
mcp = [
|
mcp = [
|
||||||
"mcp>=1.0",
|
"mcp>=1.0",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
@ -39,6 +40,16 @@ class ReActResult:
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReActEvent:
|
||||||
|
"""ReAct 执行事件"""
|
||||||
|
|
||||||
|
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error"
|
||||||
|
step: int
|
||||||
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
|
||||||
class ReActEngine:
|
class ReActEngine:
|
||||||
"""ReAct 推理-行动循环引擎
|
"""ReAct 推理-行动循环引擎
|
||||||
|
|
||||||
|
|
@ -186,6 +197,172 @@ class ReActEngine:
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def execute_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
tools: list[Tool] | None = None,
|
||||||
|
model: str = "default",
|
||||||
|
agent_name: str = "",
|
||||||
|
task_type: str = "",
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
):
|
||||||
|
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||||
|
|
||||||
|
Same logic as execute() but yields events at each step instead of
|
||||||
|
accumulating a result.
|
||||||
|
"""
|
||||||
|
tools = tools or []
|
||||||
|
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||||
|
|
||||||
|
conversation: list[dict[str, Any]] = []
|
||||||
|
if system_prompt:
|
||||||
|
conversation.append({"role": "system", "content": system_prompt})
|
||||||
|
conversation.extend(messages)
|
||||||
|
|
||||||
|
trajectory: list[ReActStep] = []
|
||||||
|
total_tokens = 0
|
||||||
|
step = 0
|
||||||
|
output = ""
|
||||||
|
|
||||||
|
while step < self._max_steps:
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
# Yield thinking event
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="thinking",
|
||||||
|
step=step,
|
||||||
|
data={"message": f"Step {step}: Calling LLM..."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Think: call LLM
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=conversation,
|
||||||
|
model=model,
|
||||||
|
agent_name=agent_name,
|
||||||
|
task_type=task_type,
|
||||||
|
tools=tool_schemas,
|
||||||
|
)
|
||||||
|
|
||||||
|
step_tokens = response.usage.total_tokens
|
||||||
|
total_tokens += step_tokens
|
||||||
|
|
||||||
|
if response.has_tool_calls:
|
||||||
|
# Record assistant message
|
||||||
|
assistant_msg: dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content or "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in response.tool_calls
|
||||||
|
],
|
||||||
|
}
|
||||||
|
conversation.append(assistant_msg)
|
||||||
|
|
||||||
|
for tc in response.tool_calls:
|
||||||
|
# Yield tool_call event
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=step,
|
||||||
|
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||||
|
react_step = ReActStep(
|
||||||
|
step=step,
|
||||||
|
action="tool_call",
|
||||||
|
tool_name=tc.name,
|
||||||
|
arguments=tc.arguments,
|
||||||
|
result=tool_result,
|
||||||
|
tokens=step_tokens,
|
||||||
|
)
|
||||||
|
trajectory.append(react_step)
|
||||||
|
|
||||||
|
# Yield tool_result event
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_result",
|
||||||
|
step=step,
|
||||||
|
data={"tool_name": tc.name, "result": tool_result},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_msg = self._build_tool_result_message(tc.id, tool_result)
|
||||||
|
conversation.append(tool_msg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Check text parsing mode
|
||||||
|
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||||
|
if parsed_calls and tools:
|
||||||
|
conversation.append({"role": "assistant", "content": response.content})
|
||||||
|
|
||||||
|
for pc in parsed_calls:
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=step,
|
||||||
|
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||||||
|
)
|
||||||
|
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||||
|
trajectory.append(ReActStep(
|
||||||
|
step=step,
|
||||||
|
action="tool_call",
|
||||||
|
tool_name=pc["name"],
|
||||||
|
arguments=pc["arguments"],
|
||||||
|
result=tool_result,
|
||||||
|
tokens=step_tokens,
|
||||||
|
))
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_result",
|
||||||
|
step=step,
|
||||||
|
data={"tool_name": pc["name"], "result": tool_result},
|
||||||
|
)
|
||||||
|
tool_msg = self._build_tool_result_message(
|
||||||
|
pc.get("id", f"text_tc_{step}"), tool_result
|
||||||
|
)
|
||||||
|
conversation.append(tool_msg)
|
||||||
|
else:
|
||||||
|
# Final answer
|
||||||
|
react_step = ReActStep(
|
||||||
|
step=step,
|
||||||
|
action="final_answer",
|
||||||
|
content=response.content,
|
||||||
|
tokens=step_tokens,
|
||||||
|
)
|
||||||
|
trajectory.append(react_step)
|
||||||
|
output = response.content or ""
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="final_answer",
|
||||||
|
step=step,
|
||||||
|
data={
|
||||||
|
"output": output,
|
||||||
|
"total_steps": len(trajectory),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if step >= self._max_steps and not output:
|
||||||
|
if trajectory and trajectory[-1].content:
|
||||||
|
output = trajectory[-1].content
|
||||||
|
elif trajectory and trajectory[-1].result is not None:
|
||||||
|
output = str(trajectory[-1].result)
|
||||||
|
else:
|
||||||
|
output = response.content or ""
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="final_answer",
|
||||||
|
step=step,
|
||||||
|
data={
|
||||||
|
"output": output,
|
||||||
|
"total_steps": len(trajectory),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"max_steps_reached": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
||||||
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
||||||
schemas = []
|
schemas = []
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import time
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||||
from agentkit.llm.config import LLMConfig
|
from agentkit.llm.config import LLMConfig
|
||||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage
|
||||||
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -97,6 +97,62 @@ class LLMGateway:
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
agent_name: str = "",
|
||||||
|
task_type: str = "",
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Stream chat response, yielding StreamChunk objects"""
|
||||||
|
resolved_model = self._resolve_model_alias(model)
|
||||||
|
|
||||||
|
if not self._providers:
|
||||||
|
raise LLMProviderError("", "No provider registered")
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider, actual_model = self._resolve_model(resolved_model)
|
||||||
|
except ModelNotFoundError as e:
|
||||||
|
raise LLMProviderError("", str(e)) from e
|
||||||
|
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=messages,
|
||||||
|
model=actual_model,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
total_content = ""
|
||||||
|
final_usage = None
|
||||||
|
final_model = resolved_model
|
||||||
|
|
||||||
|
async for chunk in provider.chat_stream(request):
|
||||||
|
if chunk.content:
|
||||||
|
total_content += chunk.content
|
||||||
|
if chunk.usage:
|
||||||
|
final_usage = chunk.usage
|
||||||
|
if chunk.model:
|
||||||
|
final_model = chunk.model
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Track usage after stream completes
|
||||||
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
|
if final_usage is None:
|
||||||
|
final_usage = TokenUsage()
|
||||||
|
cost = self._calculate_cost(final_model, final_usage)
|
||||||
|
self._usage_tracker.record(
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=final_model,
|
||||||
|
usage=final_usage,
|
||||||
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
def _resolve_model_alias(self, model: str) -> str:
|
def _resolve_model_alias(self, model: str) -> str:
|
||||||
"""解析模型别名"""
|
"""解析模型别名"""
|
||||||
if model in self._config.model_aliases:
|
if model in self._config.model_aliases:
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,17 @@ class LLMRequest:
|
||||||
self._extra = kwargs
|
self._extra = kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamChunk:
|
||||||
|
"""LLM 流式响应块"""
|
||||||
|
|
||||||
|
content: str # Delta content
|
||||||
|
model: str
|
||||||
|
tool_calls: list[ToolCall] = field(default_factory=list) # Accumulated tool calls (only in final chunk)
|
||||||
|
usage: TokenUsage | None = None # Only in final chunk
|
||||||
|
is_final: bool = False # True for the last chunk
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
"""LLM 响应"""
|
"""LLM 响应"""
|
||||||
|
|
@ -78,3 +89,18 @@ class LLMProvider(ABC):
|
||||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
"""发送 chat 请求并返回响应"""
|
"""发送 chat 请求并返回响应"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def chat_stream(self, request: LLMRequest):
|
||||||
|
"""Stream chat response. Override in subclasses that support streaming.
|
||||||
|
|
||||||
|
Yields StreamChunk objects. Default implementation falls back to
|
||||||
|
non-streaming chat and yields a single chunk.
|
||||||
|
"""
|
||||||
|
response = await self.chat(request)
|
||||||
|
yield StreamChunk(
|
||||||
|
content=response.content,
|
||||||
|
model=response.model,
|
||||||
|
tool_calls=response.tool_calls,
|
||||||
|
usage=response.usage,
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import time
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -100,3 +100,108 @@ class OpenAICompatibleProvider(LLMProvider):
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, request: LLMRequest):
|
||||||
|
"""Stream chat response using SSE"""
|
||||||
|
url = f"{self._base_url}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload: dict = {
|
||||||
|
"model": request.model,
|
||||||
|
"messages": request.messages,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
}
|
||||||
|
if request.tools:
|
||||||
|
payload["tools"] = request.tools
|
||||||
|
payload["tool_choice"] = request.tool_choice
|
||||||
|
|
||||||
|
async with self._client.stream("POST", url, json=payload, headers=headers) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = await response.aread()
|
||||||
|
raise LLMProviderError("openai", f"HTTP {response.status_code}")
|
||||||
|
|
||||||
|
accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str}
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:] # Remove "data: " prefix
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
# Usage-only chunk
|
||||||
|
usage_data = data.get("usage")
|
||||||
|
if usage_data:
|
||||||
|
yield StreamChunk(
|
||||||
|
content="",
|
||||||
|
model=data.get("model", request.model),
|
||||||
|
usage=TokenUsage(
|
||||||
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||||
|
),
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
|
||||||
|
# Accumulate tool calls from streaming
|
||||||
|
raw_tool_calls = delta.get("tool_calls")
|
||||||
|
if raw_tool_calls:
|
||||||
|
for tc in raw_tool_calls:
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
if idx not in accumulated_tool_calls:
|
||||||
|
accumulated_tool_calls[idx] = {
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"name": "",
|
||||||
|
"arguments_str": "",
|
||||||
|
}
|
||||||
|
if tc.get("id"):
|
||||||
|
accumulated_tool_calls[idx]["id"] = tc["id"]
|
||||||
|
func = tc.get("function", {})
|
||||||
|
if func.get("name"):
|
||||||
|
accumulated_tool_calls[idx]["name"] = func["name"]
|
||||||
|
if func.get("arguments"):
|
||||||
|
accumulated_tool_calls[idx]["arguments_str"] += func["arguments"]
|
||||||
|
|
||||||
|
# Only yield content chunks (not empty deltas)
|
||||||
|
if content:
|
||||||
|
yield StreamChunk(
|
||||||
|
content=content,
|
||||||
|
model=data.get("model", request.model),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we accumulated tool calls, yield them as a final chunk
|
||||||
|
if accumulated_tool_calls:
|
||||||
|
tool_calls = []
|
||||||
|
for idx in sorted(accumulated_tool_calls.keys()):
|
||||||
|
tc_data = accumulated_tool_calls[idx]
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
arguments = {"raw": tc_data["arguments_str"]}
|
||||||
|
tool_calls.append(ToolCall(
|
||||||
|
id=tc_data["id"],
|
||||||
|
name=tc_data["name"],
|
||||||
|
arguments=arguments,
|
||||||
|
))
|
||||||
|
yield StreamChunk(
|
||||||
|
content="",
|
||||||
|
model=request.model,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,38 @@ class AgentKitClient:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
async def stream_task(
|
||||||
|
self,
|
||||||
|
input_data: dict,
|
||||||
|
skill_name: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
):
|
||||||
|
"""Stream task execution events via SSE.
|
||||||
|
|
||||||
|
Yields event dicts with 'event' and 'data' keys.
|
||||||
|
"""
|
||||||
|
payload: dict[str, Any] = {"input_data": input_data}
|
||||||
|
if skill_name:
|
||||||
|
payload["skill_name"] = skill_name
|
||||||
|
if agent_name:
|
||||||
|
payload["agent_name"] = agent_name
|
||||||
|
|
||||||
|
async with self._client.stream(
|
||||||
|
"POST", "/api/v1/tasks/stream", json=payload
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
event_type = ""
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("event: "):
|
||||||
|
event_type = line[7:]
|
||||||
|
elif line.startswith("data: "):
|
||||||
|
import json as _json
|
||||||
|
data = _json.loads(line[6:])
|
||||||
|
yield {"event": event_type, "data": data}
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client"""
|
"""Close the HTTP client"""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Task submission routes"""
|
"""Task submission routes"""
|
||||||
|
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
@ -191,3 +192,79 @@ async def cancel_task(task_id: str, req: Request):
|
||||||
if not cancelled:
|
if not cancelled:
|
||||||
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
|
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
|
||||||
return {"task_id": task_id, "status": "cancelled"}
|
return {"task_id": task_id, "status": "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tasks/stream")
|
||||||
|
async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||||
|
"""Submit a task and stream ReAct events via SSE"""
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
pool = req.app.state.agent_pool
|
||||||
|
skill_registry = req.app.state.skill_registry
|
||||||
|
intent_router = req.app.state.intent_router
|
||||||
|
|
||||||
|
agent = None
|
||||||
|
|
||||||
|
# Same agent resolution logic as submit_task
|
||||||
|
if request.agent_name:
|
||||||
|
agent = pool.get_agent(request.agent_name)
|
||||||
|
if agent is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Agent '{request.agent_name}' not found",
|
||||||
|
)
|
||||||
|
elif request.skill_name:
|
||||||
|
try:
|
||||||
|
skill_registry.get(request.skill_name)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Skill '{request.skill_name}' not found",
|
||||||
|
)
|
||||||
|
agent = pool.get_agent(request.skill_name)
|
||||||
|
if agent is None:
|
||||||
|
agent = await pool.create_agent_from_skill(request.skill_name)
|
||||||
|
else:
|
||||||
|
all_skills = skill_registry.list_skills()
|
||||||
|
if not all_skills:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No skills registered and no skill_name or agent_name specified",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
routing_result = await intent_router.route(request.input_data, all_skills)
|
||||||
|
skill_registry.get(routing_result.matched_skill)
|
||||||
|
agent = pool.get_agent(routing_result.matched_skill)
|
||||||
|
if agent is None:
|
||||||
|
agent = await pool.create_agent_from_skill(routing_result.matched_skill)
|
||||||
|
except (ValueError, RuntimeError) as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
||||||
|
|
||||||
|
# Build messages from input
|
||||||
|
messages = [{"role": "user", "content": str(request.input_data)}]
|
||||||
|
|
||||||
|
# Get tools from agent
|
||||||
|
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
|
||||||
|
|
||||||
|
async for event in react_engine.execute_stream(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=agent._llm_model if hasattr(agent, "_llm_model") else "default",
|
||||||
|
agent_name=agent.name,
|
||||||
|
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
|
||||||
|
):
|
||||||
|
yield {
|
||||||
|
"event": event.event_type,
|
||||||
|
"data": json.dumps({
|
||||||
|
"step": event.step,
|
||||||
|
"data": event.data,
|
||||||
|
"timestamp": event.timestamp,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator())
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,431 @@
|
||||||
|
"""Streaming System 单元测试 - U8/U9/U10
|
||||||
|
|
||||||
|
覆盖:
|
||||||
|
- StreamChunk 数据类
|
||||||
|
- LLMProvider.chat_stream 默认回退
|
||||||
|
- Gateway.chat_stream 流式 + 用量追踪
|
||||||
|
- ReActEvent 数据类
|
||||||
|
- ReActEngine.execute_stream 事件流
|
||||||
|
- SSE 端点 /tasks/stream
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test Helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTool(Tool):
|
||||||
|
"""用于测试的 Fake Tool"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "fake_tool",
|
||||||
|
description: str = "A fake tool for testing",
|
||||||
|
result: dict | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(name=name, description=description)
|
||||||
|
self._result = result or {"status": "ok"}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
|
def make_response(
|
||||||
|
content: str = "",
|
||||||
|
tool_calls: list[ToolCall] | None = None,
|
||||||
|
prompt_tokens: int = 10,
|
||||||
|
completion_tokens: int = 20,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""快速构造 LLMResponse"""
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model="test-model",
|
||||||
|
usage=TokenUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
),
|
||||||
|
tool_calls=tool_calls or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# U8: StreamChunk + chat_stream
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamChunk:
|
||||||
|
"""StreamChunk 数据类测试"""
|
||||||
|
|
||||||
|
def test_creation_with_content(self):
|
||||||
|
from agentkit.llm.protocol import StreamChunk
|
||||||
|
|
||||||
|
chunk = StreamChunk(content="Hello", model="gpt-4o")
|
||||||
|
assert chunk.content == "Hello"
|
||||||
|
assert chunk.model == "gpt-4o"
|
||||||
|
assert chunk.tool_calls == []
|
||||||
|
assert chunk.usage is None
|
||||||
|
assert chunk.is_final is False
|
||||||
|
|
||||||
|
def test_with_tool_calls_and_is_final(self):
|
||||||
|
from agentkit.llm.protocol import StreamChunk
|
||||||
|
|
||||||
|
tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"})
|
||||||
|
chunk = StreamChunk(
|
||||||
|
content="",
|
||||||
|
model="gpt-4o",
|
||||||
|
tool_calls=[tc],
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
assert len(chunk.tool_calls) == 1
|
||||||
|
assert chunk.tool_calls[0].name == "search"
|
||||||
|
assert chunk.is_final is True
|
||||||
|
|
||||||
|
def test_with_usage(self):
|
||||||
|
from agentkit.llm.protocol import StreamChunk
|
||||||
|
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
chunk = StreamChunk(
|
||||||
|
content="",
|
||||||
|
model="gpt-4o",
|
||||||
|
usage=usage,
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
assert chunk.usage is not None
|
||||||
|
assert chunk.usage.total_tokens == 150
|
||||||
|
assert chunk.is_final is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMProviderChatStreamDefault:
|
||||||
|
"""LLMProvider.chat_stream 默认实现回退到 chat()"""
|
||||||
|
|
||||||
|
async def test_default_chat_stream_yields_single_chunk(self):
|
||||||
|
from agentkit.llm.protocol import LLMProvider, StreamChunk
|
||||||
|
|
||||||
|
class SimpleProvider(LLMProvider):
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content="hello",
|
||||||
|
model="test",
|
||||||
|
usage=TokenUsage(prompt_tokens=5, completion_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = SimpleProvider()
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in provider.chat_stream(request):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0].content == "hello"
|
||||||
|
assert chunks[0].is_final is True
|
||||||
|
assert chunks[0].usage.total_tokens == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestGatewayChatStream:
|
||||||
|
"""Gateway.chat_stream 流式测试"""
|
||||||
|
|
||||||
|
async def test_yields_chunks_from_provider(self):
|
||||||
|
from agentkit.llm.protocol import StreamChunk
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMProvider
|
||||||
|
|
||||||
|
class StreamingProvider(LLMProvider):
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content="fallback",
|
||||||
|
model="test",
|
||||||
|
usage=TokenUsage(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, request: LLMRequest):
|
||||||
|
yield StreamChunk(content="Hello ", model="test")
|
||||||
|
yield StreamChunk(content="World", model="test")
|
||||||
|
yield StreamChunk(
|
||||||
|
content="",
|
||||||
|
model="test",
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway = LLMGateway()
|
||||||
|
gateway.register_provider("test", StreamingProvider())
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in gateway.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="test/model",
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) == 3
|
||||||
|
assert chunks[0].content == "Hello "
|
||||||
|
assert chunks[1].content == "World"
|
||||||
|
assert chunks[2].is_final is True
|
||||||
|
|
||||||
|
async def test_tracks_usage_after_stream_completes(self):
|
||||||
|
from agentkit.llm.protocol import StreamChunk
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMProvider
|
||||||
|
|
||||||
|
class StreamingProvider(LLMProvider):
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content="fallback",
|
||||||
|
model="test",
|
||||||
|
usage=TokenUsage(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, request: LLMRequest):
|
||||||
|
yield StreamChunk(content="Hi", model="test")
|
||||||
|
yield StreamChunk(
|
||||||
|
content="",
|
||||||
|
model="test",
|
||||||
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway = LLMGateway()
|
||||||
|
gateway.register_provider("test", StreamingProvider())
|
||||||
|
|
||||||
|
# Consume the stream
|
||||||
|
chunks = []
|
||||||
|
async for chunk in gateway.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="test/model",
|
||||||
|
agent_name="stream_agent",
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Verify usage was tracked
|
||||||
|
usage = gateway.get_usage()
|
||||||
|
assert usage.total_tokens == 150
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# U9: ReActEvent + execute_stream
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestReActEvent:
|
||||||
|
"""ReActEvent 数据类测试"""
|
||||||
|
|
||||||
|
def test_creation_with_event_type_and_step(self):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
event = ReActEvent(event_type="thinking", step=1)
|
||||||
|
assert event.event_type == "thinking"
|
||||||
|
assert event.step == 1
|
||||||
|
assert event.data == {}
|
||||||
|
|
||||||
|
def test_has_timestamp(self):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
event = ReActEvent(event_type="thinking", step=1)
|
||||||
|
assert event.timestamp is not None
|
||||||
|
assert len(event.timestamp) > 0
|
||||||
|
|
||||||
|
def test_with_data(self):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
event = ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=2,
|
||||||
|
data={"tool_name": "search", "arguments": {"q": "test"}},
|
||||||
|
)
|
||||||
|
assert event.data["tool_name"] == "search"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReActEngineExecuteStream:
|
||||||
|
"""ReActEngine.execute_stream 事件流测试"""
|
||||||
|
|
||||||
|
async def test_yields_thinking_event_at_each_step(self):
|
||||||
|
from agentkit.core.react import ReActEngine, ReActEvent
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(return_value=make_response(content="Final answer"))
|
||||||
|
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
# Should have thinking + final_answer
|
||||||
|
thinking_events = [e for e in events if e.event_type == "thinking"]
|
||||||
|
assert len(thinking_events) >= 1
|
||||||
|
assert thinking_events[0].step == 1
|
||||||
|
|
||||||
|
async def test_yields_tool_call_and_tool_result_events(self):
|
||||||
|
from agentkit.core.react import ReActEngine, ReActEvent
|
||||||
|
|
||||||
|
tool = FakeTool(name="calculator", result={"value": 42})
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(side_effect=[
|
||||||
|
make_response(
|
||||||
|
content="",
|
||||||
|
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
|
||||||
|
),
|
||||||
|
make_response(content="The result is 42"),
|
||||||
|
])
|
||||||
|
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Calculate"}],
|
||||||
|
tools=[tool],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
tool_call_events = [e for e in events if e.event_type == "tool_call"]
|
||||||
|
tool_result_events = [e for e in events if e.event_type == "tool_result"]
|
||||||
|
|
||||||
|
assert len(tool_call_events) == 1
|
||||||
|
assert tool_call_events[0].data["tool_name"] == "calculator"
|
||||||
|
assert len(tool_result_events) == 1
|
||||||
|
assert tool_result_events[0].data["tool_name"] == "calculator"
|
||||||
|
assert tool_result_events[0].data["result"] == {"value": 42}
|
||||||
|
|
||||||
|
async def test_yields_final_answer_event(self):
|
||||||
|
from agentkit.core.react import ReActEngine, ReActEvent
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42"))
|
||||||
|
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "What is the answer?"}],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||||
|
assert len(final_events) == 1
|
||||||
|
assert final_events[0].data["output"] == "The answer is 42"
|
||||||
|
assert final_events[0].data["total_steps"] >= 1
|
||||||
|
assert final_events[0].data["total_tokens"] > 0
|
||||||
|
|
||||||
|
async def test_yields_max_steps_reached_when_hitting_limit(self):
|
||||||
|
from agentkit.core.react import ReActEngine, ReActEvent
|
||||||
|
|
||||||
|
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||||
|
|
||||||
|
always_tool_response = make_response(
|
||||||
|
content="Thinking...",
|
||||||
|
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
||||||
|
)
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(return_value=always_tool_response)
|
||||||
|
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Keep searching"}],
|
||||||
|
tools=[tool],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||||
|
assert len(final_events) == 1
|
||||||
|
assert final_events[0].data.get("max_steps_reached") is True
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# U10: SSE Endpoint + Client SDK
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEndpoint:
|
||||||
|
"""SSE /tasks/stream 端点测试"""
|
||||||
|
|
||||||
|
def test_stream_task_returns_event_source_response(self):
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from agentkit.server.app import create_app
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
gateway = LLMGateway()
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.chat.return_value = LLMResponse(
|
||||||
|
content="Final answer",
|
||||||
|
model="test-model",
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||||
|
)
|
||||||
|
gateway.register_provider("test", mock_provider)
|
||||||
|
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
tool_registry = ToolRegistry()
|
||||||
|
app = create_app(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Create an agent first
|
||||||
|
client.post(
|
||||||
|
"/api/v1/agents",
|
||||||
|
json={
|
||||||
|
"config": {
|
||||||
|
"name": "stream_agent",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {"identity": "Stream Agent"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream task
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/tasks/stream",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "test"},
|
||||||
|
"agent_name": "stream_agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Should return 200 with SSE content type
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "text/event-stream" in response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
def test_stream_task_with_invalid_agent_returns_404(self):
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from agentkit.server.app import create_app
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
gateway = LLMGateway()
|
||||||
|
skill_registry = SkillRegistry()
|
||||||
|
tool_registry = ToolRegistry()
|
||||||
|
app = create_app(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/tasks/stream",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "test"},
|
||||||
|
"agent_name": "nonexistent_agent",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
Loading…
Reference in New Issue