feat(tools): add AskHumanTool + token streaming in ReAct execute_stream

- AskHumanTool: Human-in-the-Loop tool for Chat mode, pushes questions
  via WebSocket callback and waits for user reply via asyncio.Future
- Token streaming: execute_stream() now uses chat_stream() instead of
  chat(), yielding token-type ReActEvents for each StreamChunk
- _build_response_from_stream() static method constructs LLMResponse
  from accumulated stream data
- Export AskHumanTool from tools/__init__.py
- 12 new tests (7 AskHumanTool + 5 token streaming), all passing
This commit is contained in:
chiguyong 2026-06-07 23:40:43 +08:00
parent 6013d5189b
commit 7054ac02b6
5 changed files with 398 additions and 3 deletions

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
from agentkit.core.protocol import CancellationToken from agentkit.core.protocol import CancellationToken
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import ( from agentkit.telemetry.metrics import (
@ -59,7 +60,7 @@ class ReActResult:
class ReActEvent: class ReActEvent:
"""ReAct 执行事件""" """ReAct 执行事件"""
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error" event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
step: int step: int
data: dict[str, Any] = field(default_factory=dict) data: dict[str, Any] = field(default_factory=dict)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@ -533,14 +534,42 @@ class ReActEngine:
data={"message": f"Step {step}: Calling LLM..."}, data={"message": f"Step {step}: Calling LLM..."},
) )
# Think: call LLM # Think: call LLM (with optional token streaming)
llm_start = time.monotonic() llm_start = time.monotonic()
response = await self._llm_gateway.chat(
# Use streaming for token-by-token output
stream_content = ""
stream_usage = None
stream_tool_calls: list[Any] = []
stream_model = model
async for chunk in self._llm_gateway.chat_stream(
messages=conversation, messages=conversation,
model=model, model=model,
agent_name=agent_name, agent_name=agent_name,
task_type=task_type, task_type=task_type,
tools=tool_schemas, tools=tool_schemas,
):
if chunk.content:
stream_content += chunk.content
yield ReActEvent(
event_type="token",
step=step,
data={"content": chunk.content},
)
if chunk.usage:
stream_usage = chunk.usage
if chunk.tool_calls:
stream_tool_calls = chunk.tool_calls
if chunk.model:
stream_model = chunk.model
# Build response-like object from stream
response = self._build_response_from_stream(
content=stream_content,
tool_calls=stream_tool_calls,
usage=stream_usage,
model=stream_model,
) )
llm_duration_ms = int((time.monotonic() - llm_start) * 1000) llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
@ -776,6 +805,24 @@ class ReActEngine:
schemas.append(schema) schemas.append(schema)
return schemas return schemas
@staticmethod
def _build_response_from_stream(
content: str,
tool_calls: list[Any],
usage: Any,
model: str,
) -> LLMResponse:
"""Build an LLMResponse from accumulated stream chunks."""
from agentkit.llm.protocol import LLMResponse, TokenUsage
if usage is None:
usage = TokenUsage()
return LLMResponse(
content=content,
tool_calls=tool_calls,
usage=usage,
model=model,
)
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None: def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
"""根据名称从可用工具中查找工具""" """根据名称从可用工具中查找工具"""
for tool in tools: for tool in tools:

View File

@ -9,6 +9,7 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.web_crawl import WebCrawlTool
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
from agentkit.tools.baidu_search import BaiduSearchTool from agentkit.tools.baidu_search import BaiduSearchTool
from agentkit.tools.ask_human import AskHumanTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -29,5 +30,6 @@ __all__ = [
"SchemaExtractTool", "SchemaExtractTool",
"SchemaGenerateTool", "SchemaGenerateTool",
"BaiduSearchTool", "BaiduSearchTool",
"AskHumanTool",
"HeadroomRetrieveTool", "HeadroomRetrieveTool",
] ]

View File

