diff --git a/configs/geo_tools.py b/configs/geo_tools.py index 5e34ceb..27dd0d7 100644 --- a/configs/geo_tools.py +++ b/configs/geo_tools.py @@ -462,4 +462,4 @@ def register_geo_tools(registry: ToolRegistry) -> None: tags=["knowledge", "deai"], )) - logger.info(f"GEO tools registered: {len(registry.list_all_tools())} tools") + logger.info(f"GEO tools registered: {len(registry.list_tools())} tools") diff --git a/pyproject.toml b/pyproject.toml index 2f0b212..96da667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ server = [ "fastapi>=0.110", "uvicorn>=0.27", + "sse-starlette>=2.0", ] mcp = [ "mcp>=1.0", diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 68534ae..3439f91 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -8,6 +8,7 @@ import json import logging import re from dataclasses import dataclass, field +from datetime import datetime, timezone from typing import Any from agentkit.llm.gateway import LLMGateway @@ -39,6 +40,16 @@ class ReActResult: total_tokens: int +@dataclass +class ReActEvent: + """ReAct 执行事件""" + + event_type: str # "thinking", "tool_call", "tool_result", "final_answer", "error" + step: int + data: dict[str, Any] = field(default_factory=dict) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + class ReActEngine: """ReAct 推理-行动循环引擎 @@ -186,6 +197,172 @@ class ReActEngine: total_tokens=total_tokens, ) + async def execute_stream( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + ): + """Execute ReAct loop, yielding ReActEvent objects. + + Same logic as execute() but yields events at each step instead of + accumulating a result. + """ + tools = tools or [] + tool_schemas = self._build_tool_schemas(tools) if tools else None + + conversation: list[dict[str, Any]] = [] + if system_prompt: + conversation.append({"role": "system", "content": system_prompt}) + conversation.extend(messages) + + trajectory: list[ReActStep] = [] + total_tokens = 0 + step = 0 + output = "" + + while step < self._max_steps: + step += 1 + + # Yield thinking event + yield ReActEvent( + event_type="thinking", + step=step, + data={"message": f"Step {step}: Calling LLM..."}, + ) + + # Think: call LLM + response = await self._llm_gateway.chat( + messages=conversation, + model=model, + agent_name=agent_name, + task_type=task_type, + tools=tool_schemas, + ) + + step_tokens = response.usage.total_tokens + total_tokens += step_tokens + + if response.has_tool_calls: + # Record assistant message + assistant_msg: dict[str, Any] = { + "role": "assistant", + "content": response.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments), + }, + } + for tc in response.tool_calls + ], + } + conversation.append(assistant_msg) + + for tc in response.tool_calls: + # Yield tool_call event + yield ReActEvent( + event_type="tool_call", + step=step, + data={"tool_name": tc.name, "arguments": tc.arguments}, + ) + + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=tc.name, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) + trajectory.append(react_step) + + # Yield tool_result event + yield ReActEvent( + event_type="tool_result", + step=step, + data={"tool_name": tc.name, "result": tool_result}, + ) + + tool_msg = self._build_tool_result_message(tc.id, tool_result) + conversation.append(tool_msg) + + else: + # Check text parsing mode + parsed_calls = self._parse_text_tool_calls(response.content or "") + if parsed_calls and tools: + conversation.append({"role": "assistant", "content": response.content}) + + for pc in parsed_calls: + yield ReActEvent( + event_type="tool_call", + step=step, + data={"tool_name": pc["name"], "arguments": pc["arguments"]}, + ) + tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools) + trajectory.append(ReActStep( + step=step, + action="tool_call", + tool_name=pc["name"], + arguments=pc["arguments"], + result=tool_result, + tokens=step_tokens, + )) + yield ReActEvent( + event_type="tool_result", + step=step, + data={"tool_name": pc["name"], "result": tool_result}, + ) + tool_msg = self._build_tool_result_message( + pc.get("id", f"text_tc_{step}"), tool_result + ) + conversation.append(tool_msg) + else: + # Final answer + react_step = ReActStep( + step=step, + action="final_answer", + content=response.content, + tokens=step_tokens, + ) + trajectory.append(react_step) + output = response.content or "" + yield ReActEvent( + event_type="final_answer", + step=step, + data={ + "output": output, + "total_steps": len(trajectory), + "total_tokens": total_tokens, + }, + ) + break + + if step >= self._max_steps and not output: + if trajectory and trajectory[-1].content: + output = trajectory[-1].content + elif trajectory and trajectory[-1].result is not None: + output = str(trajectory[-1].result) + else: + output = response.content or "" + yield ReActEvent( + event_type="final_answer", + step=step, + data={ + "output": output, + "total_steps": len(trajectory), + "total_tokens": total_tokens, + "max_steps_reached": True, + }, + ) + def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]: """将 Tool 对象转换为 OpenAI Function Calling schema 格式""" schemas = [] diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index f79996b..33885d4 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -5,7 +5,7 @@ import time from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.config import LLMConfig -from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage +from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage from agentkit.llm.providers.tracker import UsageSummary, UsageTracker logger = logging.getLogger(__name__) @@ -97,6 +97,62 @@ class LLMGateway: return response + async def chat_stream( + self, + messages: list[dict[str, str]], + model: str, + agent_name: str = "", + task_type: str = "", + tools: list[dict] | None = None, + tool_choice: str = "auto", + **kwargs, + ): + """Stream chat response, yielding StreamChunk objects""" + resolved_model = self._resolve_model_alias(model) + + if not self._providers: + raise LLMProviderError("", "No provider registered") + + try: + provider, actual_model = self._resolve_model(resolved_model) + except ModelNotFoundError as e: + raise LLMProviderError("", str(e)) from e + + request = LLMRequest( + messages=messages, + model=actual_model, + tools=tools, + tool_choice=tool_choice, + **kwargs, + ) + + start = time.monotonic() + total_content = "" + final_usage = None + final_model = resolved_model + + async for chunk in provider.chat_stream(request): + if chunk.content: + total_content += chunk.content + if chunk.usage: + final_usage = chunk.usage + if chunk.model: + final_model = chunk.model + yield chunk + + # Track usage after stream completes + latency_ms = (time.monotonic() - start) * 1000 + if final_usage is None: + final_usage = TokenUsage() + cost = self._calculate_cost(final_model, final_usage) + self._usage_tracker.record( + agent_name=agent_name, + model=final_model, + usage=final_usage, + cost=cost, + latency_ms=latency_ms, + ) + def _resolve_model_alias(self, model: str) -> str: """解析模型别名""" if model in self._config.model_aliases: diff --git a/src/agentkit/llm/protocol.py b/src/agentkit/llm/protocol.py index f9f0f15..15e52c8 100644 --- a/src/agentkit/llm/protocol.py +++ b/src/agentkit/llm/protocol.py @@ -56,6 +56,17 @@ class LLMRequest: self._extra = kwargs +@dataclass +class StreamChunk: + """LLM 流式响应块""" + + content: str # Delta content + model: str + tool_calls: list[ToolCall] = field(default_factory=list) # Accumulated tool calls (only in final chunk) + usage: TokenUsage | None = None # Only in final chunk + is_final: bool = False # True for the last chunk + + @dataclass class LLMResponse: """LLM 响应""" @@ -78,3 +89,18 @@ class LLMProvider(ABC): async def chat(self, request: LLMRequest) -> LLMResponse: """发送 chat 请求并返回响应""" ... + + async def chat_stream(self, request: LLMRequest): + """Stream chat response. Override in subclasses that support streaming. + + Yields StreamChunk objects. Default implementation falls back to + non-streaming chat and yields a single chunk. + """ + response = await self.chat(request) + yield StreamChunk( + content=response.content, + model=response.model, + tool_calls=response.tool_calls, + usage=response.usage, + is_final=True, + ) diff --git a/src/agentkit/llm/providers/openai.py b/src/agentkit/llm/providers/openai.py index 1bc4f09..f71cb51 100644 --- a/src/agentkit/llm/providers/openai.py +++ b/src/agentkit/llm/providers/openai.py @@ -7,7 +7,7 @@ import time import httpx from agentkit.core.exceptions import LLMProviderError -from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage, ToolCall +from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall logger = logging.getLogger(__name__) @@ -100,3 +100,108 @@ class OpenAICompatibleProvider(LLMProvider): tool_calls=tool_calls, latency_ms=latency_ms, ) + + async def chat_stream(self, request: LLMRequest): + """Stream chat response using SSE""" + url = f"{self._base_url}/chat/completions" + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + } + payload: dict = { + "model": request.model, + "messages": request.messages, + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + if request.tools: + payload["tools"] = request.tools + payload["tool_choice"] = request.tool_choice + + async with self._client.stream("POST", url, json=payload, headers=headers) as response: + if response.status_code != 200: + error_text = await response.aread() + raise LLMProviderError("openai", f"HTTP {response.status_code}") + + accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str} + + async for line in response.aiter_lines(): + line = line.strip() + if not line or not line.startswith("data: "): + continue + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + + choices = data.get("choices", []) + if not choices: + # Usage-only chunk + usage_data = data.get("usage") + if usage_data: + yield StreamChunk( + content="", + model=data.get("model", request.model), + usage=TokenUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + ), + is_final=True, + ) + continue + + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + + # Accumulate tool calls from streaming + raw_tool_calls = delta.get("tool_calls") + if raw_tool_calls: + for tc in raw_tool_calls: + idx = tc.get("index", 0) + if idx not in accumulated_tool_calls: + accumulated_tool_calls[idx] = { + "id": tc.get("id", ""), + "name": "", + "arguments_str": "", + } + if tc.get("id"): + accumulated_tool_calls[idx]["id"] = tc["id"] + func = tc.get("function", {}) + if func.get("name"): + accumulated_tool_calls[idx]["name"] = func["name"] + if func.get("arguments"): + accumulated_tool_calls[idx]["arguments_str"] += func["arguments"] + + # Only yield content chunks (not empty deltas) + if content: + yield StreamChunk( + content=content, + model=data.get("model", request.model), + ) + + # If we accumulated tool calls, yield them as a final chunk + if accumulated_tool_calls: + tool_calls = [] + for idx in sorted(accumulated_tool_calls.keys()): + tc_data = accumulated_tool_calls[idx] + try: + arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {} + except json.JSONDecodeError: + arguments = {"raw": tc_data["arguments_str"]} + tool_calls.append(ToolCall( + id=tc_data["id"], + name=tc_data["name"], + arguments=arguments, + )) + yield StreamChunk( + content="", + model=request.model, + tool_calls=tool_calls, + is_final=True, + ) diff --git a/src/agentkit/server/client.py b/src/agentkit/server/client.py index f850a35..8c813a6 100644 --- a/src/agentkit/server/client.py +++ b/src/agentkit/server/client.py @@ -126,6 +126,38 @@ class AgentKitClient: response.raise_for_status() return response.json() + async def stream_task( + self, + input_data: dict, + skill_name: str | None = None, + agent_name: str | None = None, + ): + """Stream task execution events via SSE. + + Yields event dicts with 'event' and 'data' keys. + """ + payload: dict[str, Any] = {"input_data": input_data} + if skill_name: + payload["skill_name"] = skill_name + if agent_name: + payload["agent_name"] = agent_name + + async with self._client.stream( + "POST", "/api/v1/tasks/stream", json=payload + ) as response: + response.raise_for_status() + event_type = "" + async for line in response.aiter_lines(): + line = line.strip() + if not line: + continue + if line.startswith("event: "): + event_type = line[7:] + elif line.startswith("data: "): + import json as _json + data = _json.loads(line[6:]) + yield {"event": event_type, "data": data} + async def close(self) -> None: """Close the HTTP client""" await self._client.aclose() diff --git a/src/agentkit/server/routes/tasks.py b/src/agentkit/server/routes/tasks.py index 52d70e9..6557118 100644 --- a/src/agentkit/server/routes/tasks.py +++ b/src/agentkit/server/routes/tasks.py @@ -1,5 +1,6 @@ """Task submission routes""" +import json import uuid from datetime import datetime, timezone @@ -191,3 +192,79 @@ async def cancel_task(task_id: str, req: Request): if not cancelled: raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)") return {"task_id": task_id, "status": "cancelled"} + + +@router.post("/tasks/stream") +async def stream_task(request: SubmitTaskRequest, req: Request): + """Submit a task and stream ReAct events via SSE""" + from sse_starlette.sse import EventSourceResponse + + pool = req.app.state.agent_pool + skill_registry = req.app.state.skill_registry + intent_router = req.app.state.intent_router + + agent = None + + # Same agent resolution logic as submit_task + if request.agent_name: + agent = pool.get_agent(request.agent_name) + if agent is None: + raise HTTPException( + status_code=404, + detail=f"Agent '{request.agent_name}' not found", + ) + elif request.skill_name: + try: + skill_registry.get(request.skill_name) + except Exception: + raise HTTPException( + status_code=404, + detail=f"Skill '{request.skill_name}' not found", + ) + agent = pool.get_agent(request.skill_name) + if agent is None: + agent = await pool.create_agent_from_skill(request.skill_name) + else: + all_skills = skill_registry.list_skills() + if not all_skills: + raise HTTPException( + status_code=400, + detail="No skills registered and no skill_name or agent_name specified", + ) + try: + routing_result = await intent_router.route(request.input_data, all_skills) + skill_registry.get(routing_result.matched_skill) + agent = pool.get_agent(routing_result.matched_skill) + if agent is None: + agent = await pool.create_agent_from_skill(routing_result.matched_skill) + except (ValueError, RuntimeError) as e: + raise HTTPException(status_code=400, detail=str(e)) + + async def event_generator(): + from agentkit.core.react import ReActEngine + + react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) + + # Build messages from input + messages = [{"role": "user", "content": str(request.input_data)}] + + # Get tools from agent + tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else [] + + async for event in react_engine.execute_stream( + messages=messages, + tools=tools, + model=agent._llm_model if hasattr(agent, "_llm_model") else "default", + agent_name=agent.name, + system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None, + ): + yield { + "event": event.event_type, + "data": json.dumps({ + "step": event.step, + "data": event.data, + "timestamp": event.timestamp, + }), + } + + return EventSourceResponse(event_generator()) diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py new file mode 100644 index 0000000..7b09224 --- /dev/null +++ b/tests/unit/test_streaming.py @@ -0,0 +1,431 @@ +"""Streaming System 单元测试 - U8/U9/U10 + +覆盖: +- StreamChunk 数据类 +- LLMProvider.chat_stream 默认回退 +- Gateway.chat_stream 流式 + 用量追踪 +- ReActEvent 数据类 +- ReActEngine.execute_stream 事件流 +- SSE 端点 /tasks/stream +""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage, ToolCall +from agentkit.tools.base import Tool + + +# ── Test Helpers ────────────────────────────────────────── + + +class FakeTool(Tool): + """用于测试的 Fake Tool""" + + def __init__( + self, + name: str = "fake_tool", + description: str = "A fake tool for testing", + result: dict | None = None, + ): + super().__init__(name=name, description=description) + self._result = result or {"status": "ok"} + + async def execute(self, **kwargs) -> dict: + return self._result + + +def make_response( + content: str = "", + tool_calls: list[ToolCall] | None = None, + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> LLMResponse: + """快速构造 LLMResponse""" + return LLMResponse( + content=content, + model="test-model", + usage=TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ), + tool_calls=tool_calls or [], + ) + + +# ══════════════════════════════════════════════════════════ +# U8: StreamChunk + chat_stream +# ══════════════════════════════════════════════════════════ + + +class TestStreamChunk: + """StreamChunk 数据类测试""" + + def test_creation_with_content(self): + from agentkit.llm.protocol import StreamChunk + + chunk = StreamChunk(content="Hello", model="gpt-4o") + assert chunk.content == "Hello" + assert chunk.model == "gpt-4o" + assert chunk.tool_calls == [] + assert chunk.usage is None + assert chunk.is_final is False + + def test_with_tool_calls_and_is_final(self): + from agentkit.llm.protocol import StreamChunk + + tc = ToolCall(id="tc_1", name="search", arguments={"q": "test"}) + chunk = StreamChunk( + content="", + model="gpt-4o", + tool_calls=[tc], + is_final=True, + ) + assert len(chunk.tool_calls) == 1 + assert chunk.tool_calls[0].name == "search" + assert chunk.is_final is True + + def test_with_usage(self): + from agentkit.llm.protocol import StreamChunk + + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + chunk = StreamChunk( + content="", + model="gpt-4o", + usage=usage, + is_final=True, + ) + assert chunk.usage is not None + assert chunk.usage.total_tokens == 150 + assert chunk.is_final is True + + +class TestLLMProviderChatStreamDefault: + """LLMProvider.chat_stream 默认实现回退到 chat()""" + + async def test_default_chat_stream_yields_single_chunk(self): + from agentkit.llm.protocol import LLMProvider, StreamChunk + + class SimpleProvider(LLMProvider): + async def chat(self, request: LLMRequest) -> LLMResponse: + return LLMResponse( + content="hello", + model="test", + usage=TokenUsage(prompt_tokens=5, completion_tokens=10), + ) + + provider = SimpleProvider() + request = LLMRequest( + messages=[{"role": "user", "content": "hi"}], + model="test", + ) + + chunks = [] + async for chunk in provider.chat_stream(request): + chunks.append(chunk) + + assert len(chunks) == 1 + assert chunks[0].content == "hello" + assert chunks[0].is_final is True + assert chunks[0].usage.total_tokens == 15 + + +class TestGatewayChatStream: + """Gateway.chat_stream 流式测试""" + + async def test_yields_chunks_from_provider(self): + from agentkit.llm.protocol import StreamChunk + from agentkit.llm.gateway import LLMGateway + from agentkit.llm.protocol import LLMProvider + + class StreamingProvider(LLMProvider): + async def chat(self, request: LLMRequest) -> LLMResponse: + return LLMResponse( + content="fallback", + model="test", + usage=TokenUsage(), + ) + + async def chat_stream(self, request: LLMRequest): + yield StreamChunk(content="Hello ", model="test") + yield StreamChunk(content="World", model="test") + yield StreamChunk( + content="", + model="test", + usage=TokenUsage(prompt_tokens=10, completion_tokens=5), + is_final=True, + ) + + gateway = LLMGateway() + gateway.register_provider("test", StreamingProvider()) + + chunks = [] + async for chunk in gateway.chat_stream( + messages=[{"role": "user", "content": "hi"}], + model="test/model", + ): + chunks.append(chunk) + + assert len(chunks) == 3 + assert chunks[0].content == "Hello " + assert chunks[1].content == "World" + assert chunks[2].is_final is True + + async def test_tracks_usage_after_stream_completes(self): + from agentkit.llm.protocol import StreamChunk + from agentkit.llm.gateway import LLMGateway + from agentkit.llm.protocol import LLMProvider + + class StreamingProvider(LLMProvider): + async def chat(self, request: LLMRequest) -> LLMResponse: + return LLMResponse( + content="fallback", + model="test", + usage=TokenUsage(), + ) + + async def chat_stream(self, request: LLMRequest): + yield StreamChunk(content="Hi", model="test") + yield StreamChunk( + content="", + model="test", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + is_final=True, + ) + + gateway = LLMGateway() + gateway.register_provider("test", StreamingProvider()) + + # Consume the stream + chunks = [] + async for chunk in gateway.chat_stream( + messages=[{"role": "user", "content": "hi"}], + model="test/model", + agent_name="stream_agent", + ): + chunks.append(chunk) + + # Verify usage was tracked + usage = gateway.get_usage() + assert usage.total_tokens == 150 + + +# ══════════════════════════════════════════════════════════ +# U9: ReActEvent + execute_stream +# ══════════════════════════════════════════════════════════ + + +class TestReActEvent: + """ReActEvent 数据类测试""" + + def test_creation_with_event_type_and_step(self): + from agentkit.core.react import ReActEvent + + event = ReActEvent(event_type="thinking", step=1) + assert event.event_type == "thinking" + assert event.step == 1 + assert event.data == {} + + def test_has_timestamp(self): + from agentkit.core.react import ReActEvent + + event = ReActEvent(event_type="thinking", step=1) + assert event.timestamp is not None + assert len(event.timestamp) > 0 + + def test_with_data(self): + from agentkit.core.react import ReActEvent + + event = ReActEvent( + event_type="tool_call", + step=2, + data={"tool_name": "search", "arguments": {"q": "test"}}, + ) + assert event.data["tool_name"] == "search" + + +class TestReActEngineExecuteStream: + """ReActEngine.execute_stream 事件流测试""" + + async def test_yields_thinking_event_at_each_step(self): + from agentkit.core.react import ReActEngine, ReActEvent + + gateway = MagicMock() + gateway.chat = AsyncMock(return_value=make_response(content="Final answer")) + + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Hello"}], + ): + events.append(event) + + # Should have thinking + final_answer + thinking_events = [e for e in events if e.event_type == "thinking"] + assert len(thinking_events) >= 1 + assert thinking_events[0].step == 1 + + async def test_yields_tool_call_and_tool_result_events(self): + from agentkit.core.react import ReActEngine, ReActEvent + + tool = FakeTool(name="calculator", result={"value": 42}) + + gateway = MagicMock() + gateway.chat = AsyncMock(side_effect=[ + make_response( + content="", + tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})], + ), + make_response(content="The result is 42"), + ]) + + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Calculate"}], + tools=[tool], + ): + events.append(event) + + tool_call_events = [e for e in events if e.event_type == "tool_call"] + tool_result_events = [e for e in events if e.event_type == "tool_result"] + + assert len(tool_call_events) == 1 + assert tool_call_events[0].data["tool_name"] == "calculator" + assert len(tool_result_events) == 1 + assert tool_result_events[0].data["tool_name"] == "calculator" + assert tool_result_events[0].data["result"] == {"value": 42} + + async def test_yields_final_answer_event(self): + from agentkit.core.react import ReActEngine, ReActEvent + + gateway = MagicMock() + gateway.chat = AsyncMock(return_value=make_response(content="The answer is 42")) + + engine = ReActEngine(llm_gateway=gateway) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "What is the answer?"}], + ): + events.append(event) + + final_events = [e for e in events if e.event_type == "final_answer"] + assert len(final_events) == 1 + assert final_events[0].data["output"] == "The answer is 42" + assert final_events[0].data["total_steps"] >= 1 + assert final_events[0].data["total_tokens"] > 0 + + async def test_yields_max_steps_reached_when_hitting_limit(self): + from agentkit.core.react import ReActEngine, ReActEvent + + tool = FakeTool(name="search", result={"results": ["data"]}) + + always_tool_response = make_response( + content="Thinking...", + tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})], + ) + gateway = MagicMock() + gateway.chat = AsyncMock(return_value=always_tool_response) + + engine = ReActEngine(llm_gateway=gateway, max_steps=3) + + events = [] + async for event in engine.execute_stream( + messages=[{"role": "user", "content": "Keep searching"}], + tools=[tool], + ): + events.append(event) + + final_events = [e for e in events if e.event_type == "final_answer"] + assert len(final_events) == 1 + assert final_events[0].data.get("max_steps_reached") is True + + +# ══════════════════════════════════════════════════════════ +# U10: SSE Endpoint + Client SDK +# ══════════════════════════════════════════════════════════ + + +class TestSSEEndpoint: + """SSE /tasks/stream 端点测试""" + + def test_stream_task_returns_event_source_response(self): + from fastapi.testclient import TestClient + from agentkit.server.app import create_app + from agentkit.llm.gateway import LLMGateway + from agentkit.skills.registry import SkillRegistry + from agentkit.tools.registry import ToolRegistry + + gateway = LLMGateway() + mock_provider = AsyncMock() + mock_provider.chat.return_value = LLMResponse( + content="Final answer", + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + ) + gateway.register_provider("test", mock_provider) + + skill_registry = SkillRegistry() + tool_registry = ToolRegistry() + app = create_app( + llm_gateway=gateway, + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + client = TestClient(app) + + # Create an agent first + client.post( + "/api/v1/agents", + json={ + "config": { + "name": "stream_agent", + "agent_type": "test", + "task_mode": "llm_generate", + "prompt": {"identity": "Stream Agent"}, + } + }, + ) + + # Stream task + response = client.post( + "/api/v1/tasks/stream", + json={ + "input_data": {"query": "test"}, + "agent_name": "stream_agent", + }, + ) + # Should return 200 with SSE content type + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + def test_stream_task_with_invalid_agent_returns_404(self): + from fastapi.testclient import TestClient + from agentkit.server.app import create_app + from agentkit.llm.gateway import LLMGateway + from agentkit.skills.registry import SkillRegistry + from agentkit.tools.registry import ToolRegistry + + gateway = LLMGateway() + skill_registry = SkillRegistry() + tool_registry = ToolRegistry() + app = create_app( + llm_gateway=gateway, + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + client = TestClient(app) + + response = client.post( + "/api/v1/tasks/stream", + json={ + "input_data": {"query": "test"}, + "agent_name": "nonexistent_agent", + }, + ) + assert response.status_code == 404