feat(tools): add AskHumanTool + token streaming in ReAct execute_stream
- AskHumanTool: Human-in-the-Loop tool for Chat mode, pushes questions via WebSocket callback and waits for user reply via asyncio.Future - Token streaming: execute_stream() now uses chat_stream() instead of chat(), yielding token-type ReActEvents for each StreamChunk - _build_response_from_stream() static method constructs LLMResponse from accumulated stream data - Export AskHumanTool from tools/__init__.py - 12 new tests (7 AskHumanTool + 5 token streaming), all passing
This commit is contained in:
parent
6013d5189b
commit
7054ac02b6
|
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any
|
||||||
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
||||||
from agentkit.core.protocol import CancellationToken
|
from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMResponse
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
|
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
|
||||||
from agentkit.telemetry.metrics import (
|
from agentkit.telemetry.metrics import (
|
||||||
|
|
@ -59,7 +60,7 @@ class ReActResult:
|
||||||
class ReActEvent:
|
class ReActEvent:
|
||||||
"""ReAct 执行事件"""
|
"""ReAct 执行事件"""
|
||||||
|
|
||||||
event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error"
|
event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
|
||||||
step: int
|
step: int
|
||||||
data: dict[str, Any] = field(default_factory=dict)
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
@ -533,14 +534,42 @@ class ReActEngine:
|
||||||
data={"message": f"Step {step}: Calling LLM..."},
|
data={"message": f"Step {step}: Calling LLM..."},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Think: call LLM
|
# Think: call LLM (with optional token streaming)
|
||||||
llm_start = time.monotonic()
|
llm_start = time.monotonic()
|
||||||
response = await self._llm_gateway.chat(
|
|
||||||
|
# Use streaming for token-by-token output
|
||||||
|
stream_content = ""
|
||||||
|
stream_usage = None
|
||||||
|
stream_tool_calls: list[Any] = []
|
||||||
|
stream_model = model
|
||||||
|
|
||||||
|
async for chunk in self._llm_gateway.chat_stream(
|
||||||
messages=conversation,
|
messages=conversation,
|
||||||
model=model,
|
model=model,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
tools=tool_schemas,
|
tools=tool_schemas,
|
||||||
|
):
|
||||||
|
if chunk.content:
|
||||||
|
stream_content += chunk.content
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="token",
|
||||||
|
step=step,
|
||||||
|
data={"content": chunk.content},
|
||||||
|
)
|
||||||
|
if chunk.usage:
|
||||||
|
stream_usage = chunk.usage
|
||||||
|
if chunk.tool_calls:
|
||||||
|
stream_tool_calls = chunk.tool_calls
|
||||||
|
if chunk.model:
|
||||||
|
stream_model = chunk.model
|
||||||
|
|
||||||
|
# Build response-like object from stream
|
||||||
|
response = self._build_response_from_stream(
|
||||||
|
content=stream_content,
|
||||||
|
tool_calls=stream_tool_calls,
|
||||||
|
usage=stream_usage,
|
||||||
|
model=stream_model,
|
||||||
)
|
)
|
||||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||||
|
|
||||||
|
|
@ -776,6 +805,24 @@ class ReActEngine:
|
||||||
schemas.append(schema)
|
schemas.append(schema)
|
||||||
return schemas
|
return schemas
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_response_from_stream(
|
||||||
|
content: str,
|
||||||
|
tool_calls: list[Any],
|
||||||
|
usage: Any,
|
||||||
|
model: str,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Build an LLMResponse from accumulated stream chunks."""
|
||||||
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
|
if usage is None:
|
||||||
|
usage = TokenUsage()
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
usage=usage,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
|
def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None:
|
||||||
"""根据名称从可用工具中查找工具"""
|
"""根据名称从可用工具中查找工具"""
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
|
||||||
from agentkit.tools.web_crawl import WebCrawlTool
|
from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||||
|
from agentkit.tools.ask_human import AskHumanTool
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -29,5 +30,6 @@ __all__ = [
|
||||||
"SchemaExtractTool",
|
"SchemaExtractTool",
|
||||||
"SchemaGenerateTool",
|
"SchemaGenerateTool",
|
||||||
"BaiduSearchTool",
|
"BaiduSearchTool",
|
||||||
|
"AskHumanTool",
|
||||||
"HeadroomRetrieveTool",
|
"HeadroomRetrieveTool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
"""AskHumanTool — Human-in-the-Loop tool for Chat mode.
|
||||||
|
|
||||||
|
When registered in a Chat-mode Agent, this tool allows the ReAct loop
|
||||||
|
to pause and ask the user a question, then wait for a reply via the
|
||||||
|
WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AskHumanTool(Tool):
|
||||||
|
"""Tool that asks the human user a question and waits for a reply.
|
||||||
|
|
||||||
|
Only functional in Chat mode where a WebSocket connection exists.
|
||||||
|
In Task mode, this tool should not be registered.
|
||||||
|
|
||||||
|
Usage in ReAct loop:
|
||||||
|
The Agent calls this tool when it needs clarification or
|
||||||
|
a decision from the user. The question is pushed to the
|
||||||
|
client via WebSocket, and the tool blocks until the user
|
||||||
|
replies or a timeout expires.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, timeout: float = 60.0):
|
||||||
|
super().__init__(
|
||||||
|
name="ask_human",
|
||||||
|
description="Ask the human user a question and wait for their reply. "
|
||||||
|
"Use this when you need clarification, a decision, or "
|
||||||
|
"confirmation from the user before proceeding.",
|
||||||
|
)
|
||||||
|
self._timeout = timeout
|
||||||
|
# Shared dict injected by the Chat WebSocket handler:
|
||||||
|
# request_id -> asyncio.Future
|
||||||
|
self._pending_replies: dict[str, asyncio.Future] | None = None
|
||||||
|
# Callback to push question to client
|
||||||
|
self._ask_callback: Any = None
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
pending_replies: dict[str, asyncio.Future] | None = None,
|
||||||
|
ask_callback: Any = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure the tool with WebSocket communication channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pending_replies: Dict mapping request_id to Future that will
|
||||||
|
be resolved when the user replies.
|
||||||
|
ask_callback: Async callable(request_id, question, options)
|
||||||
|
that pushes the question to the client.
|
||||||
|
"""
|
||||||
|
self._pending_replies = pending_replies
|
||||||
|
self._ask_callback = ask_callback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"question": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The question to ask the user",
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Optional list of choices for the user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["question"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> dict:
|
||||||
|
"""Ask the user a question and wait for their reply.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The question to ask.
|
||||||
|
options: Optional list of choices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with "reply" key containing the user's response.
|
||||||
|
"""
|
||||||
|
question = kwargs.get("question", "")
|
||||||
|
options = kwargs.get("options")
|
||||||
|
|
||||||
|
if self._pending_replies is None or self._ask_callback is None:
|
||||||
|
# Not in Chat mode — return a default response
|
||||||
|
logger.warning("AskHumanTool called outside Chat mode, returning default response")
|
||||||
|
default = options[0] if options else "confirmed"
|
||||||
|
return {"reply": default}
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# Create and register future BEFORE calling callback so the
|
||||||
|
# callback (or any concurrent task) can resolve it immediately.
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
future = loop.create_future()
|
||||||
|
self._pending_replies[request_id] = future
|
||||||
|
|
||||||
|
# Push question to client
|
||||||
|
await self._ask_callback(request_id, question, options)
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply = await asyncio.wait_for(future, timeout=self._timeout)
|
||||||
|
return {"reply": str(reply)}
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"AskHumanTool timeout for request {request_id}")
|
||||||
|
default = options[0] if options else "timeout — no response received"
|
||||||
|
return {"reply": default}
|
||||||
|
finally:
|
||||||
|
self._pending_replies.pop(request_id, None)
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
"""Tests for AskHumanTool."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.ask_human import AskHumanTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestAskHumanToolBasic:
|
||||||
|
def test_tool_properties(self):
|
||||||
|
tool = AskHumanTool()
|
||||||
|
assert tool.name == "ask_human"
|
||||||
|
assert "question" in str(tool.parameters)
|
||||||
|
assert tool.parameters["required"] == ["question"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_chat_mode_returns_default(self):
|
||||||
|
tool = AskHumanTool()
|
||||||
|
result = await tool.execute(question="What should I do?")
|
||||||
|
assert result == {"reply": "confirmed"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_chat_mode_with_options(self):
|
||||||
|
tool = AskHumanTool()
|
||||||
|
result = await tool.execute(question="Choose:", options=["A", "B", "C"])
|
||||||
|
assert result == {"reply": "A"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAskHumanToolChatMode:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ask_and_receive_reply(self):
|
||||||
|
tool = AskHumanTool(timeout=5.0)
|
||||||
|
pending: dict[str, asyncio.Future] = {}
|
||||||
|
ask_calls: list[tuple[str, str, list[str] | None]] = []
|
||||||
|
|
||||||
|
async def mock_ask_callback(request_id, question, options):
|
||||||
|
ask_calls.append((request_id, question, options))
|
||||||
|
|
||||||
|
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||||
|
|
||||||
|
# Start the execute in a task
|
||||||
|
task = asyncio.create_task(
|
||||||
|
tool.execute(question="Continue?", options=["yes", "no"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the ask to be pushed
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert len(ask_calls) == 1
|
||||||
|
request_id = ask_calls[0][0]
|
||||||
|
assert ask_calls[0][1] == "Continue?"
|
||||||
|
assert ask_calls[0][2] == ["yes", "no"]
|
||||||
|
|
||||||
|
# Simulate user reply
|
||||||
|
assert request_id in pending
|
||||||
|
pending[request_id].set_result("yes")
|
||||||
|
|
||||||
|
result = await task
|
||||||
|
assert result == {"reply": "yes"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_returns_default(self):
|
||||||
|
tool = AskHumanTool(timeout=0.1)
|
||||||
|
pending: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
async def mock_ask_callback(request_id, question, options):
|
||||||
|
pass # Never reply
|
||||||
|
|
||||||
|
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||||
|
|
||||||
|
result = await tool.execute(question="Continue?", options=["yes", "no"])
|
||||||
|
assert result == {"reply": "yes"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_no_options(self):
|
||||||
|
tool = AskHumanTool(timeout=0.1)
|
||||||
|
pending: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
async def mock_ask_callback(request_id, question, options):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||||
|
|
||||||
|
result = await tool.execute(question="Continue?")
|
||||||
|
assert "timeout" in result["reply"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_on_reply(self):
|
||||||
|
tool = AskHumanTool(timeout=5.0)
|
||||||
|
pending: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
async def mock_ask_callback(request_id, question, options):
|
||||||
|
# Immediately reply
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
pending[request_id].set_result("ok")
|
||||||
|
|
||||||
|
tool.configure(pending_replies=pending, ask_callback=mock_ask_callback)
|
||||||
|
|
||||||
|
await tool.execute(question="Test?")
|
||||||
|
# Pending should be cleaned up
|
||||||
|
assert len(pending) == 0
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
"""Tests for ReAct engine token streaming via execute_stream()."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.core.react import ReActEngine, ReActEvent
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_chunks(
|
||||||
|
content_parts: list[str],
|
||||||
|
model: str = "test-model",
|
||||||
|
usage: TokenUsage | None = None,
|
||||||
|
) -> list[StreamChunk]:
|
||||||
|
"""Build a list of StreamChunk objects simulating a streaming response."""
|
||||||
|
chunks = []
|
||||||
|
for part in content_parts:
|
||||||
|
chunks.append(StreamChunk(content=part, model=model))
|
||||||
|
# Final chunk with usage
|
||||||
|
chunks.append(StreamChunk(
|
||||||
|
content="",
|
||||||
|
model=model,
|
||||||
|
usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||||
|
is_final=True,
|
||||||
|
))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock:
|
||||||
|
"""Create a mock LLMGateway whose chat_stream yields the given chunks."""
|
||||||
|
gateway = MagicMock(spec=LLMGateway)
|
||||||
|
|
||||||
|
async def _stream(**kwargs):
|
||||||
|
for chunks in chunks_list:
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
# Remove after use
|
||||||
|
chunks_list.pop(0)
|
||||||
|
|
||||||
|
gateway.chat_stream = _stream
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
|
class TestReActTokenStreaming:
|
||||||
|
"""Test that execute_stream() yields token events from chat_stream()."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_yields_token_events(self):
|
||||||
|
"""execute_stream should yield token events for each stream chunk."""
|
||||||
|
chunks = _make_stream_chunks(["Hello", " world", "!"])
|
||||||
|
gateway = _make_stream_gateway([chunks])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [e for e in events if e.event_type == "token"]
|
||||||
|
assert len(token_events) == 3
|
||||||
|
assert token_events[0].data["content"] == "Hello"
|
||||||
|
assert token_events[1].data["content"] == " world"
|
||||||
|
assert token_events[2].data["content"] == "!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_answer_after_streaming(self):
|
||||||
|
"""After streaming tokens, a final_answer event should be yielded."""
|
||||||
|
chunks = _make_stream_chunks(["The answer is 42"])
|
||||||
|
gateway = _make_stream_gateway([chunks])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "What?"}],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
final_events = [e for e in events if e.event_type == "final_answer"]
|
||||||
|
assert len(final_events) == 1
|
||||||
|
assert "42" in final_events[0].data.get("output", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_events_have_correct_step(self):
|
||||||
|
"""Token events should carry the current step number."""
|
||||||
|
chunks = _make_stream_chunks(["Hi"])
|
||||||
|
gateway = _make_stream_gateway([chunks])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [e for e in events if e.event_type == "token"]
|
||||||
|
for te in token_events:
|
||||||
|
assert te.step >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_response_from_stream(self):
|
||||||
|
"""_build_response_from_stream should construct a valid LLMResponse."""
|
||||||
|
usage = TokenUsage(prompt_tokens=5, completion_tokens=10)
|
||||||
|
response = ReActEngine._build_response_from_stream(
|
||||||
|
content="Hello world",
|
||||||
|
tool_calls=[],
|
||||||
|
usage=usage,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
assert response.content == "Hello world"
|
||||||
|
assert response.model == "test-model"
|
||||||
|
assert response.usage.prompt_tokens == 5
|
||||||
|
assert response.usage.completion_tokens == 10
|
||||||
|
assert not response.has_tool_calls
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_response_from_stream_no_usage(self):
|
||||||
|
"""_build_response_from_stream should handle None usage gracefully."""
|
||||||
|
response = ReActEngine._build_response_from_stream(
|
||||||
|
content="test",
|
||||||
|
tool_calls=[],
|
||||||
|
usage=None,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
assert response.content == "test"
|
||||||
|
assert response.usage.prompt_tokens == 0
|
||||||
Loading…
Reference in New Issue