@ -0,0 +1,119 @@
"""AskHumanTool — Human-in-the-Loop tool for Chat mode.
When registered in a Chat-mode Agent, this tool allows the ReAct loop
to pause and ask the user a question, then wait for a reply via the
WebSocket connection.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class AskHumanTool(Tool):
"""Tool that asks the human user a question and waits for a reply.
Only functional in Chat mode where a WebSocket connection exists.
In Task mode, this tool should not be registered.
Usage in ReAct loop:
The Agent calls this tool when it needs clarification or
a decision from the user. The question is pushed to the
client via WebSocket, and the tool blocks until the user
replies or a timeout expires.
"""
def __init__(self, timeout: float = 60.0):
super().__init__(
name="ask_human",
description="Ask the human user a question and wait for their reply. "
"Use this when you need clarification, a decision, or "
"confirmation from the user before proceeding.",
)
self._timeout = timeout
# Shared dict injected by the Chat WebSocket handler:
# request_id -> asyncio.Future
self._pending_replies: dict[str, asyncio.Future] | None = None
# Callback to push question to client
self._ask_callback: Any = None
def configure(
self,
pending_replies: dict[str, asyncio.Future] | None = None,
ask_callback: Any = None,
) -> None:
"""Configure the tool with WebSocket communication channels.
Args:
pending_replies: Dict mapping request_id to Future that will
be resolved when the user replies.
ask_callback: Async callable(request_id, question, options)
that pushes the question to the client.
"""
self._pending_replies = pending_replies
self._ask_callback = ask_callback
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to ask the user",
},
"options": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of choices for the user",
},
},
"required": ["question"],
}
async def execute(self, **kwargs: Any) -> dict:
"""Ask the user a question and wait for their reply.
Args:
question: The question to ask.
options: Optional list of choices.
Returns:
Dict with "reply" key containing the user's response.
"""
question = kwargs.get("question", "")
options = kwargs.get("options")
if self._pending_replies is None or self._ask_callback is None:
# Not in Chat mode — return a default response
logger.warning("AskHumanTool called outside Chat mode, returning default response")
default = options[0] if options else "confirmed"
return {"reply": default}
request_id = str(uuid.uuid4())[:8]
# Create and register future BEFORE calling callback so the
# callback (or any concurrent task) can resolve it immediately.
loop = asyncio.get_event_loop()
future = loop.create_future()
self._pending_replies[request_id] = future
# Push question to client
await self._ask_callback(request_id, question, options)
try:
reply = await asyncio.wait_for(future, timeout=self._timeout)
return {"reply": str(reply)}
except asyncio.TimeoutError:
logger.warning(f"AskHumanTool timeout for request {request_id}")
default = options[0] if options else "timeout — no response received"
return {"reply": default}
finally:
self._pending_replies.pop(request_id, None)

View File

@ -0,0 +1,100 @@
"""Tests for AskHumanTool."""
import asyncio
import pytest
from agentkit.tools.ask_human import AskHumanTool
class TestAskHumanToolBasic:
def test_tool_properties(self):
tool = AskHumanTool()
assert tool.name == "ask_human"
assert "question" in str(tool.parameters)
assert tool.parameters["required"] == ["question"]
@pytest.mark.asyncio
async def test_no_chat_mode_returns_default(self):
tool = AskHumanTool()
result = await tool.execute(question="What should I do?")
assert result == {"reply": "confirmed"}
@pytest.mark.asyncio
async def test_no_chat_mode_with_options(self):
tool = AskHumanTool()
result = await tool.execute(question="Choose:", options=["A", "B", "C"])
assert result == {"reply": "A"}
class TestAskHumanToolChatMode:
@pytest.mark.asyncio
async def test_ask_and_receive_reply(self):
tool = AskHumanTool(timeout=5.0)
pending: dict[str, asyncio.Future] = {}
ask_calls: list[tuple[str, str, list[str] | None]] = []
async def mock_ask_callback(request_id, question, options):
ask_calls.append((request_id, question, options))
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
# Start the execute in a task
task = asyncio.create_task(
tool.execute(question="Continue?", options=["yes", "no"])
)
# Wait for the ask to be pushed
await asyncio.sleep(0.1)
assert len(ask_calls) == 1
request_id = ask_calls[0][0]
assert ask_calls[0][1] == "Continue?"
assert ask_calls[0][2] == ["yes", "no"]
# Simulate user reply
assert request_id in pending
pending[request_id].set_result("yes")
result = await task
assert result == {"reply": "yes"}
@pytest.mark.asyncio
async def test_timeout_returns_default(self):
tool = AskHumanTool(timeout=0.1)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
pass # Never reply
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
result = await tool.execute(question="Continue?", options=["yes", "no"])
assert result == {"reply": "yes"}
@pytest.mark.asyncio
async def test_timeout_no_options(self):
tool = AskHumanTool(timeout=0.1)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
pass
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
result = await tool.execute(question="Continue?")
assert "timeout" in result["reply"]
@pytest.mark.asyncio
async def test_cleanup_on_reply(self):
tool = AskHumanTool(timeout=5.0)
pending: dict[str, asyncio.Future] = {}
async def mock_ask_callback(request_id, question, options):
# Immediately reply
await asyncio.sleep(0.05)
pending[request_id].set_result("ok")
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
await tool.execute(question="Test?")
# Pending should be cleaned up
assert len(pending) == 0

View File

@ -0,0 +1,127 @@
"""Tests for ReAct engine token streaming via execute_stream()."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.react import ReActEngine, ReActEvent
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage
def _make_stream_chunks(
content_parts: list[str],
model: str = "test-model",
usage: TokenUsage | None = None,
) -> list[StreamChunk]:
"""Build a list of StreamChunk objects simulating a streaming response."""
chunks = []
for part in content_parts:
chunks.append(StreamChunk(content=part, model=model))
# Final chunk with usage
chunks.append(StreamChunk(
content="",
model=model,
usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20),
is_final=True,
))
return chunks
def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock:
"""Create a mock LLMGateway whose chat_stream yields the given chunks."""
gateway = MagicMock(spec=LLMGateway)
async def _stream(**kwargs):
for chunks in chunks_list:
for chunk in chunks:
yield chunk
# Remove after use
chunks_list.pop(0)
gateway.chat_stream = _stream
return gateway
class TestReActTokenStreaming:
"""Test that execute_stream() yields token events from chat_stream()."""
@pytest.mark.asyncio
async def test_yields_token_events(self):
"""execute_stream should yield token events for each stream chunk."""
chunks = _make_stream_chunks(["Hello", " world", "!"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hi"}],
):
events.append(event)
token_events = [e for e in events if e.event_type == "token"]
assert len(token_events) == 3
assert token_events[0].data["content"] == "Hello"
assert token_events[1].data["content"] == " world"
assert token_events[2].data["content"] == "!"
@pytest.mark.asyncio
async def test_final_answer_after_streaming(self):
"""After streaming tokens, a final_answer event should be yielded."""
chunks = _make_stream_chunks(["The answer is 42"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "What?"}],
):
events.append(event)
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) == 1
assert "42" in final_events[0].data.get("output", "")
@pytest.mark.asyncio
async def test_token_events_have_correct_step(self):
"""Token events should carry the current step number."""
chunks = _make_stream_chunks(["Hi"])
gateway = _make_stream_gateway([chunks])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hi"}],
):
events.append(event)
token_events = [e for e in events if e.event_type == "token"]
for te in token_events:
assert te.step >= 1
@pytest.mark.asyncio
async def test_build_response_from_stream(self):
"""_build_response_from_stream should construct a valid LLMResponse."""
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
response = ReActEngine._build_response_from_stream(
content="Hello world",
tool_calls=[],
usage=usage,
model="test-model",
)
assert response.content == "Hello world"
assert response.model == "test-model"
assert response.usage.prompt_tokens == 5
assert response.usage.completion_tokens == 10
assert not response.has_tool_calls
@pytest.mark.asyncio
async def test_build_response_from_stream_no_usage(self):
"""_build_response_from_stream should handle None usage gracefully."""
response = ReActEngine._build_response_from_stream(
content="test",
tool_calls=[],
usage=None,
model="test-model",
)
assert response.content == "test"
assert response.usage.prompt_tokens == 0