feat(tools): add AskHumanTool + token streaming in ReAct execute_stream
- AskHumanTool: Human-in-the-Loop tool for Chat mode, pushes questions via WebSocket callback and waits for user reply via asyncio.Future - Token streaming: execute_stream() now uses chat_stream() instead of chat(), yielding token-type ReActEvents for each StreamChunk - _build_response_from_stream() static method constructs LLMResponse from accumulated stream data - Export AskHumanTool from tools/__init__.py - 12 new tests (7 AskHumanTool + 5 token streaming), all passing
This commit is contained in:
parent
6013d5189b
commit
7054ac02b6
|
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
|
|||
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
||||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
|
||||
from agentkit.telemetry.metrics import (
|
||||
|
|
@ -59,7 +60,7 @@ class ReActResult:
|
|||
class ReActEvent:
|
||||
"""ReAct 执行事件"""
|
||||
|
||||
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error"
|
||||
event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
|
||||
step: int
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
|
@ -533,14 +534,42 @@ class ReActEngine:
|
|||
data={"message": f"Step {step}: Calling LLM..."},
|
||||
)
|
||||
|
||||
# Think: call LLM
|
||||
# Think: call LLM (with optional token streaming)
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
|
||||
# Use streaming for token-by-token output
|
||||
stream_content = ""
|
||||
stream_usage = None
|
||||
stream_tool_calls: list[Any] = []
|
||||
stream_model = model
|
||||
|
||||
async for chunk in self._llm_gateway.chat_stream(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
):
|
||||
if chunk.content:
|
||||
stream_content += chunk.content
|
||||
yield ReActEvent(
|
||||
event_type="token",
|
||||
step=step,
|
||||
data={"content": chunk.content},
|
||||
)
|
||||
if chunk.usage:
|
||||
stream_usage = chunk.usage
|
||||
if chunk.tool_calls:
|
||||
stream_tool_calls = chunk.tool_calls
|
||||
if chunk.model:
|
||||
stream_model = chunk.model
|
||||
|
||||
# Build response-like object from stream
|
||||
response = self._build_response_from_stream(
|
||||
content=stream_content,
|
||||
tool_calls=stream_tool_calls,
|
||||
usage=stream_usage,
|
||||
model=stream_model,
|
||||
)
|
||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||
|
||||
|
|
@ -776,6 +805,24 @@ class ReActEngine:
|
|||
schemas.append(schema)
|
||||
return schemas
|
||||
|
||||
@staticmethod
|
||||
def _build_response_from_stream(
|
||||
content: str,
|
||||
tool_calls: list[Any],
|
||||
usage: Any,
|
||||
model: str,
|
||||
) -> LLMResponse:
|
||||
"""Build an LLMResponse from accumulated stream chunks."""
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
if usage is None:
|
||||
usage = TokenUsage()
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
usage=usage,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
|
||||
"""根据名称从可用工具中查找工具"""
|
||||
for tool in tools:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
|
|||
from agentkit.tools.web_crawl import WebCrawlTool
|
||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||
from agentkit.tools.ask_human import AskHumanTool
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
|
|
@ -29,5 +30,6 @@ __all__ = [
|
|||
"SchemaExtractTool",
|
||||
"SchemaGenerateTool",
|
||||
"BaiduSearchTool",
|
||||
"AskHumanTool",
|
||||
"HeadroomRetrieveTool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,119 @@
|
|||
"""AskHumanTool — Human-in-the-Loop tool for Chat mode.
|
||||
|
||||
When registered in a Chat-mode Agent, this tool allows the ReAct loop
|
||||
to pause and ask the user a question, then wait for a reply via the
|
||||
WebSocket connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AskHumanTool(Tool):
|
||||
"""Tool that asks the human user a question and waits for a reply.
|
||||
|
||||
Only functional in Chat mode where a WebSocket connection exists.
|
||||
In Task mode, this tool should not be registered.
|
||||
|
||||
Usage in ReAct loop:
|
||||
The Agent calls this tool when it needs clarification or
|
||||
a decision from the user. The question is pushed to the
|
||||
client via WebSocket, and the tool blocks until the user
|
||||
replies or a timeout expires.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float = 60.0):
|
||||
super().__init__(
|
||||
name="ask_human",
|
||||
description="Ask the human user a question and wait for their reply. "
|
||||
"Use this when you need clarification, a decision, or "
|
||||
"confirmation from the user before proceeding.",
|
||||
)
|
||||
self._timeout = timeout
|
||||
# Shared dict injected by the Chat WebSocket handler:
|
||||
# request_id -> asyncio.Future
|
||||
self._pending_replies: dict[str, asyncio.Future] | None = None
|
||||
# Callback to push question to client
|
||||
self._ask_callback: Any = None
|
||||
|
||||
def configure(
|
||||
self,
|
||||
pending_replies: dict[str, asyncio.Future] | None = None,
|
||||
ask_callback: Any = None,
|
||||
) -> None:
|
||||
"""Configure the tool with WebSocket communication channels.
|
||||
|
||||
Args:
|
||||
pending_replies: Dict mapping request_id to Future that will
|
||||
be resolved when the user replies.
|
||||
ask_callback: Async callable(request_id, question, options)
|
||||
that pushes the question to the client.
|
||||
"""
|
||||
self._pending_replies = pending_replies
|
||||
self._ask_callback = ask_callback
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the user",
|
||||
},
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of choices for the user",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> dict:
|
||||
"""Ask the user a question and wait for their reply.
|
||||
|
||||
Args:
|
||||
question: The question to ask.
|
||||
options: Optional list of choices.
|
||||
|
||||
Returns:
|
||||
Dict with "reply" key containing the user's response.
|
||||
"""
|
||||
question = kwargs.get("question", "")
|
||||
options = kwargs.get("options")
|
||||
|
||||
if self._pending_replies is None or self._ask_callback is None:
|
||||
# Not in Chat mode — return a default response
|
||||
logger.warning("AskHumanTool called outside Chat mode, returning default response")
|
||||
default = options[0] if options else "confirmed"
|
||||
return {"reply": default}
|
||||
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create and register future BEFORE calling callback so the
|
||||
# callback (or any concurrent task) can resolve it immediately.
|
||||
loop = asyncio.get_event_loop()
|
||||
future = loop.create_future()
|
||||
self._pending_replies[request_id] = future
|
||||
|
||||
# Push question to client
|
||||
await self._ask_callback(request_id, question, options)
|
||||
|
||||
try:
|
||||
reply = await asyncio.wait_for(future, timeout=self._timeout)
|
||||
return {"reply": str(reply)}
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"AskHumanTool timeout for request {request_id}")
|
||||
default = options[0] if options else "timeout — no response received"
|
||||
return {"reply": default}
|
||||
finally:
|
||||
self._pending_replies.pop(request_id, None)
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""Tests for AskHumanTool."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.ask_human import AskHumanTool
|
||||
|
||||
|
||||
class TestAskHumanToolBasic:
|
||||
def test_tool_properties(self):
|
||||
tool = AskHumanTool()
|
||||
assert tool.name == "ask_human"
|
||||
assert "question" in str(tool.parameters)
|
||||
assert tool.parameters["required"] == ["question"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_chat_mode_returns_default(self):
|
||||
tool = AskHumanTool()
|
||||
result = await tool.execute(question="What should I do?")
|
||||
assert result == {"reply": "confirmed"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_chat_mode_with_options(self):
|
||||
tool = AskHumanTool()
|
||||
result = await tool.execute(question="Choose:", options=["A", "B", "C"])
|
||||
assert result == {"reply": "A"}
|
||||
|
||||
|
||||
class TestAskHumanToolChatMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_and_receive_reply(self):
|
||||
tool = AskHumanTool(timeout=5.0)
|
||||
pending: dict[str, asyncio.Future] = {}
|
||||
ask_calls: list[tuple[str, str, list[str] | None]] = []
|
||||
|
||||
async def mock_ask_callback(request_id, question, options):
|
||||
ask_calls.append((request_id, question, options))
|
||||
|
||||
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||
|
||||
# Start the execute in a task
|
||||
task = asyncio.create_task(
|
||||
tool.execute(question="Continue?", options=["yes", "no"])
|
||||
)
|
||||
|
||||
# Wait for the ask to be pushed
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(ask_calls) == 1
|
||||
request_id = ask_calls[0][0]
|
||||
assert ask_calls[0][1] == "Continue?"
|
||||
assert ask_calls[0][2] == ["yes", "no"]
|
||||
|
||||
# Simulate user reply
|
||||
assert request_id in pending
|
||||
pending[request_id].set_result("yes")
|
||||
|
||||
result = await task
|
||||
assert result == {"reply": "yes"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_returns_default(self):
|
||||
tool = AskHumanTool(timeout=0.1)
|
||||
pending: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def mock_ask_callback(request_id, question, options):
|
||||
pass # Never reply
|
||||
|
||||
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||
|
||||
result = await tool.execute(question="Continue?", options=["yes", "no"])
|
||||
assert result == {"reply": "yes"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_no_options(self):
|
||||
tool = AskHumanTool(timeout=0.1)
|
||||
pending: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def mock_ask_callback(request_id, question, options):
|
||||
pass
|
||||
|
||||
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||
|
||||
result = await tool.execute(question="Continue?")
|
||||
assert "timeout" in result["reply"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_on_reply(self):
|
||||
tool = AskHumanTool(timeout=5.0)
|
||||
pending: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def mock_ask_callback(request_id, question, options):
|
||||
# Immediately reply
|
||||
await asyncio.sleep(0.05)
|
||||
pending[request_id].set_result("ok")
|
||||
|
||||
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||
|
||||
await tool.execute(question="Test?")
|
||||
# Pending should be cleaned up
|
||||
assert len(pending) == 0
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
"""Tests for ReAct engine token streaming via execute_stream()."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage
|
||||
|
||||
|
||||
def _make_stream_chunks(
|
||||
content_parts: list[str],
|
||||
model: str = "test-model",
|
||||
usage: TokenUsage | None = None,
|
||||
) -> list[StreamChunk]:
|
||||
"""Build a list of StreamChunk objects simulating a streaming response."""
|
||||
chunks = []
|
||||
for part in content_parts:
|
||||
chunks.append(StreamChunk(content=part, model=model))
|
||||
# Final chunk with usage
|
||||
chunks.append(StreamChunk(
|
||||
content="",
|
||||
model=model,
|
||||
usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
is_final=True,
|
||||
))
|
||||
return chunks
|
||||
|
||||
|
||||
def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock:
|
||||
"""Create a mock LLMGateway whose chat_stream yields the given chunks."""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
|
||||
async def _stream(**kwargs):
|
||||
for chunks in chunks_list:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
# Remove after use
|
||||
chunks_list.pop(0)
|
||||
|
||||
gateway.chat_stream = _stream
|
||||
return gateway
|
||||
|
||||
|
||||
class TestReActTokenStreaming:
|
||||
"""Test that execute_stream() yields token events from chat_stream()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_token_events(self):
|
||||
"""execute_stream should yield token events for each stream chunk."""
|
||||
chunks = _make_stream_chunks(["Hello", " world", "!"])
|
||||
gateway = _make_stream_gateway([chunks])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
token_events = [e for e in events if e.event_type == "token"]
|
||||
assert len(token_events) == 3
|
||||
assert token_events[0].data["content"] == "Hello"
|
||||
assert token_events[1].data["content"] == " world"
|
||||
assert token_events[2].data["content"] == "!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_answer_after_streaming(self):
|
||||
"""After streaming tokens, a final_answer event should be yielded."""
|
||||
chunks = _make_stream_chunks(["The answer is 42"])
|
||||
gateway = _make_stream_gateway([chunks])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "What?"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(final_events) == 1
|
||||
assert "42" in final_events[0].data.get("output", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_events_have_correct_step(self):
|
||||
"""Token events should carry the current step number."""
|
||||
chunks = _make_stream_chunks(["Hi"])
|
||||
gateway = _make_stream_gateway([chunks])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
token_events = [e for e in events if e.event_type == "token"]
|
||||
for te in token_events:
|
||||
assert te.step >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_response_from_stream(self):
|
||||
"""_build_response_from_stream should construct a valid LLMResponse."""
|
||||
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
|
||||
response = ReActEngine._build_response_from_stream(
|
||||
content="Hello world",
|
||||
tool_calls=[],
|
||||
usage=usage,
|
||||
model="test-model",
|
||||
)
|
||||
assert response.content == "Hello world"
|
||||
assert response.model == "test-model"
|
||||
assert response.usage.prompt_tokens == 5
|
||||
assert response.usage.completion_tokens == 10
|
||||
assert not response.has_tool_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_response_from_stream_no_usage(self):
|
||||
"""_build_response_from_stream should handle None usage gracefully."""
|
||||
response = ReActEngine._build_response_from_stream(
|
||||
content="test",
|
||||
tool_calls=[],
|
||||
usage=None,
|
||||
model="test-model",
|
||||
)
|
||||
assert response.content == "test"
|
||||
assert response.usage.prompt_tokens == 0
|
||||
Loading…
Reference in New Issue