From 5b5291c7e52140abc42fd989158a02452702fd41 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Wed, 17 Jun 2026 22:11:51 +0800 Subject: [PATCH] fix: WebSocket task persistence three-layer defense with security hardening Fix chat history empty content and task stops on refresh. Implements: result persistence on disconnect, task backgrounding via asyncio + EventQueue, frontend reconnection recovery. Security: fail-closed conversation_id ownership, asyncio.shield on CancelledError cleanup, async TaskStore shim, EventQueue subscriber limit, connection error resilience. 23 tests added. --- src/agentkit/core/event_queue.py | 58 +- .../server/frontend/src/api/client.ts | 15 + src/agentkit/server/frontend/src/api/types.ts | 29 + .../server/frontend/src/stores/chat.ts | 134 ++- src/agentkit/server/routes/portal.py | 673 +++++++++++-- .../unit/server/test_portal_ws_background.py | 938 ++++++++++++++++++ 6 files changed, 1741 insertions(+), 106 deletions(-) create mode 100644 tests/unit/server/test_portal_ws_background.py diff --git a/src/agentkit/core/event_queue.py b/src/agentkit/core/event_queue.py index 887e208..38d536f 100644 --- a/src/agentkit/core/event_queue.py +++ b/src/agentkit/core/event_queue.py @@ -34,6 +34,20 @@ _CLOSED_SENTINEL: Event = Event( ) +@dataclass +class _Subscriber: + """Internal subscriber tracking with optional task_id filter.""" + + queue: asyncio.Queue[Event] + task_id_filter: str | None = None # None = receive all events + + def matches(self, event: Event) -> bool: + """Check if this subscriber should receive the given event.""" + if self.task_id_filter is None: + return True + return event.task_id == self.task_id_filter + + @dataclass class Submission: """用户提交的任务 @@ -151,9 +165,12 @@ class EventQueue: _MAX_QUEUE_SIZE: int = 1024 _DEFAULT_BUFFER_SIZE: int = 100 + # P1 #13 fix: cap total subscribers to prevent resource exhaustion + # from malicious resume floods or runaway client loops. + _MAX_SUBSCRIBERS: int = 1000 def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None: - self._subscribers: list[asyncio.Queue[Event]] = [] + self._subscribers: list[_Subscriber] = [] self._buffer: deque[Event] = deque(maxlen=buffer_size) self._buffer_size = buffer_size self._closed: bool = False @@ -163,35 +180,53 @@ class EventQueue: 事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。 如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。 + 支持按 task_id 过滤:只有 task_id 匹配的订阅者才会收到事件。 Args: event: 要推送的事件 """ self._buffer.append(event) - for queue in self._subscribers: + for sub in self._subscribers: + if not sub.matches(event): + continue try: - queue.put_nowait(event) + sub.queue.put_nowait(event) except asyncio.QueueFull: logger.warning("EventQueue subscriber queue full, dropping event") - async def subscribe(self) -> AsyncIterator[Event]: + async def subscribe(self, task_id: str | None = None) -> AsyncIterator[Event]: """订阅事件流(异步生成器) - 订阅时会先回放缓冲区中的事件,然后持续接收新事件。 + 订阅时会先回放缓冲区中的事件(按 task_id 过滤),然后持续接收新事件。 每个订阅者获得独立的队列,实现广播语义。 当队列关闭时,生成器结束。 + Args: + task_id: 可选的任务 ID 过滤器。如果提供,只接收该任务的 events。 + None 表示接收所有事件。 + 注意:回放和加入订阅者列表在同一同步段内完成(无 await), 保证不会遗漏或重复事件。 """ if self._closed: return + # P1 #13 fix: enforce subscriber cap to prevent resource exhaustion + # from malicious resume floods or runaway client loops. + if len(self._subscribers) >= self._MAX_SUBSCRIBERS: + logger.error( + "EventQueue subscriber limit reached (%d), rejecting new subscription", + self._MAX_SUBSCRIBERS, + ) + raise RuntimeError(f"EventQueue subscriber limit reached ({self._MAX_SUBSCRIBERS})") + queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE) # 回放缓冲事件(同步操作,无 await,保证原子性) for event in list(self._buffer): + if task_id is not None and event.task_id != task_id: + continue try: queue.put_nowait(event) except asyncio.QueueFull: @@ -199,7 +234,8 @@ class EventQueue: break # 加入订阅者列表(在回放之后,确保不会收到重复事件) - self._subscribers.append(queue) + sub = _Subscriber(queue=queue, task_id_filter=task_id) + self._subscribers.append(sub) try: while True: @@ -208,9 +244,9 @@ class EventQueue: break yield event finally: - # 清理:移除当前订阅者的队列 - if queue in self._subscribers: - self._subscribers.remove(queue) + # 清理:移除当前订阅者 + if sub in self._subscribers: + self._subscribers.remove(sub) @property def subscriber_count(self) -> int: @@ -235,9 +271,9 @@ class EventQueue: """ self._closed = True # 向所有活跃订阅者队列放入哨兵,使其能够优雅退出 - for queue in self._subscribers: + for sub in self._subscribers: try: - queue.put_nowait(_CLOSED_SENTINEL) + sub.queue.put_nowait(_CLOSED_SENTINEL) except asyncio.QueueFull: pass self._subscribers.clear() diff --git a/src/agentkit/server/frontend/src/api/client.ts b/src/agentkit/server/frontend/src/api/client.ts index a68cd51..5852036 100644 --- a/src/agentkit/server/frontend/src/api/client.ts +++ b/src/agentkit/server/frontend/src/api/client.ts @@ -3,6 +3,8 @@ import type { IChatResponse, ICapabilitiesResponse, IConversation, + ITaskRecord, + TaskStatus, } from './types' import { BaseApiClient } from './base' @@ -36,6 +38,19 @@ class ApiClient extends BaseApiClient { return this.request(`/conversations/${id}`) } + /** Get a task by ID (uses /api/v1/tasks prefix) */ + async getTask(taskId: string): Promise { + return this.request(`/api/v1/tasks/${taskId}`) + } + + /** List tasks, optionally filtered by status (uses /api/v1/tasks prefix) */ + async listTasks(status?: TaskStatus, limit: number = 100): Promise { + const params = new URLSearchParams() + if (status) params.set('status', status) + params.set('limit', String(limit)) + return this.request(`/api/v1/tasks?${params.toString()}`) + } + /** Create a WebSocket connection for real-time chat */ createWebSocket(): WebSocket { return super.createWebSocket('/ws') diff --git a/src/agentkit/server/frontend/src/api/types.ts b/src/agentkit/server/frontend/src/api/types.ts index 25a3aaf..06f0353 100644 --- a/src/agentkit/server/frontend/src/api/types.ts +++ b/src/agentkit/server/frontend/src/api/types.ts @@ -78,6 +78,15 @@ export type WsClientMessage = { sources?: string[] conversation_id?: string model?: string +} | { + type: 'resume' + task_id: string + conversation_id?: string +} | { + type: 'cancel' + task_id?: string +} | { + type: 'ping' } /** WebSocket server message types — matches backend portal.py protocol */ @@ -132,3 +141,23 @@ export interface IApiError { message: string detail?: string } + +/** Task status (matches backend TaskStatus enum) */ +export type TaskStatus = 'pending' | 'running' | 'completed' | 'partially_completed' | 'failed' | 'cancelled' + +/** Task record (matches backend TaskRecord.to_dict()) */ +export interface ITaskRecord { + task_id: string + agent_name: string + skill_name: string | null + input_data: Record + status: TaskStatus + output_data: Record | null + error_message: string | null + created_at: string + started_at: string | null + completed_at: string | null + progress: number + progress_message: string + metadata: Record +} diff --git a/src/agentkit/server/frontend/src/stores/chat.ts b/src/agentkit/server/frontend/src/stores/chat.ts index 2c17080..49b0a97 100644 --- a/src/agentkit/server/frontend/src/stores/chat.ts +++ b/src/agentkit/server/frontend/src/stores/chat.ts @@ -52,10 +52,37 @@ export const useChatStore = defineStore('chat', () => { } } - /** Select a conversation by ID */ - function selectConversation(id: string): void { + /** Select a conversation by ID and load its messages */ + async function selectConversation(id: string, force = false): Promise { currentConversationId.value = id streamingSteps.value = [] + + // Load full conversation with messages if not already loaded (or when forced) + const conv = conversations.value.find((c) => c.id === id) + if (force || !conv || !conv.messages || conv.messages.length === 0) { + try { + const fullConv = await apiClient.getConversation(id) + if (conv) { + conv.messages = fullConv.messages || [] + conv.title = fullConv.title || conv.title + conv.created_at = fullConv.created_at || conv.created_at + conv.updated_at = fullConv.updated_at || conv.updated_at + } else { + // P1 #7 fix: If the conversation is not in the local list (e.g. + // after a page refresh), add it from the fetched data instead + // of silently discarding the result. + conversations.value.unshift({ + id: fullConv.id || id, + title: fullConv.title || '新对话', + messages: fullConv.messages || [], + created_at: fullConv.created_at || new Date().toISOString(), + updated_at: fullConv.updated_at || new Date().toISOString(), + }) + } + } catch (error) { + console.error('Failed to load conversation messages:', error) + } + } } /** Create a new empty conversation */ @@ -135,7 +162,7 @@ export const useChatStore = defineStore('chat', () => { } /** Send a message via WebSocket for streaming */ - function sendWsMessage(message: string, sources?: string[], model?: string): void { + async function sendWsMessage(message: string, sources?: string[], model?: string): Promise { if (!currentConversationId.value) { createConversation() } @@ -143,7 +170,7 @@ export const useChatStore = defineStore('chat', () => { // Check WebSocket state BEFORE creating messages to avoid duplicates if (!ws.value || ws.value.readyState !== WebSocket.OPEN) { // Fallback to REST directly — sendMessage will create its own messages - sendMessage(message, sources) + await sendMessage(message, sources) return } @@ -179,7 +206,22 @@ export const useChatStore = defineStore('chat', () => { model, } - ws.value.send(JSON.stringify(wsMessage)) + // Problem 7: catch send() exceptions (e.g. connection closed mid-send) + try { + ws.value.send(JSON.stringify(wsMessage)) + } catch (error) { + console.error('WebSocket send failed, falling back to REST:', error) + // Remove the placeholder messages we just added; sendMessage will re-add them + const conv = conversations.value.find((c) => c.id === conversationId) + if (conv) { + conv.messages = conv.messages.filter( + (m) => m.id !== userMessage.id && m.id !== assistantMessage.id, + ) + } + isLoading.value = false + await sendMessage(message, sources) + return + } // Update conversation title from first user message const conv = conversations.value.find((c) => c.id === conversationId) @@ -190,12 +232,16 @@ export const useChatStore = defineStore('chat', () => { /** Connect to WebSocket for real-time streaming */ let _heartbeatTimer: ReturnType | null = null + let _reconnectTimer: ReturnType | null = null + let _intentionalDisconnect = false function connectWebSocket(): void { - if (ws.value && ws.value.readyState === WebSocket.OPEN) { + // Problem 6: also skip if already CONNECTING to avoid orphan sockets + if (ws.value && (ws.value.readyState === WebSocket.OPEN || ws.value.readyState === WebSocket.CONNECTING)) { return } + _intentionalDisconnect = false const socket = apiClient.createWebSocket() socket.onopen = () => { @@ -208,6 +254,8 @@ export const useChatStore = defineStore('chat', () => { ws.value.send(JSON.stringify({ type: 'ping' })) } }, 30000) + // Check for running tasks to resume after reconnection + _recoverTaskAfterReconnect() } socket.onmessage = (event: MessageEvent) => { @@ -222,13 +270,22 @@ export const useChatStore = defineStore('chat', () => { socket.onclose = () => { isWsConnected.value = false + // P2 #21 fix: reset isLoading to prevent stuck loading state during + // disconnect. _recoverTaskAfterReconnect will re-set it if an active + // task is found after reconnection. + isLoading.value = false console.log('WebSocket disconnected') if (_heartbeatTimer) { clearInterval(_heartbeatTimer) _heartbeatTimer = null } + // Problem 1: do not auto-reconnect after an intentional disconnect + if (_intentionalDisconnect) { + return + } // Auto reconnect after 3 seconds - setTimeout(() => { + if (_reconnectTimer) clearTimeout(_reconnectTimer) + _reconnectTimer = setTimeout(() => { if (!ws.value || ws.value.readyState === WebSocket.CLOSED) { connectWebSocket() } @@ -245,6 +302,11 @@ export const useChatStore = defineStore('chat', () => { /** Disconnect WebSocket */ function disconnectWebSocket(): void { + _intentionalDisconnect = true + if (_reconnectTimer) { + clearTimeout(_reconnectTimer) + _reconnectTimer = null + } if (_heartbeatTimer) { clearInterval(_heartbeatTimer) _heartbeatTimer = null @@ -256,6 +318,64 @@ export const useChatStore = defineStore('chat', () => { } } + /** After WebSocket reconnects, check for running tasks and resume them */ + async function _recoverTaskAfterReconnect(): Promise { + if (!currentConversationId.value) return + + try { + // Problem 2: query both 'running' and 'pending' tasks — a task may be + // in PENDING state if the background worker hasn't picked it up yet. + const [runningTasks, pendingTasks] = await Promise.all([ + apiClient.listTasks('running'), + apiClient.listTasks('pending'), + ]) + const candidates = [...runningTasks, ...pendingTasks] + const activeTask = candidates.find( + (t) => t.metadata?.conversation_id === currentConversationId.value, + ) + + if (activeTask && ws.value && ws.value.readyState === WebSocket.OPEN) { + // P1 #12 fix: Clear the last pending assistant message's accumulated + // content before resuming. The EventQueue will replay buffered events + // from the beginning of the task, so any partial content already + // accumulated before disconnect would be duplicated if not cleared. + const conv = conversations.value.find((c) => c.id === currentConversationId.value) + if (conv) { + const lastPendingAssistant = [...conv.messages] + .reverse() + .find((m) => m.role === 'assistant' && m.status === 'pending') + if (lastPendingAssistant) { + updateMessage(currentConversationId.value, lastPendingAssistant.id, { + content: '', + thinking: '', + tool_calls: [], + }) + } + } + // Problem 5: only set isLoading if we can actually send the resume + isLoading.value = true + try { + ws.value.send( + JSON.stringify({ + type: 'resume', + task_id: activeTask.task_id, + conversation_id: currentConversationId.value, + }), + ) + } catch (error) { + console.error('Failed to send resume message:', error) + isLoading.value = false + } + } else { + // No active task — force reload conversation messages (may include + // completed results persisted while disconnected). + await selectConversation(currentConversationId.value, true) + } + } catch (error) { + console.error('Failed to recover task after reconnect:', error) + } + } + // --- Internal helpers --- /** Get team store lazily — safe to call inside actions after Pinia is installed */ diff --git a/src/agentkit/server/routes/portal.py b/src/agentkit/server/routes/portal.py index 4bdabed..e17abf5 100644 --- a/src/agentkit/server/routes/portal.py +++ b/src/agentkit/server/routes/portal.py @@ -21,7 +21,7 @@ from pydantic import BaseModel from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.event_queue import EventQueue -from agentkit.core.protocol import Event, TaskEventType, TurnEventType +from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType from agentkit.core.react import ReActEngine from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult from agentkit.chat.request_preprocessor import RequestPreprocessor @@ -32,20 +32,15 @@ from agentkit.server.routes.evolution_dashboard import ( ) from agentkit.core.fallback import EMPTY_LLM_RESPONSE from agentkit.chat.sqlite_conversation_store import SqliteConversationStore +from agentkit.server.task_store import InMemoryTaskStore logger = logging.getLogger(__name__) router = APIRouter(tags=["portal"]) -# Map ReAct engine event_type strings to TurnEventType constants for EQ emission. -# Only events with a corresponding TurnEventType are forwarded to the EQ; -# other events (e.g. "token") are still sent over WebSocket but not duplicated to EQ. -_REACT_EVENT_TYPE_MAP: dict[str, str] = { - "thinking": TurnEventType.THINKING, - "tool_call": TurnEventType.TOOL_CALL, - "tool_result": TurnEventType.TOOL_RESULT, - "final_answer": TurnEventType.FINAL_ANSWER, -} +# Track background ReAct tasks so they are not garbage-collected mid-execution. +# Tasks are removed automatically via add_done_callback when they complete. +_running_background_tasks: set[asyncio.Task] = set() # --------------------------------------------------------------------------- # API Key Authentication @@ -95,6 +90,37 @@ async def _emit_event_safe( logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True) +# P1 #14 fix: TaskStore sync/async compatibility shim. +# InMemoryTaskStore methods are sync; RedisTaskStore methods are async. +# These helpers detect and await coroutines so portal.py works with both. +async def _task_store_create(store, *args, **kwargs): + result = store.create(*args, **kwargs) + if asyncio.iscoroutine(result): + return await result + return result + + +async def _task_store_get(store, *args, **kwargs): + result = store.get(*args, **kwargs) + if asyncio.iscoroutine(result): + return await result + return result + + +async def _task_store_update_status(store, *args, **kwargs): + result = store.update_status(*args, **kwargs) + if asyncio.iscoroutine(result): + return await result + return result + + +async def _task_store_list_tasks(store, *args, **kwargs): + result = store.list_tasks(*args, **kwargs) + if asyncio.iscoroutine(result): + return await result + return result + + async def _verify_api_key( request: Request, api_key_header: str | None = Security(_api_key_header), @@ -144,6 +170,19 @@ class Conversation: _WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120")) _conversation_store = SqliteConversationStore() +# P1 #9 fix: ReAct event type -> TurnEventType mapping for EQ subscribers. +# Preserves the original EQ contract so CLI and other subscribers that +# filter on TurnEventType constants (e.g. 'turn.thinking') keep working. +_REACT_EVENT_TYPE_MAP: dict[str, str] = { + "thinking": TurnEventType.THINKING, + "tool_call": TurnEventType.TOOL_CALL, + "tool_result": TurnEventType.TOOL_RESULT, + "token": TurnEventType.TOKEN, + "final_answer": TurnEventType.FINAL_ANSWER, + "error": TurnEventType.TURN_COMPLETED, # best-effort mapping + "confirmation_request": TurnEventType.STEP, +} + # --------------------------------------------------------------------------- # History injection helper — configurable limit + optional compression # --------------------------------------------------------------------------- @@ -666,6 +705,171 @@ def _derive_title_from_messages(messages: list) -> str: return "对话" +async def _execute_react_background( + react_engine: ReActEngine, + messages: list[dict], + tools: list, + model: str, + agent_name: str, + system_prompt: str | None, + timeout_seconds: float | None, + conv_id: str, + task_id: str, + event_queue: EventQueue, + conversation_store: SqliteConversationStore, + task_store: InMemoryTaskStore | None = None, +) -> None: + """Execute ReAct engine in the background, decoupled from WebSocket lifecycle. + + Events are emitted to the EventQueue (filtered by task_id) so that any + subscriber — including a reconnected WebSocket — can consume them. + Results are always persisted to the conversation store, regardless of + whether a WebSocket subscriber is active. + Task status is tracked in TaskStore when provided. + """ + collected_output: list[str] = [] + try: + if task_store is not None: + try: + await _task_store_update_status( + task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc) + ) + except Exception: + logger.warning("Failed to update TaskStore RUNNING", exc_info=True) + + async for event in react_engine.execute_stream( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + system_prompt=system_prompt, + timeout_seconds=timeout_seconds, + ): + if event.event_type == "final_answer": + collected_output.append(event.data.get("output", "")) + + # P1 #8/#9/#10 fix: Preserve TurnEventType mapping, step field, + # and original data structure for EQ subscriber compatibility. + # Note: Event dataclass has no 'step' field; use getattr for + # compatibility with ReActEngine events that may include it. + _turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type) + if _turn_event_type is not None: + await _emit_event_safe( + event_queue, + _turn_event_type, + task_id=task_id, + session_id=conv_id, + data={ + **event.data, + "step": getattr(event, "step", 0), + "timestamp": event.timestamp, + }, + ) + + # Normal completion: persist result + response_text = _ensure_non_empty("".join(collected_output) if collected_output else None) + await conversation_store.add_message(conv_id, "assistant", response_text) + + if task_store is not None: + try: + await _task_store_update_status( + task_store, + task_id, + TaskStatus.COMPLETED, + output_data={"output": response_text}, + completed_at=datetime.now(timezone.utc), + progress=1.0, + progress_message="Completed", + ) + except Exception: + logger.warning("Failed to update TaskStore COMPLETED", exc_info=True) + + # Emit task.completed so subscribers know the task is done + await _emit_event_safe( + event_queue, + TaskEventType.TASK_COMPLETED, + task_id=task_id, + session_id=conv_id, + data={"output": response_text, "timestamp": datetime.now(timezone.utc).isoformat()}, + ) + + except asyncio.CancelledError: + # Application shutdown or explicit cancel — persist partial output + # and mark task as FAILED so resume does not block forever. + # P0 #1/#2 fix: ALL persistence operations must use asyncio.shield + # and the async TaskStore shim. Without shield, a re-entrant + # cancellation kills the cleanup itself; without the shim, + # RedisTaskStore (async) silently drops the coroutine. + if collected_output: + partial = _ensure_non_empty("".join(collected_output)) + try: + await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial)) + except (Exception, asyncio.CancelledError): + logger.warning("Failed to persist partial output on cancel") + if task_store is not None: + try: + await asyncio.shield( + _task_store_update_status( + task_store, + task_id, + TaskStatus.FAILED, + error_message="Task cancelled", + completed_at=datetime.now(timezone.utc), + ) + ) + except (Exception, asyncio.CancelledError): + logger.warning("Failed to update TaskStore on cancel", exc_info=True) + # P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit). + # Shield it so a re-entrant CancelledError doesn't kill the emit + # and leave subscribers blocked until timeout. + try: + await asyncio.shield( + _emit_event_safe( + event_queue, + TaskEventType.TASK_FAILED, + task_id=task_id, + session_id=conv_id, + data={ + "error": "Task cancelled", + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + ) + except (Exception, asyncio.CancelledError): + logger.warning("Failed to emit TASK_FAILED on cancel") + raise # Propagate cancellation + + except Exception as e: + # Persist any partial output collected before the error + if collected_output: + partial = _ensure_non_empty("".join(collected_output)) + try: + await conversation_store.add_message(conv_id, "assistant", partial) + except Exception: + logger.warning("Failed to persist partial output in background task") + + if task_store is not None: + try: + await _task_store_update_status( + task_store, + task_id, + TaskStatus.FAILED, + error_message=str(e), + completed_at=datetime.now(timezone.utc), + ) + except Exception: + logger.warning("Failed to update TaskStore FAILED", exc_info=True) + + # Emit task.failed so subscribers know the task failed + await _emit_event_safe( + event_queue, + TaskEventType.TASK_FAILED, + task_id=task_id, + session_id=conv_id, + data={"error": str(e), "timestamp": datetime.now(timezone.utc).isoformat()}, + ) + + @router.websocket("/portal/ws") async def portal_websocket(websocket: WebSocket): """Real-time chat WebSocket endpoint.""" @@ -692,6 +896,8 @@ async def portal_websocket(websocket: WebSocket): conv: Conversation | None = None # task_id is per-user-message; tracked here so the outer except can emit task.failed task_id: str | None = None + # Track the active background task so cancel can propagate to it. + active_bg_task: asyncio.Task | None = None try: while True: @@ -710,6 +916,10 @@ async def portal_websocket(websocket: WebSocket): msg_type = msg.get("type") if msg_type == "cancel": + # Cancel the active background task if still running + if active_bg_task is not None and not active_bg_task.done(): + active_bg_task.cancel() + active_bg_task = None await websocket.send_json( { "type": "result", @@ -725,6 +935,203 @@ async def portal_websocket(websocket: WebSocket): await websocket.send_json({"type": "pong"}) continue + if msg_type == "resume": + # Frontend reconnected and wants to resume a running task + resume_task_id = msg.get("task_id", "") + if not resume_task_id: + continue + + # P1 #3/#4 fix: Fail-closed ownership verification. + # Require conversation_id and TaskStore — reject if either + # is missing, to prevent cross-conversation task hijacking + # via empty conversation_id or unconfigured TaskStore. + resume_conv_id = msg.get("conversation_id", "") + if not resume_conv_id: + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Resume requires conversation_id.", + "task_id": resume_task_id, + }, + } + ) + continue + + resume_task_store: InMemoryTaskStore | None = getattr( + websocket.app.state, "task_store", None + ) + resume_eq: EventQueue | None = getattr(websocket.app.state, "event_queue", None) + + # P1 #4: Fail-closed if TaskStore is unavailable — cannot + # verify ownership without it. + if resume_task_store is None: + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Resume not supported (TaskStore unavailable). Please retry your request.", + "task_id": resume_task_id, + }, + } + ) + continue + + try: + record = await _task_store_get(resume_task_store, resume_task_id) + except Exception: + logger.warning("TaskStore.get failed during resume", exc_info=True) + record = None + if record is not None: + # P1 #3: Fail-closed ownership check — reject if + # conversation_id is missing from task metadata OR + # does not match the request. + task_conv_id = (record.metadata or {}).get("conversation_id", "") + if not task_conv_id or resume_conv_id != task_conv_id: + logger.warning( + "Resume rejected: conversation_id mismatch " + "(task=%s, request=%s, task_id=%s)", + task_conv_id, + resume_conv_id, + resume_task_id, + ) + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Task does not belong to this conversation.", + "task_id": resume_task_id, + }, + } + ) + continue + if record.status == TaskStatus.COMPLETED: + # Task already finished — send result immediately + output = (record.output_data or {}).get("output", "") + await websocket.send_json( + { + "type": "result", + "data": { + "message": output, + "timestamp": record.completed_at.isoformat() + if record.completed_at + else datetime.now(timezone.utc).isoformat(), + }, + } + ) + continue + elif record.status == TaskStatus.FAILED: + await websocket.send_json( + { + "type": "error", + "data": { + "message": record.error_message or "Task failed", + }, + } + ) + continue + else: + # Task not found in store — cannot resume + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Task not found or has expired. Please retry your request.", + "task_id": resume_task_id, + }, + } + ) + continue + + # Task is still running — subscribe to EventQueue for remaining events. + # H6: if EventQueue is unavailable, inform the client instead of + # silently continuing (which would leave the UI loading forever). + if resume_eq is None: + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Resume not supported (EventQueue unavailable). Please retry your request.", + }, + } + ) + continue + + # C2: bound the subscribe loop with a timeout so a dead + # background task cannot block resume forever. + resume_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600 + try: + async with asyncio.timeout(resume_timeout): + async for event in resume_eq.subscribe(task_id=resume_task_id): + if event.event_type == TaskEventType.TASK_COMPLETED: + response_text = event.data.get("output", EMPTY_LLM_RESPONSE) + await websocket.send_json( + { + "type": "result", + "data": { + "message": response_text, + "timestamp": event.data.get( + "timestamp", + datetime.now(timezone.utc).isoformat(), + ), + }, + } + ) + break + elif event.event_type == TaskEventType.TASK_FAILED: + await websocket.send_json( + { + "type": "error", + "data": { + "message": event.data.get("error", "Unknown error"), + }, + } + ) + break + else: + # P1 #8/#10 fix: step and data are now + # top-level fields in event.data. + await websocket.send_json( + { + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.data.get("step", 0), + "data": { + k: v + for k, v in event.data.items() + if k not in ("step", "timestamp") + }, + "timestamp": event.data.get("timestamp", ""), + }, + } + ) + except TimeoutError: + logger.warning(f"Resume subscribe timed out for task {resume_task_id}") + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Task resume timed out. Please retry your request.", + "task_id": resume_task_id, + }, + } + ) + except RuntimeError as exc: + # P1 #5: subscriber limit reached or EQ closed — send + # a friendly error instead of terminating the connection. + logger.warning("Resume subscribe failed for task %s: %s", resume_task_id, exc) + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Server busy, please retry shortly.", + "task_id": resume_task_id, + }, + } + ) + continue + if msg_type != "chat": continue @@ -744,6 +1151,7 @@ async def portal_websocket(websocket: WebSocket): # (EQ is a side-channel: emit failures never break the WebSocket flow) task_id = str(uuid.uuid4()) event_queue: EventQueue | None = getattr(websocket.app.state, "event_queue", None) + task_store: InMemoryTaskStore | None = getattr(websocket.app.state, "task_store", None) await _emit_event_safe( event_queue, TaskEventType.TASK_CREATED, @@ -844,6 +1252,26 @@ async def portal_websocket(websocket: WebSocket): }, ) + # Register task in TaskStore for status tracking and recovery + if task_store is not None: + try: + await _task_store_create( + task_store, + task_id=task_id, + agent_name=routing_result.agent_name or "default", + input_data={"message": message_text}, + skill_name=routing_result.skill_name, + ) + # Store conversation_id in metadata for frontend recovery + await _task_store_update_status( + task_store, + task_id, + TaskStatus.PENDING, + metadata={"conversation_id": conv.id}, + ) + except Exception: + logger.warning("Failed to register task in TaskStore", exc_info=True) + # Execute based on routing result's execution_mode # This is the single source of truth for path selection, # replacing fragile string-matching on match_method. @@ -870,6 +1298,21 @@ async def portal_websocket(websocket: WebSocket): response_content = _ensure_non_empty(response.content) await _conversation_store.add_message(conv.id, "assistant", response_content) + # Update TaskStore status to COMPLETED + if task_store is not None: + try: + await _task_store_update_status( + task_store, + task_id, + TaskStatus.COMPLETED, + output_data={"output": response_content}, + completed_at=datetime.now(timezone.utc), + progress=1.0, + progress_message="Completed", + ) + except Exception: + logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True) + # Emit turn.final_answer and task.completed to EQ await _emit_event_safe( event_queue, @@ -890,8 +1333,7 @@ async def portal_websocket(websocket: WebSocket): { "type": "result", "data": { - "status": "completed", - "content": response_content, + "message": response_content, "timestamp": datetime.now(timezone.utc).isoformat(), }, } @@ -1010,99 +1452,154 @@ async def portal_websocket(websocket: WebSocket): f"[portal] agent='{agent_name}' tools={len(tools)} " f"[{', '.join(t.name for t in tools)}] model={model}" ) - collected_output: list[str] = [] - try: - async for event in react_engine.execute_stream( + + # Start ReAct execution as a background task, decoupled from + # WebSocket lifecycle. When the WebSocket disconnects, the + # background task continues running and persists the result. + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=react_engine, messages=messages, tools=tools, model=model, agent_name=agent.name, system_prompt=system_prompt, timeout_seconds=timeout_seconds, - ): - if event.event_type == "final_answer": - collected_output.append(event.data.get("output", "")) - - # Map ReAct event types to TurnEventType and emit to EQ - # (side-channel: failures are swallowed by _emit_event_safe) - _turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type) - if _turn_event_type is not None: - await _emit_event_safe( - event_queue, - _turn_event_type, - task_id=task_id, - session_id=conv.id, - data=event.data, - ) - - await websocket.send_json( - { - "type": "step", - "data": { - "event_type": event.event_type, - "step": event.step, - "data": event.data, - "timestamp": event.timestamp, - }, - } - ) - except Exception as e: - # Emit task.failed to EQ before sending error to WebSocket - await _emit_event_safe( - event_queue, - TaskEventType.TASK_FAILED, + conv_id=conv.id, task_id=task_id, - session_id=conv.id, - data={"error": str(e)}, + event_queue=event_queue, + conversation_store=_conversation_store, + task_store=task_store, ) - await websocket.send_json({"type": "error", "data": {"message": str(e)}}) + ) + _running_background_tasks.add(bg_task) + bg_task.add_done_callback(_running_background_tasks.discard) + active_bg_task = bg_task + + # C1 guard: EventQueue is required for subscribe; fall back to + # awaiting the background task directly if unavailable. + if event_queue is None: + logger.warning("EventQueue not configured; awaiting background task directly") + try: + await bg_task + except Exception: + pass # errors handled inside _execute_react_background + active_bg_task = None continue - response_text = _ensure_non_empty( - "".join(collected_output) if collected_output else None - ) - await _conversation_store.add_message(conv.id, "assistant", response_text) - - outcome = "success" if response_text != EMPTY_LLM_RESPONSE else "failure" - - # Emit task.completed (success) or task.failed (empty response) to EQ - if outcome == "success": - await _emit_event_safe( - event_queue, - TaskEventType.TASK_COMPLETED, - task_id=task_id, - session_id=conv.id, - data={"output": response_text}, + # Subscribe to EventQueue (filtered by task_id) and forward + # events to the WebSocket. When the WebSocket disconnects, + # this loop exits but the background task continues. + # P1 #7 fix: bound the subscribe loop with a timeout so a + # hung background task cannot block the WebSocket forever. + # Matches the resume path's timeout strategy. + _subscribe_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600 + try: + async with asyncio.timeout(_subscribe_timeout): + async for event in event_queue.subscribe(task_id=task_id): + if event.event_type == TaskEventType.TASK_COMPLETED: + response_text = event.data.get("output", EMPTY_LLM_RESPONSE) + await websocket.send_json( + { + "type": "result", + "data": { + "message": response_text, + "timestamp": event.data.get( + "timestamp", + datetime.now(timezone.utc).isoformat(), + ), + }, + } + ) + await _record_experience( + routing_result.skill_name or "agent", + message_text, + "success" if response_text != EMPTY_LLM_RESPONSE else "failure", + (datetime.now(timezone.utc) - start_time).total_seconds(), + ) + break + elif event.event_type == TaskEventType.TASK_FAILED: + await websocket.send_json( + { + "type": "error", + "data": { + "message": event.data.get("error", "Unknown error"), + }, + } + ) + await _record_experience( + routing_result.skill_name or "agent", + message_text, + "failure", + (datetime.now(timezone.utc) - start_time).total_seconds(), + ) + break + else: + # Forward ReAct events as step messages. + # P1 #8/#10 fix: step and data are now top-level + # fields in event.data (no longer nested). + await websocket.send_json( + { + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.data.get("step", 0), + "data": { + k: v + for k, v in event.data.items() + if k not in ("step", "timestamp") + }, + "timestamp": event.data.get("timestamp", ""), + }, + } + ) + except TimeoutError: + logger.warning(f"Subscribe loop timed out for task {task_id}") + if active_bg_task is not None and not active_bg_task.done(): + active_bg_task.cancel() + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Task timed out. Please retry your request.", + "task_id": task_id, + }, + } ) - else: - await _emit_event_safe( - event_queue, - TaskEventType.TASK_FAILED, - task_id=task_id, - session_id=conv.id, - data={"error": "Empty LLM response"}, + except RuntimeError as exc: + # P1 #5: subscriber limit reached or EQ closed — send + # a friendly error instead of terminating the connection. + logger.warning("Subscribe failed for task %s: %s", task_id, exc) + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Server busy, please retry shortly.", + "task_id": task_id, + }, + } ) - await websocket.send_json( - { - "type": "result", - "data": { - "message": response_text, - "timestamp": datetime.now(timezone.utc).isoformat(), - }, - } - ) - await _record_experience( - routing_result.skill_name or "agent", - message_text, - outcome, - (datetime.now(timezone.utc) - start_time).total_seconds(), - ) - except WebSocketDisconnect: logger.debug(f"Portal WebSocket disconnected for conversation {conv.id if conv else 'N/A'}") + # P0 fix: Do NOT cancel the background task on disconnect. + # The entire purpose of the three-layer defense is to let the + # background task continue running and persist the result so the + # frontend can resume it after reconnection. Cancelling here would + # kill the task, lose the full output, and mark it FAILED — + # defeating layers 2 and 3. The task is only cancelled on explicit + # user cancel (msg_type == 'cancel') or application shutdown. except Exception as e: logger.error(f"Portal WebSocket error: {e}") + # P1 #6 fix: Do NOT cancel the background task on connection-level + # errors (ConnectionResetError, BrokenPipeError, etc.). These are + # functionally equivalent to WebSocketDisconnect — the client dropped + # — and the background task must survive to persist its result. + # Only cancel on truly unexpected errors that may have corrupted + # state needed by the background task. + if not isinstance(e, (ConnectionResetError, BrokenPipeError, ConnectionError)): + if active_bg_task is not None and not active_bg_task.done(): + active_bg_task.cancel() # Emit task.failed to EQ if a task was in progress # (task_id is set when a user message is received; None before that) if task_id is not None and conv is not None: diff --git a/tests/unit/server/test_portal_ws_background.py b/tests/unit/server/test_portal_ws_background.py new file mode 100644 index 0000000..4da3089 --- /dev/null +++ b/tests/unit/server/test_portal_ws_background.py @@ -0,0 +1,938 @@ +"""Tests for WebSocket task persistence and background execution (U1-U3). + +Tests cover: +- U1: Partial output persistence on WebSocket disconnect +- U2: Background ReAct task execution with EventQueue event distribution +- U3: TaskStore registration and status tracking +- P0 #2: CancelledError path — partial output persisted, task marked FAILED +- P0 #3: Resume handler — conversation_id ownership verification +- P0 #4: Cancel propagation — explicit cancel marks task FAILED +- P0 #5: WebSocketDisconnect does NOT cancel background task +""" + +from __future__ import annotations + +import asyncio + +from agentkit.core.event_queue import EventQueue +from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType +from agentkit.server.routes.portal import _execute_react_background +from agentkit.server.task_store import InMemoryTaskStore + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class FakeConversationStore: + """Minimal conversation store for testing.""" + + def __init__(self) -> None: + self.messages: list[tuple[str, str, str]] = [] + + async def add_message(self, conv_id: str, role: str, content: str) -> None: + self.messages.append((conv_id, role, content)) + + +class FakeReactEngine: + """Fake ReAct engine that yields events from a predefined list.""" + + def __init__(self, events: list[Event]) -> None: + self._events = events + + async def execute_stream(self, **kwargs): + for event in self._events: + yield event + + +class FailingReactEngine: + """Fake ReAct engine that raises an exception after yielding some events.""" + + def __init__(self, events: list[Event], error: Exception) -> None: + self._events = events + self._error = error + + async def execute_stream(self, **kwargs): + for event in self._events: + yield event + raise self._error + + +def _make_event( + event_type: str, + task_id: str = "test-task", + session_id: str = "test-conv", + data: dict | None = None, +) -> Event: + return Event.create( + event_type=event_type, + task_id=task_id, + session_id=session_id, + data=data or {}, + ) + + +class SlowFakeReactEngine: + """Fake ReAct engine with a delay to allow status checks during execution.""" + + def __init__(self, events: list[Event], delay: float = 0.1) -> None: + self._events = events + self._delay = delay + + async def execute_stream(self, **kwargs): + for event in self._events: + await asyncio.sleep(self._delay) + yield event + + +class CancellableReactEngine: + """Fake ReAct engine that blocks forever until cancelled. + + Yields one event so collected_output is non-empty, then blocks on an + Event so the test can cancel the task and verify CancelledError cleanup. + """ + + def __init__(self, first_event: Event) -> None: + self._first_event = first_event + self.started = asyncio.Event() + + async def execute_stream(self, **kwargs): + yield self._first_event + self.started.set() + # Block forever until cancelled + await asyncio.Event().wait() + + +def _suppress_cancelled(): + """Context manager that suppresses asyncio.CancelledError.""" + import contextlib + + return contextlib.suppress(asyncio.CancelledError) + + +# --------------------------------------------------------------------------- +# U1 + U2: Background task persistence tests +# --------------------------------------------------------------------------- + + +class TestExecuteReactBackground: + """Tests for _execute_react_background (U1 + U2).""" + + async def test_normal_completion_persists_result(self): + """U2: Normal completion persists result to conversation store.""" + events = [ + _make_event("thinking", data={"text": "Let me think..."}), + _make_event("final_answer", data={"output": "The answer is 42"}), + ] + engine = FakeReactEngine(events) + conv_store = FakeConversationStore() + eq = EventQueue() + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + + # Result should be persisted + assert len(conv_store.messages) == 1 + conv_id, role, content = conv_store.messages[0] + assert conv_id == "test-conv" + assert role == "assistant" + assert content == "The answer is 42" + + async def test_partial_output_persisted_on_error(self): + """U1: Partial output is persisted when ReAct engine raises an error.""" + events = [ + _make_event("thinking", data={"text": "Thinking..."}), + _make_event("final_answer", data={"output": "Partial result"}), + ] + error = RuntimeError("LLM timeout") + engine = FailingReactEngine(events, error) + conv_store = FakeConversationStore() + eq = EventQueue() + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + + # Partial output should be persisted + assert len(conv_store.messages) == 1 + _, role, content = conv_store.messages[0] + assert role == "assistant" + assert content == "Partial result" + + async def test_no_output_on_error_without_final_answer(self): + """U1: No message persisted when error occurs before any final_answer.""" + events = [_make_event("thinking", data={"text": "Thinking..."})] + error = RuntimeError("Early failure") + engine = FailingReactEngine(events, error) + conv_store = FakeConversationStore() + eq = EventQueue() + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + + # No assistant message should be persisted (collected_output is empty) + assert len(conv_store.messages) == 0 + + async def test_events_emitted_to_event_queue(self): + """U2: Events are emitted to EventQueue for subscribers.""" + events = [ + _make_event("thinking", data={"text": "Thinking..."}), + _make_event("final_answer", data={"output": "Done"}), + ] + engine = FakeReactEngine(events) + conv_store = FakeConversationStore() + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(task_id="test-task"): + received.append(evt) + if evt.event_type == TaskEventType.TASK_COMPLETED: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + + await asyncio.wait_for(sub_task, timeout=2.0) + + # Should receive thinking, final_answer, and task.completed events + # P1 #9: ReAct event types are mapped to TurnEventType constants + event_types = [e.event_type for e in received] + assert TurnEventType.THINKING in event_types + assert TurnEventType.FINAL_ANSWER in event_types + assert TaskEventType.TASK_COMPLETED in event_types + + async def test_task_failed_event_on_error(self): + """U2: task.failed event is emitted on error.""" + events: list[Event] = [] + error = RuntimeError("Execution failed") + engine = FailingReactEngine(events, error) + conv_store = FakeConversationStore() + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(task_id="test-task"): + received.append(evt) + if evt.event_type == TaskEventType.TASK_FAILED: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + + await asyncio.wait_for(sub_task, timeout=2.0) + + failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED] + assert len(failed_events) == 1 + assert "Execution failed" in failed_events[0].data.get("error", "") + + +# --------------------------------------------------------------------------- +# U3: TaskStore integration tests +# --------------------------------------------------------------------------- + + +class TestTaskStoreIntegration: + """Tests for TaskStore registration and status tracking (U3).""" + + async def test_task_store_status_running_during_execution(self): + """U3: TaskStore status is RUNNING during background execution.""" + events = [ + _make_event("thinking", data={"text": "Thinking..."}), + _make_event("final_answer", data={"output": "Result"}), + ] + engine = SlowFakeReactEngine(events, delay=0.2) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + + task_store.create( + task_id="test-task", + agent_name="test-agent", + input_data={"message": "hello"}, + ) + + # Start background task + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + ) + + # Check status while running (need to yield control) + await asyncio.sleep(0.01) + record = task_store.get("test-task") + assert record is not None + assert record.status == TaskStatus.RUNNING + + await asyncio.wait_for(bg_task, timeout=2.0) + + # After completion, status should be COMPLETED + record = task_store.get("test-task") + assert record is not None + assert record.status == TaskStatus.COMPLETED + assert record.output_data is not None + assert record.output_data.get("output") == "Result" + assert record.progress == 1.0 + + async def test_task_store_status_failed_on_error(self): + """U3: TaskStore status is FAILED when background task raises error.""" + events: list[Event] = [] + error = RuntimeError("Execution failed") + engine = FailingReactEngine(events, error) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + + task_store.create( + task_id="test-task", + agent_name="test-agent", + input_data={"message": "hello"}, + ) + + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + + record = task_store.get("test-task") + assert record is not None + assert record.status == TaskStatus.FAILED + assert record.error_message is not None + assert "Execution failed" in record.error_message + + async def test_task_store_none_does_not_crash(self): + """U3: Passing task_store=None should not crash.""" + events = [_make_event("final_answer", data={"output": "Result"})] + engine = FakeReactEngine(events) + conv_store = FakeConversationStore() + eq = EventQueue() + + # Should not raise + await _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + task_store=None, + ) + + assert len(conv_store.messages) == 1 + + async def test_task_store_list_by_status(self): + """U3: TaskStore list_tasks filters by status correctly.""" + task_store = InMemoryTaskStore() + + # Create tasks in different states + task_store.create("task-1", "agent", {}) + task_store.create("task-2", "agent", {}) + task_store.create("task-3", "agent", {}) + + task_store.update_status("task-2", TaskStatus.RUNNING) + task_store.update_status("task-3", TaskStatus.COMPLETED, progress=1.0) + + running = task_store.list_tasks(status=TaskStatus.RUNNING) + completed = task_store.list_tasks(status=TaskStatus.COMPLETED) + pending = task_store.list_tasks(status=TaskStatus.PENDING) + + assert len(running) == 1 + assert running[0].task_id == "task-2" + assert len(completed) == 1 + assert completed[0].task_id == "task-3" + assert len(pending) == 1 + assert pending[0].task_id == "task-1" + + async def test_task_store_metadata_contains_conversation_id(self): + """U3: TaskStore metadata stores conversation_id for frontend recovery.""" + task_store = InMemoryTaskStore() + task_store.create("task-1", "agent", {"message": "hello"}) + task_store.update_status( + "task-1", + TaskStatus.PENDING, + metadata={"conversation_id": "conv-123"}, + ) + + record = task_store.get("task-1") + assert record is not None + assert record.metadata.get("conversation_id") == "conv-123" + + +# --------------------------------------------------------------------------- +# EventQueue task_id filtering tests (U2) +# --------------------------------------------------------------------------- + + +class TestEventQueueTaskIdFilter: + """Tests for EventQueue subscribe(task_id=...) filtering.""" + + async def test_subscribe_with_task_id_filter(self): + """U2: subscribe(task_id=...) only receives matching events.""" + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(task_id="task-A"): + received.append(evt) + if len(received) >= 2: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + # Emit events for different tasks + await eq.emit(_make_event("thinking", task_id="task-A")) + await eq.emit(_make_event("thinking", task_id="task-B")) # Should be filtered out + await eq.emit(_make_event("final_answer", task_id="task-A")) + + await asyncio.wait_for(sub_task, timeout=2.0) + + # Should only receive task-A events + assert len(received) == 2 + assert all(e.task_id == "task-A" for e in received) + + async def test_subscribe_without_filter_receives_all(self): + """U2: subscribe() without task_id receives all events (backward compat).""" + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + if len(received) >= 3: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + await eq.emit(_make_event("thinking", task_id="task-A")) + await eq.emit(_make_event("thinking", task_id="task-B")) + await eq.emit(_make_event("final_answer", task_id="task-A")) + + await asyncio.wait_for(sub_task, timeout=2.0) + + # Should receive all events regardless of task_id + assert len(received) == 3 + + async def test_subscribe_replays_buffer_filtered(self): + """U2: Buffer replay respects task_id filter.""" + eq = EventQueue() + + # Emit events before subscribing + await eq.emit(_make_event("thinking", task_id="task-A")) + await eq.emit(_make_event("thinking", task_id="task-B")) + await eq.emit(_make_event("final_answer", task_id="task-A")) + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(task_id="task-A"): + received.append(evt) + if len(received) >= 2: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.wait_for(sub_task, timeout=2.0) + + # Should only replay task-A events from buffer + assert len(received) == 2 + assert all(e.task_id == "task-A" for e in received) + + +# --------------------------------------------------------------------------- +# P0 #2: CancelledError path tests +# --------------------------------------------------------------------------- + + +class TestCancelledErrorPath: + """P0 #2: Verify CancelledError cleanup persists partial output and + marks the task as FAILED, and TASK_FAILED event is emitted.""" + + async def test_cancel_persists_partial_output(self): + """P0 #2: When task is cancelled mid-execution, partial output + collected before cancellation is persisted to conversation store.""" + first_event = _make_event("final_answer", data={"output": "Partial before cancel"}) + engine = CancellableReactEngine(first_event) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + task_store.create("test-task", "test-agent", {"message": "hello"}) + + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + ) + + # Wait for the engine to yield its first event + await asyncio.wait_for(engine.started.wait(), timeout=2.0) + bg_task.cancel() + with self._expect_cancelled(): + await bg_task + + # Partial output should be persisted + assert len(conv_store.messages) == 1 + _, role, content = conv_store.messages[0] + assert role == "assistant" + assert content == "Partial before cancel" + + async def test_cancel_marks_task_failed_in_store(self): + """P0 #2: CancelledError marks task status as FAILED in TaskStore.""" + first_event = _make_event("final_answer", data={"output": "Partial"}) + engine = CancellableReactEngine(first_event) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + task_store.create("test-task", "test-agent", {"message": "hello"}) + + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + ) + + await asyncio.wait_for(engine.started.wait(), timeout=2.0) + bg_task.cancel() + with self._expect_cancelled(): + await bg_task + + record = task_store.get("test-task") + assert record is not None + assert record.status == TaskStatus.FAILED + assert record.error_message is not None + assert "cancelled" in record.error_message.lower() + + async def test_cancel_emits_task_failed_event(self): + """P0 #2: CancelledError emits TASK_FAILED event to EventQueue.""" + first_event = _make_event("final_answer", data={"output": "Partial"}) + engine = CancellableReactEngine(first_event) + conv_store = FakeConversationStore() + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(task_id="test-task"): + received.append(evt) + if evt.event_type == TaskEventType.TASK_FAILED: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + ) + + await asyncio.wait_for(engine.started.wait(), timeout=2.0) + bg_task.cancel() + with self._expect_cancelled(): + await bg_task + + await asyncio.wait_for(sub_task, timeout=2.0) + + failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED] + assert len(failed_events) == 1 + assert "cancelled" in failed_events[0].data.get("error", "").lower() + + @staticmethod + def _expect_cancelled(): + """Context manager that expects asyncio.CancelledError to be raised.""" + import contextlib + + return contextlib.suppress(asyncio.CancelledError) + + +# --------------------------------------------------------------------------- +# P0 #3: Resume handler conversation_id ownership verification tests +# --------------------------------------------------------------------------- + + +class TestResumeOwnershipVerification: + """P0 #3: Verify resume path rejects tasks from a different conversation. + + These tests exercise the TaskStore metadata check directly, since the + WebSocket resume handler reads metadata to verify ownership. + """ + + async def test_resume_rejects_mismatched_conversation_id(self): + """P0 #3: Task with conversation_id mismatch should be rejected. + + Simulates the metadata check performed in portal.py resume handler: + if record.metadata['conversation_id'] != request conversation_id, + the resume is rejected with an error. + """ + task_store = InMemoryTaskStore() + task_store.create("task-X", "agent", {"message": "hello"}) + task_store.update_status( + "task-X", + TaskStatus.RUNNING, + metadata={"conversation_id": "conv-A"}, + ) + + record = task_store.get("task-X") + assert record is not None + + # Simulate the ownership check from portal.py resume handler + task_conv_id = (record.metadata or {}).get("conversation_id", "") + request_conv_id = "conv-B" # Different conversation + + assert task_conv_id != request_conv_id + # In portal.py, this mismatch triggers an error response and `continue` + + async def test_resume_accepts_matching_conversation_id(self): + """P0 #3: Task with matching conversation_id should be allowed to resume.""" + task_store = InMemoryTaskStore() + task_store.create("task-Y", "agent", {"message": "hello"}) + task_store.update_status( + "task-Y", + TaskStatus.RUNNING, + metadata={"conversation_id": "conv-A"}, + ) + + record = task_store.get("task-Y") + assert record is not None + + task_conv_id = (record.metadata or {}).get("conversation_id", "") + request_conv_id = "conv-A" # Same conversation + + assert task_conv_id == request_conv_id + # In portal.py, this passes the ownership check and proceeds to subscribe + + async def test_resume_rejects_when_metadata_missing(self): + """P1 #3: When task metadata has no conversation_id, resume is + rejected (fail-closed). Previously this was allowed for backward + compatibility, but the security review identified this as a + bypass vector — an attacker can omit conversation_id to subscribe + to any task's events.""" + task_store = InMemoryTaskStore() + task_store.create("task-Z", "agent", {"message": "hello"}) + # No metadata update — metadata defaults to {} + + record = task_store.get("task-Z") + assert record is not None + + task_conv_id = (record.metadata or {}).get("conversation_id", "") + request_conv_id = "conv-A" + + # P1 #3 fix: fail-closed — reject if task_conv_id is missing + should_reject = not task_conv_id or task_conv_id != request_conv_id + assert should_reject + + +# --------------------------------------------------------------------------- +# P0 #4: Cancel propagation tests +# --------------------------------------------------------------------------- + + +class TestCancelPropagation: + """P0 #4: Verify explicit cancel (msg_type == 'cancel') propagates + correctly to the background task and marks it FAILED.""" + + async def test_explicit_cancel_marks_task_failed(self): + """P0 #4: When a background task is explicitly cancelled (simulating + the msg_type == 'cancel' handler), it should propagate CancelledError + and mark the task FAILED with partial output persisted.""" + first_event = _make_event("final_answer", data={"output": "Partial before user cancel"}) + engine = CancellableReactEngine(first_event) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + task_store.create("cancel-task", "agent", {"message": "hello"}) + + # Simulate the background task as portal.py would create it + active_bg_task: asyncio.Task | None = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="cancel-conv", + task_id="cancel-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + ) + + await asyncio.wait_for(engine.started.wait(), timeout=2.0) + + # Simulate the cancel handler: active_bg_task.cancel() + assert active_bg_task is not None + assert not active_bg_task.done() + active_bg_task.cancel() + + with _suppress_cancelled(): + await active_bg_task + + # Verify task is marked FAILED + record = task_store.get("cancel-task") + assert record is not None + assert record.status == TaskStatus.FAILED + + # Verify partial output was persisted + assert len(conv_store.messages) == 1 + _, _, content = conv_store.messages[0] + assert content == "Partial before user cancel" + + async def test_cancel_after_completion_is_noop(self): + """P0 #4: Cancelling an already-completed task is a no-op + (active_bg_task.done() check prevents double-cancel).""" + events = [_make_event("final_answer", data={"output": "Done"})] + engine = FakeReactEngine(events) + conv_store = FakeConversationStore() + eq = EventQueue() + + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + ) + await bg_task + + # Already done — cancel should be a no-op per portal.py guard: + # `if active_bg_task is not None and not active_bg_task.done():` + assert bg_task.done() + # No exception, no state change + assert len(conv_store.messages) == 1 + + +# --------------------------------------------------------------------------- +# P0 #5: WebSocketDisconnect does NOT cancel background task +# --------------------------------------------------------------------------- + + +class TestWebSocketDisconnectNoCancel: + """P0 #5: Verify that WebSocketDisconnect does NOT cancel the background + task — this is the core invariant of the three-layer defense. + + The test simulates the portal.py control flow: a background task is + started, then the WebSocket disconnects (simulated by cancelling the + subscribe loop but NOT the background task). The background task should + continue running and persist its result. + """ + + async def test_disconnect_does_not_cancel_background_task(self): + """P0 #5: After WebSocketDisconnect, the background task continues + running and persists its result to the conversation store.""" + events = [ + _make_event("thinking", data={"text": "Thinking..."}), + _make_event("final_answer", data={"output": "Final result"}), + ] + engine = SlowFakeReactEngine(events, delay=0.2) + conv_store = FakeConversationStore() + eq = EventQueue() + + # Start the background task (as portal.py would) + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="test-conv", + task_id="test-task", + event_queue=eq, + conversation_store=conv_store, + ) + ) + + # Simulate WebSocketDisconnect: the subscribe loop is interrupted, + # but the background task is NOT cancelled (per P0 #1 fix). + # We just stop subscribing — bg_task keeps running. + await asyncio.sleep(0.1) # Let bg_task start + + # Verify bg_task is still running (not cancelled by disconnect) + assert not bg_task.done() + + # Wait for bg_task to complete naturally + await asyncio.wait_for(bg_task, timeout=5.0) + + # Result should be persisted despite "disconnect" + assert len(conv_store.messages) == 1 + _, _, content = conv_store.messages[0] + assert content == "Final result" + + async def test_disconnect_result_available_for_resume(self): + """P0 #5: After disconnect, the completed task's result is available + in TaskStore so a reconnecting client can retrieve it via resume.""" + events = [_make_event("final_answer", data={"output": "Resumable result"})] + engine = FakeReactEngine(events) + conv_store = FakeConversationStore() + eq = EventQueue() + task_store = InMemoryTaskStore() + task_store.create("resume-task", "agent", {"message": "hello"}) + task_store.update_status( + "resume-task", + TaskStatus.RUNNING, + metadata={"conversation_id": "resume-conv"}, + ) + + bg_task = asyncio.create_task( + _execute_react_background( + react_engine=engine, + messages=[], + tools=[], + model="test-model", + agent_name="test-agent", + system_prompt=None, + timeout_seconds=None, + conv_id="resume-conv", + task_id="resume-task", + event_queue=eq, + conversation_store=conv_store, + task_store=task_store, + ) + ) + + # Simulate disconnect: don't cancel bg_task, just let it run + await asyncio.wait_for(bg_task, timeout=5.0) + + # Task should be COMPLETED with output available for resume + record = task_store.get("resume-task") + assert record is not None + assert record.status == TaskStatus.COMPLETED + assert record.output_data is not None + assert record.output_data.get("output") == "Resumable result"