diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 0b17393..f94e90b 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError from agentkit.core.protocol import CancellationToken from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse from agentkit.tools.base import Tool from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE from agentkit.telemetry.metrics import ( @@ -59,7 +60,7 @@ class ReActResult: class ReActEvent: """ReAct 执行事件""" - event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error" + event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error" step: int data: dict[str, Any] = field(default_factory=dict) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -533,14 +534,42 @@ class ReActEngine: data={"message": f"Step {step}: Calling LLM..."}, ) - # Think: call LLM + # Think: call LLM (with optional token streaming) llm_start = time.monotonic() - response = await self._llm_gateway.chat( + + # Use streaming for token-by-token output + stream_content = "" + stream_usage = None + stream_tool_calls: list[Any] = [] + stream_model = model + + async for chunk in self._llm_gateway.chat_stream( messages=conversation, model=model, agent_name=agent_name, task_type=task_type, tools=tool_schemas, + ): + if chunk.content: + stream_content += chunk.content + yield ReActEvent( + event_type="token", + step=step, + data={"content": chunk.content}, + ) + if chunk.usage: + stream_usage = chunk.usage + if chunk.tool_calls: + stream_tool_calls = chunk.tool_calls + if chunk.model: + stream_model = chunk.model + + # Build response-like object from stream + response = self._build_response_from_stream( + content=stream_content, + tool_calls=stream_tool_calls, + usage=stream_usage, + model=stream_model, ) llm_duration_ms = int((time.monotonic() - llm_start) * 1000) @@ -776,6 +805,24 @@ class ReActEngine: schemas.append(schema) return schemas + @staticmethod + def _build_response_from_stream( + content: str, + tool_calls: list[Any], + usage: Any, + model: str, + ) -> LLMResponse: + """Build an LLMResponse from accumulated stream chunks.""" + from agentkit.llm.protocol import LLMResponse, TokenUsage + if usage is None: + usage = TokenUsage() + return LLMResponse( + content=content, + tool_calls=tool_calls, + usage=usage, + model=model, + ) + def _find_tool(self, name: str, tools: list[Tool]) -> Tool | None: """根据名称从可用工具中查找工具""" for tool in tools: diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 3aef0be..6feb772 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -9,6 +9,7 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.baidu_search import BaiduSearchTool +from agentkit.tools.ask_human import AskHumanTool # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -29,5 +30,6 @@ __all__ = [ "SchemaExtractTool", "SchemaGenerateTool", "BaiduSearchTool", + "AskHumanTool", "HeadroomRetrieveTool", ] diff --git a/src/agentkit/tools/ask_human.py b/src/agentkit/tools/ask_human.py new file mode 100644 index 0000000..0fc9a9b --- /dev/null +++ b/src/agentkit/tools/ask_human.py @@ -0,0 +1,119 @@ +"""AskHumanTool — Human-in-the-Loop tool for Chat mode. + +When registered in a Chat-mode Agent, this tool allows the ReAct loop +to pause and ask the user a question, then wait for a reply via the +WebSocket connection. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class AskHumanTool(Tool): + """Tool that asks the human user a question and waits for a reply. + + Only functional in Chat mode where a WebSocket connection exists. + In Task mode, this tool should not be registered. + + Usage in ReAct loop: + The Agent calls this tool when it needs clarification or + a decision from the user. The question is pushed to the + client via WebSocket, and the tool blocks until the user + replies or a timeout expires. + """ + + def __init__(self, timeout: float = 60.0): + super().__init__( + name="ask_human", + description="Ask the human user a question and wait for their reply. " + "Use this when you need clarification, a decision, or " + "confirmation from the user before proceeding.", + ) + self._timeout = timeout + # Shared dict injected by the Chat WebSocket handler: + # request_id -> asyncio.Future + self._pending_replies: dict[str, asyncio.Future] | None = None + # Callback to push question to client + self._ask_callback: Any = None + + def configure( + self, + pending_replies: dict[str, asyncio.Future] | None = None, + ask_callback: Any = None, + ) -> None: + """Configure the tool with WebSocket communication channels. + + Args: + pending_replies: Dict mapping request_id to Future that will + be resolved when the user replies. + ask_callback: Async callable(request_id, question, options) + that pushes the question to the client. + """ + self._pending_replies = pending_replies + self._ask_callback = ask_callback + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question to ask the user", + }, + "options": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of choices for the user", + }, + }, + "required": ["question"], + } + + async def execute(self, **kwargs: Any) -> dict: + """Ask the user a question and wait for their reply. + + Args: + question: The question to ask. + options: Optional list of choices. + + Returns: + Dict with "reply" key containing the user's response. + """ + question = kwargs.get("question", "") + options = kwargs.get("options") + + if self._pending_replies is None or self._ask_callback is None: + # Not in Chat mode — return a default response + logger.warning("AskHumanTool called outside Chat mode, returning default response") + default = options[0] if options else "confirmed" + return {"reply": default} + + request_id = str(uuid.uuid4())[:8] + + # Create and register future BEFORE calling callback so the + # callback (or any concurrent task) can resolve it immediately. + loop = asyncio.get_event_loop() + future = loop.create_future() + self._pending_replies[request_id] = future + + # Push question to client + await self._ask_callback(request_id, question, options) + + try: + reply = await asyncio.wait_for(future, timeout=self._timeout) + return {"reply": str(reply)} + except asyncio.TimeoutError: + logger.warning(f"AskHumanTool timeout for request {request_id}") + default = options[0] if options else "timeout — no response received" + return {"reply": default} + finally: + self._pending_replies.pop(request_id, None) diff --git a/tests/unit/test_ask_human_tool.py b/tests/unit/test_ask_human_tool.py new file mode 100644 index 0000000..6dbc1b0 --- /dev/null +++ b/tests/unit/test_ask_human_tool.py @@ -0,0 +1,100 @@ +"""Tests for AskHumanTool.""" + +import asyncio +import pytest + +from agentkit.tools.ask_human import AskHumanTool + + +class TestAskHumanToolBasic: + def test_tool_properties(self): + tool = AskHumanTool() + assert tool.name == "ask_human" + assert "question" in str(tool.parameters) + assert tool.parameters["required"] == ["question"] + + @pytest.mark.asyncio + async def test_no_chat_mode_returns_default(self): + tool = AskHumanTool() + result = await tool.execute(question="What should I do?") + assert result == {"reply": "confirmed"} + + @pytest.mark.asyncio + async def test_no_chat_mode_with_options(self): + tool = AskHumanTool() + result = await tool.execute(question="Choose:", options=["A", "B", "C"]) + assert result == {"reply": "A"} + + +class TestAskHumanToolChatMode: + @pytest.mark.asyncio + async def test_ask_and_receive_reply(self): + tool = AskHumanTool(timeout=5.0) + pending: dict[str, asyncio.Future] = {} + ask_calls: list[tuple[str, str, list[str] | None]] = [] + + async def mock_ask_callback(request_id, question, options): + ask_calls.append((request_id, question, options)) + + tool.configure(pending_replies=pending, ask_callback=mock_ask_callback) + + # Start the execute in a task + task = asyncio.create_task( + tool.execute(question="Continue?", options=["yes", "no"]) + ) + + # Wait for the ask to be pushed + await asyncio.sleep(0.1) + assert len(ask_calls) == 1 + request_id = ask_calls[0][0] + assert ask_calls[0][1] == "Continue?" + assert ask_calls[0][2] == ["yes", "no"] + + # Simulate user reply + assert request_id in pending + pending[request_id].set_result("yes") + + result = await task + assert result == {"reply": "yes"} + + @pytest.mark.asyncio + async def test_timeout_returns_default(self): + tool = AskHumanTool(timeout=0.1) + pending: dict[str, asyncio.Future] = {} + + async def mock_ask_callback(request_id, question, options): + pass # Never reply + + tool.configure(pending_replies=pending, ask_callback=mock_ask_callback) + + result = await tool.execute(question="Continue?", options=["yes", "no"]) + assert result == {"reply": "yes"} + + @pytest.mark.asyncio + async def test_timeout_no_options(self): + tool = AskHumanTool(timeout=0.1) + pending: dict[str, asyncio.Future] = {} + + async def mock_ask_callback(request_id, question, options): + pass + + tool.configure(pending_replies=pending, ask_callback=mock_ask_callback) + + result = await tool.execute(question="Continue?") + assert "timeout" in result["reply"] + + @pytest.mark.asyncio + async def test_cleanup_on_reply(self): + tool = AskHumanTool(timeout=5.0) + pending: dict[str, asyncio.Future] = {} + + async def mock_ask_callback(request_id, question, options): + # Immediately reply + await asyncio.sleep(0.05) + pending[request_id].set_result("ok") + + tool.configure(pending_replies=pending, ask_callback=mock_ask_callback) + + await tool.execute(question="Test?") + # Pending should be cleaned up + assert len(pending) == 0 diff --git a/tests/unit/test_react_token_streaming.py b/tests/unit/test_react_token_streaming.py new file mode 100644 index 0000000..cf24d34 --- /dev/null +++ b/tests/unit/test_react_token_streaming.py @@ -0,0 +1,127 @@ +"""Tests for ReAct engine token streaming via execute_stream().""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from agentkit.core.react import ReActEngine, ReActEvent +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage + + +def _make_stream_chunks( + content_parts: list[str], + model: str = "test-model", + usage: TokenUsage | None = None, +) -> list[StreamChunk]: + """Build a list of StreamChunk objects simulating a streaming response.""" + chunks = [] + for part in content_parts: + chunks.append(StreamChunk(content=part, model=model)) + # Final chunk with usage + chunks.append(StreamChunk( + content="", + model=model, + usage=usage or TokenUsage(prompt_tokens=10, completion_tokens=20), + is_final=True, + )) + return chunks + + +def _make_stream_gateway(chunks_list: list[list[StreamChunk]]) -> MagicMock: + """Create a mock LLMGateway whose chat_stream yields the given chunks.""" + gateway = MagicMock(spec=LLMGateway) + + async def _stream(**kwargs): + for chunks in chunks_list: + for chunk in chunks: + yield chunk + # Remove after use + chunks_list.pop(0) + + gateway.chat_stream = _stream + return gateway + + +class TestReActTokenStreaming: + """Test that execute_stream() yields token events from chat_stream().""" + + @pytest.mark.asyncio + async def test_yields_token_events(self): + """execute_stream should yield token events for each stream chunk.""" + chunks = _make_stream_chunks(["Hello", " world", "!"]) + gateway = _make_stream_gateway([chunks]) + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Hi"}], + ): + events.append(event) + + token_events = [e for e in events if e.event_type == "token"] + assert len(token_events) == 3 + assert token_events[0].data["content"] == "Hello" + assert token_events[1].data["content"] == " world" + assert token_events[2].data["content"] == "!" + + @pytest.mark.asyncio + async def test_final_answer_after_streaming(self): + """After streaming tokens, a final_answer event should be yielded.""" + chunks = _make_stream_chunks(["The answer is 42"]) + gateway = _make_stream_gateway([chunks]) + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "What?"}], + ): + events.append(event) + + final_events = [e for e in events if e.event_type == "final_answer"] + assert len(final_events) == 1 + assert "42" in final_events[0].data.get("output", "") + + @pytest.mark.asyncio + async def test_token_events_have_correct_step(self): + """Token events should carry the current step number.""" + chunks = _make_stream_chunks(["Hi"]) + gateway = _make_stream_gateway([chunks]) + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Hi"}], + ): + events.append(event) + + token_events = [e for e in events if e.event_type == "token"] + for te in token_events: + assert te.step >= 1 + + @pytest.mark.asyncio + async def test_build_response_from_stream(self): + """_build_response_from_stream should construct a valid LLMResponse.""" + usage = TokenUsage(prompt_tokens=5, completion_tokens=10) + response = ReActEngine._build_response_from_stream( + content="Hello world", + tool_calls=[], + usage=usage, + model="test-model", + ) + assert response.content == "Hello world" + assert response.model == "test-model" + assert response.usage.prompt_tokens == 5 + assert response.usage.completion_tokens == 10 + assert not response.has_tool_calls + + @pytest.mark.asyncio + async def test_build_response_from_stream_no_usage(self): + """_build_response_from_stream should handle None usage gracefully.""" + response = ReActEngine._build_response_from_stream( + content="test", + tool_calls=[], + usage=None, + model="test-model", + ) + assert response.content == "test" + assert response.usage.prompt_tokens == 0