fix: WebSocket task persistence three-layer defense with security hardening

Fix chat history empty content and task stops on refresh. Implements: result persistence on disconnect, task backgrounding via asyncio + EventQueue, frontend reconnection recovery. Security: fail-closed conversation_id ownership, asyncio.shield on CancelledError cleanup, async TaskStore shim, EventQueue subscriber limit, connection error resilience. 23 tests added.
This commit is contained in:
chiguyong 2026-06-17 22:11:51 +08:00
parent 840d1afd6a
commit 5b5291c7e5
6 changed files with 1741 additions and 106 deletions

View File

@ -34,6 +34,20 @@ _CLOSED_SENTINEL: Event = Event(
)
@dataclass
class _Subscriber:
"""Internal subscriber tracking with optional task_id filter."""
queue: asyncio.Queue[Event]
task_id_filter: str | None = None # None = receive all events
def matches(self, event: Event) -> bool:
"""Check if this subscriber should receive the given event."""
if self.task_id_filter is None:
return True
return event.task_id == self.task_id_filter
@dataclass
class Submission:
"""用户提交的任务
@ -151,9 +165,12 @@ class EventQueue:
_MAX_QUEUE_SIZE: int = 1024
_DEFAULT_BUFFER_SIZE: int = 100
# P1 #13 fix: cap total subscribers to prevent resource exhaustion
# from malicious resume floods or runaway client loops.
_MAX_SUBSCRIBERS: int = 1000
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
self._subscribers: list[asyncio.Queue[Event]] = []
self._subscribers: list[_Subscriber] = []
self._buffer: deque[Event] = deque(maxlen=buffer_size)
self._buffer_size = buffer_size
self._closed: bool = False
@ -163,35 +180,53 @@ class EventQueue:
事件会同时写入缓冲区供未来订阅者回放和所有活跃订阅者队列
如果某订阅者队列已满该事件对该订阅者被丢弃不影响其他订阅者
支持按 task_id 过滤只有 task_id 匹配的订阅者才会收到事件
Args:
event: 要推送的事件
"""
self._buffer.append(event)
for queue in self._subscribers:
for sub in self._subscribers:
if not sub.matches(event):
continue
try:
queue.put_nowait(event)
sub.queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("EventQueue subscriber queue full, dropping event")
async def subscribe(self) -> AsyncIterator[Event]:
async def subscribe(self, task_id: str | None = None) -> AsyncIterator[Event]:
"""订阅事件流(异步生成器)
订阅时会先回放缓冲区中的事件然后持续接收新事件
订阅时会先回放缓冲区中的事件 task_id 过滤然后持续接收新事件
每个订阅者获得独立的队列实现广播语义
当队列关闭时生成器结束
Args:
task_id: 可选的任务 ID 过滤器如果提供只接收该任务的 events
None 表示接收所有事件
注意回放和加入订阅者列表在同一同步段内完成 await
保证不会遗漏或重复事件
"""
if self._closed:
return
# P1 #13 fix: enforce subscriber cap to prevent resource exhaustion
# from malicious resume floods or runaway client loops.
if len(self._subscribers) >= self._MAX_SUBSCRIBERS:
logger.error(
"EventQueue subscriber limit reached (%d), rejecting new subscription",
self._MAX_SUBSCRIBERS,
)
raise RuntimeError(f"EventQueue subscriber limit reached ({self._MAX_SUBSCRIBERS})")
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
# 回放缓冲事件(同步操作,无 await保证原子性
for event in list(self._buffer):
if task_id is not None and event.task_id != task_id:
continue
try:
queue.put_nowait(event)
except asyncio.QueueFull:
@ -199,7 +234,8 @@ class EventQueue:
break
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
self._subscribers.append(queue)
sub = _Subscriber(queue=queue, task_id_filter=task_id)
self._subscribers.append(sub)
try:
while True:
@ -208,9 +244,9 @@ class EventQueue:
break
yield event
finally:
# 清理:移除当前订阅者的队列
if queue in self._subscribers:
self._subscribers.remove(queue)
# 清理:移除当前订阅者
if sub in self._subscribers:
self._subscribers.remove(sub)
@property
def subscriber_count(self) -> int:
@ -235,9 +271,9 @@ class EventQueue:
"""
self._closed = True
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
for queue in self._subscribers:
for sub in self._subscribers:
try:
queue.put_nowait(_CLOSED_SENTINEL)
sub.queue.put_nowait(_CLOSED_SENTINEL)
except asyncio.QueueFull:
pass
self._subscribers.clear()

View File

@ -3,6 +3,8 @@ import type {
IChatResponse,
ICapabilitiesResponse,
IConversation,
ITaskRecord,
TaskStatus,
} from './types'
import { BaseApiClient } from './base'
@ -36,6 +38,19 @@ class ApiClient extends BaseApiClient {
return this.request<IConversation>(`/conversations/${id}`)
}
/** Get a task by ID (uses /api/v1/tasks prefix) */
async getTask(taskId: string): Promise<ITaskRecord> {
return this.request<ITaskRecord>(`/api/v1/tasks/${taskId}`)
}
/** List tasks, optionally filtered by status (uses /api/v1/tasks prefix) */
async listTasks(status?: TaskStatus, limit: number = 100): Promise<ITaskRecord[]> {
const params = new URLSearchParams()
if (status) params.set('status', status)
params.set('limit', String(limit))
return this.request<ITaskRecord[]>(`/api/v1/tasks?${params.toString()}`)
}
/** Create a WebSocket connection for real-time chat */
createWebSocket(): WebSocket {
return super.createWebSocket('/ws')

View File

@ -78,6 +78,15 @@ export type WsClientMessage = {
sources?: string[]
conversation_id?: string
model?: string
} | {
type: 'resume'
task_id: string
conversation_id?: string
} | {
type: 'cancel'
task_id?: string
} | {
type: 'ping'
}
/** WebSocket server message types — matches backend portal.py protocol */
@ -132,3 +141,23 @@ export interface IApiError {
message: string
detail?: string
}
/** Task status (matches backend TaskStatus enum) */
export type TaskStatus = 'pending' | 'running' | 'completed' | 'partially_completed' | 'failed' | 'cancelled'
/** Task record (matches backend TaskRecord.to_dict()) */
export interface ITaskRecord {
task_id: string
agent_name: string
skill_name: string | null
input_data: Record<string, unknown>
status: TaskStatus
output_data: Record<string, unknown> | null
error_message: string | null
created_at: string
started_at: string | null
completed_at: string | null
progress: number
progress_message: string
metadata: Record<string, unknown>
}

View File

@ -52,10 +52,37 @@ export const useChatStore = defineStore('chat', () => {
}
}
/** Select a conversation by ID */
function selectConversation(id: string): void {
/** Select a conversation by ID and load its messages */
async function selectConversation(id: string, force = false): Promise<void> {
currentConversationId.value = id
streamingSteps.value = []
// Load full conversation with messages if not already loaded (or when forced)
const conv = conversations.value.find((c) => c.id === id)
if (force || !conv || !conv.messages || conv.messages.length === 0) {
try {
const fullConv = await apiClient.getConversation(id)
if (conv) {
conv.messages = fullConv.messages || []
conv.title = fullConv.title || conv.title
conv.created_at = fullConv.created_at || conv.created_at
conv.updated_at = fullConv.updated_at || conv.updated_at
} else {
// P1 #7 fix: If the conversation is not in the local list (e.g.
// after a page refresh), add it from the fetched data instead
// of silently discarding the result.
conversations.value.unshift({
id: fullConv.id || id,
title: fullConv.title || '新对话',
messages: fullConv.messages || [],
created_at: fullConv.created_at || new Date().toISOString(),
updated_at: fullConv.updated_at || new Date().toISOString(),
})
}
} catch (error) {
console.error('Failed to load conversation messages:', error)
}
}
}
/** Create a new empty conversation */
@ -135,7 +162,7 @@ export const useChatStore = defineStore('chat', () => {
}
/** Send a message via WebSocket for streaming */
function sendWsMessage(message: string, sources?: string[], model?: string): void {
async function sendWsMessage(message: string, sources?: string[], model?: string): Promise<void> {
if (!currentConversationId.value) {
createConversation()
}
@ -143,7 +170,7 @@ export const useChatStore = defineStore('chat', () => {
// Check WebSocket state BEFORE creating messages to avoid duplicates
if (!ws.value || ws.value.readyState !== WebSocket.OPEN) {
// Fallback to REST directly — sendMessage will create its own messages
sendMessage(message, sources)
await sendMessage(message, sources)
return
}
@ -179,7 +206,22 @@ export const useChatStore = defineStore('chat', () => {
model,
}
ws.value.send(JSON.stringify(wsMessage))
// Problem 7: catch send() exceptions (e.g. connection closed mid-send)
try {
ws.value.send(JSON.stringify(wsMessage))
} catch (error) {
console.error('WebSocket send failed, falling back to REST:', error)
// Remove the placeholder messages we just added; sendMessage will re-add them
const conv = conversations.value.find((c) => c.id === conversationId)
if (conv) {
conv.messages = conv.messages.filter(
(m) => m.id !== userMessage.id && m.id !== assistantMessage.id,
)
}
isLoading.value = false
await sendMessage(message, sources)
return
}
// Update conversation title from first user message
const conv = conversations.value.find((c) => c.id === conversationId)
@ -190,12 +232,16 @@ export const useChatStore = defineStore('chat', () => {
/** Connect to WebSocket for real-time streaming */
let _heartbeatTimer: ReturnType<typeof setInterval> | null = null
let _reconnectTimer: ReturnType<typeof setTimeout> | null = null
let _intentionalDisconnect = false
function connectWebSocket(): void {
if (ws.value && ws.value.readyState === WebSocket.OPEN) {
// Problem 6: also skip if already CONNECTING to avoid orphan sockets
if (ws.value && (ws.value.readyState === WebSocket.OPEN || ws.value.readyState === WebSocket.CONNECTING)) {
return
}
_intentionalDisconnect = false
const socket = apiClient.createWebSocket()
socket.onopen = () => {
@ -208,6 +254,8 @@ export const useChatStore = defineStore('chat', () => {
ws.value.send(JSON.stringify({ type: 'ping' }))
}
}, 30000)
// Check for running tasks to resume after reconnection
_recoverTaskAfterReconnect()
}
socket.onmessage = (event: MessageEvent) => {
@ -222,13 +270,22 @@ export const useChatStore = defineStore('chat', () => {
socket.onclose = () => {
isWsConnected.value = false
// P2 #21 fix: reset isLoading to prevent stuck loading state during
// disconnect. _recoverTaskAfterReconnect will re-set it if an active
// task is found after reconnection.
isLoading.value = false
console.log('WebSocket disconnected')
if (_heartbeatTimer) {
clearInterval(_heartbeatTimer)
_heartbeatTimer = null
}
// Problem 1: do not auto-reconnect after an intentional disconnect
if (_intentionalDisconnect) {
return
}
// Auto reconnect after 3 seconds
setTimeout(() => {
if (_reconnectTimer) clearTimeout(_reconnectTimer)
_reconnectTimer = setTimeout(() => {
if (!ws.value || ws.value.readyState === WebSocket.CLOSED) {
connectWebSocket()
}
@ -245,6 +302,11 @@ export const useChatStore = defineStore('chat', () => {
/** Disconnect WebSocket */
function disconnectWebSocket(): void {
_intentionalDisconnect = true
if (_reconnectTimer) {
clearTimeout(_reconnectTimer)
_reconnectTimer = null
}
if (_heartbeatTimer) {
clearInterval(_heartbeatTimer)
_heartbeatTimer = null
@ -256,6 +318,64 @@ export const useChatStore = defineStore('chat', () => {
}
}
/** After WebSocket reconnects, check for running tasks and resume them */
async function _recoverTaskAfterReconnect(): Promise<void> {
if (!currentConversationId.value) return
try {
// Problem 2: query both 'running' and 'pending' tasks — a task may be
// in PENDING state if the background worker hasn't picked it up yet.
const [runningTasks, pendingTasks] = await Promise.all([
apiClient.listTasks('running'),
apiClient.listTasks('pending'),
])
const candidates = [...runningTasks, ...pendingTasks]
const activeTask = candidates.find(
(t) => t.metadata?.conversation_id === currentConversationId.value,
)
if (activeTask && ws.value && ws.value.readyState === WebSocket.OPEN) {
// P1 #12 fix: Clear the last pending assistant message's accumulated
// content before resuming. The EventQueue will replay buffered events
// from the beginning of the task, so any partial content already
// accumulated before disconnect would be duplicated if not cleared.
const conv = conversations.value.find((c) => c.id === currentConversationId.value)
if (conv) {
const lastPendingAssistant = [...conv.messages]
.reverse()
.find((m) => m.role === 'assistant' && m.status === 'pending')
if (lastPendingAssistant) {
updateMessage(currentConversationId.value, lastPendingAssistant.id, {
content: '',
thinking: '',
tool_calls: [],
})
}
}
// Problem 5: only set isLoading if we can actually send the resume
isLoading.value = true
try {
ws.value.send(
JSON.stringify({
type: 'resume',
task_id: activeTask.task_id,
conversation_id: currentConversationId.value,
}),
)
} catch (error) {
console.error('Failed to send resume message:', error)
isLoading.value = false
}
} else {
// No active task — force reload conversation messages (may include
// completed results persisted while disconnected).
await selectConversation(currentConversationId.value, true)
}
} catch (error) {
console.error('Failed to recover task after reconnect:', error)
}
}
// --- Internal helpers ---
/** Get team store lazily — safe to call inside actions after Pinia is installed */

View File

@ -21,7 +21,7 @@ from pydantic import BaseModel
from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.event_queue import EventQueue
from agentkit.core.protocol import Event, TaskEventType, TurnEventType
from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType
from agentkit.core.react import ReActEngine
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
from agentkit.chat.request_preprocessor import RequestPreprocessor
@ -32,20 +32,15 @@ from agentkit.server.routes.evolution_dashboard import (
)
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
from agentkit.server.task_store import InMemoryTaskStore
logger = logging.getLogger(__name__)
router = APIRouter(tags=["portal"])
# Map ReAct engine event_type strings to TurnEventType constants for EQ emission.
# Only events with a corresponding TurnEventType are forwarded to the EQ;
# other events (e.g. "token") are still sent over WebSocket but not duplicated to EQ.
_REACT_EVENT_TYPE_MAP: dict[str, str] = {
"thinking": TurnEventType.THINKING,
"tool_call": TurnEventType.TOOL_CALL,
"tool_result": TurnEventType.TOOL_RESULT,
"final_answer": TurnEventType.FINAL_ANSWER,
}
# Track background ReAct tasks so they are not garbage-collected mid-execution.
# Tasks are removed automatically via add_done_callback when they complete.
_running_background_tasks: set[asyncio.Task] = set()
# ---------------------------------------------------------------------------
# API Key Authentication
@ -95,6 +90,37 @@ async def _emit_event_safe(
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
# P1 #14 fix: TaskStore sync/async compatibility shim.
# InMemoryTaskStore methods are sync; RedisTaskStore methods are async.
# These helpers detect and await coroutines so portal.py works with both.
async def _task_store_create(store, *args, **kwargs):
result = store.create(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return result
async def _task_store_get(store, *args, **kwargs):
result = store.get(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return result
async def _task_store_update_status(store, *args, **kwargs):
result = store.update_status(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return result
async def _task_store_list_tasks(store, *args, **kwargs):
result = store.list_tasks(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return result
async def _verify_api_key(
request: Request,
api_key_header: str | None = Security(_api_key_header),
@ -144,6 +170,19 @@ class Conversation:
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
_conversation_store = SqliteConversationStore()
# P1 #9 fix: ReAct event type -> TurnEventType mapping for EQ subscribers.
# Preserves the original EQ contract so CLI and other subscribers that
# filter on TurnEventType constants (e.g. 'turn.thinking') keep working.
_REACT_EVENT_TYPE_MAP: dict[str, str] = {
"thinking": TurnEventType.THINKING,
"tool_call": TurnEventType.TOOL_CALL,
"tool_result": TurnEventType.TOOL_RESULT,
"token": TurnEventType.TOKEN,
"final_answer": TurnEventType.FINAL_ANSWER,
"error": TurnEventType.TURN_COMPLETED, # best-effort mapping
"confirmation_request": TurnEventType.STEP,
}
# ---------------------------------------------------------------------------
# History injection helper — configurable limit + optional compression
# ---------------------------------------------------------------------------
@ -666,6 +705,171 @@ def _derive_title_from_messages(messages: list) -> str:
return "对话"
async def _execute_react_background(
react_engine: ReActEngine,
messages: list[dict],
tools: list,
model: str,
agent_name: str,
system_prompt: str | None,
timeout_seconds: float | None,
conv_id: str,
task_id: str,
event_queue: EventQueue,
conversation_store: SqliteConversationStore,
task_store: InMemoryTaskStore | None = None,
) -> None:
"""Execute ReAct engine in the background, decoupled from WebSocket lifecycle.
Events are emitted to the EventQueue (filtered by task_id) so that any
subscriber including a reconnected WebSocket can consume them.
Results are always persisted to the conversation store, regardless of
whether a WebSocket subscriber is active.
Task status is tracked in TaskStore when provided.
"""
collected_output: list[str] = []
try:
if task_store is not None:
try:
await _task_store_update_status(
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
)
except Exception:
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
async for event in react_engine.execute_stream(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
system_prompt=system_prompt,
timeout_seconds=timeout_seconds,
):
if event.event_type == "final_answer":
collected_output.append(event.data.get("output", ""))
# P1 #8/#9/#10 fix: Preserve TurnEventType mapping, step field,
# and original data structure for EQ subscriber compatibility.
# Note: Event dataclass has no 'step' field; use getattr for
# compatibility with ReActEngine events that may include it.
_turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type)
if _turn_event_type is not None:
await _emit_event_safe(
event_queue,
_turn_event_type,
task_id=task_id,
session_id=conv_id,
data={
**event.data,
"step": getattr(event, "step", 0),
"timestamp": event.timestamp,
},
)
# Normal completion: persist result
response_text = _ensure_non_empty("".join(collected_output) if collected_output else None)
await conversation_store.add_message(conv_id, "assistant", response_text)
if task_store is not None:
try:
await _task_store_update_status(
task_store,
task_id,
TaskStatus.COMPLETED,
output_data={"output": response_text},
completed_at=datetime.now(timezone.utc),
progress=1.0,
progress_message="Completed",
)
except Exception:
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
# Emit task.completed so subscribers know the task is done
await _emit_event_safe(
event_queue,
TaskEventType.TASK_COMPLETED,
task_id=task_id,
session_id=conv_id,
data={"output": response_text, "timestamp": datetime.now(timezone.utc).isoformat()},
)
except asyncio.CancelledError:
# Application shutdown or explicit cancel — persist partial output
# and mark task as FAILED so resume does not block forever.
# P0 #1/#2 fix: ALL persistence operations must use asyncio.shield
# and the async TaskStore shim. Without shield, a re-entrant
# cancellation kills the cleanup itself; without the shim,
# RedisTaskStore (async) silently drops the coroutine.
if collected_output:
partial = _ensure_non_empty("".join(collected_output))
try:
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
except (Exception, asyncio.CancelledError):
logger.warning("Failed to persist partial output on cancel")
if task_store is not None:
try:
await asyncio.shield(
_task_store_update_status(
task_store,
task_id,
TaskStatus.FAILED,
error_message="Task cancelled",
completed_at=datetime.now(timezone.utc),
)
)
except (Exception, asyncio.CancelledError):
logger.warning("Failed to update TaskStore on cancel", exc_info=True)
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
# Shield it so a re-entrant CancelledError doesn't kill the emit
# and leave subscribers blocked until timeout.
try:
await asyncio.shield(
_emit_event_safe(
event_queue,
TaskEventType.TASK_FAILED,
task_id=task_id,
session_id=conv_id,
data={
"error": "Task cancelled",
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
)
except (Exception, asyncio.CancelledError):
logger.warning("Failed to emit TASK_FAILED on cancel")
raise # Propagate cancellation
except Exception as e:
# Persist any partial output collected before the error
if collected_output:
partial = _ensure_non_empty("".join(collected_output))
try:
await conversation_store.add_message(conv_id, "assistant", partial)
except Exception:
logger.warning("Failed to persist partial output in background task")
if task_store is not None:
try:
await _task_store_update_status(
task_store,
task_id,
TaskStatus.FAILED,
error_message=str(e),
completed_at=datetime.now(timezone.utc),
)
except Exception:
logger.warning("Failed to update TaskStore FAILED", exc_info=True)
# Emit task.failed so subscribers know the task failed
await _emit_event_safe(
event_queue,
TaskEventType.TASK_FAILED,
task_id=task_id,
session_id=conv_id,
data={"error": str(e), "timestamp": datetime.now(timezone.utc).isoformat()},
)
@router.websocket("/portal/ws")
async def portal_websocket(websocket: WebSocket):
"""Real-time chat WebSocket endpoint."""
@ -692,6 +896,8 @@ async def portal_websocket(websocket: WebSocket):
conv: Conversation | None = None
# task_id is per-user-message; tracked here so the outer except can emit task.failed
task_id: str | None = None
# Track the active background task so cancel can propagate to it.
active_bg_task: asyncio.Task | None = None
try:
while True:
@ -710,6 +916,10 @@ async def portal_websocket(websocket: WebSocket):
msg_type = msg.get("type")
if msg_type == "cancel":
# Cancel the active background task if still running
if active_bg_task is not None and not active_bg_task.done():
active_bg_task.cancel()
active_bg_task = None
await websocket.send_json(
{
"type": "result",
@ -725,6 +935,203 @@ async def portal_websocket(websocket: WebSocket):
await websocket.send_json({"type": "pong"})
continue
if msg_type == "resume":
# Frontend reconnected and wants to resume a running task
resume_task_id = msg.get("task_id", "")
if not resume_task_id:
continue
# P1 #3/#4 fix: Fail-closed ownership verification.
# Require conversation_id and TaskStore — reject if either
# is missing, to prevent cross-conversation task hijacking
# via empty conversation_id or unconfigured TaskStore.
resume_conv_id = msg.get("conversation_id", "")
if not resume_conv_id:
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Resume requires conversation_id.",
"task_id": resume_task_id,
},
}
)
continue
resume_task_store: InMemoryTaskStore | None = getattr(
websocket.app.state, "task_store", None
)
resume_eq: EventQueue | None = getattr(websocket.app.state, "event_queue", None)
# P1 #4: Fail-closed if TaskStore is unavailable — cannot
# verify ownership without it.
if resume_task_store is None:
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Resume not supported (TaskStore unavailable). Please retry your request.",
"task_id": resume_task_id,
},
}
)
continue
try:
record = await _task_store_get(resume_task_store, resume_task_id)
except Exception:
logger.warning("TaskStore.get failed during resume", exc_info=True)
record = None
if record is not None:
# P1 #3: Fail-closed ownership check — reject if
# conversation_id is missing from task metadata OR
# does not match the request.
task_conv_id = (record.metadata or {}).get("conversation_id", "")
if not task_conv_id or resume_conv_id != task_conv_id:
logger.warning(
"Resume rejected: conversation_id mismatch "
"(task=%s, request=%s, task_id=%s)",
task_conv_id,
resume_conv_id,
resume_task_id,
)
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Task does not belong to this conversation.",
"task_id": resume_task_id,
},
}
)
continue
if record.status == TaskStatus.COMPLETED:
# Task already finished — send result immediately
output = (record.output_data or {}).get("output", "")
await websocket.send_json(
{
"type": "result",
"data": {
"message": output,
"timestamp": record.completed_at.isoformat()
if record.completed_at
else datetime.now(timezone.utc).isoformat(),
},
}
)
continue
elif record.status == TaskStatus.FAILED:
await websocket.send_json(
{
"type": "error",
"data": {
"message": record.error_message or "Task failed",
},
}
)
continue
else:
# Task not found in store — cannot resume
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Task not found or has expired. Please retry your request.",
"task_id": resume_task_id,
},
}
)
continue
# Task is still running — subscribe to EventQueue for remaining events.
# H6: if EventQueue is unavailable, inform the client instead of
# silently continuing (which would leave the UI loading forever).
if resume_eq is None:
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Resume not supported (EventQueue unavailable). Please retry your request.",
},
}
)
continue
# C2: bound the subscribe loop with a timeout so a dead
# background task cannot block resume forever.
resume_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600
try:
async with asyncio.timeout(resume_timeout):
async for event in resume_eq.subscribe(task_id=resume_task_id):
if event.event_type == TaskEventType.TASK_COMPLETED:
response_text = event.data.get("output", EMPTY_LLM_RESPONSE)
await websocket.send_json(
{
"type": "result",
"data": {
"message": response_text,
"timestamp": event.data.get(
"timestamp",
datetime.now(timezone.utc).isoformat(),
),
},
}
)
break
elif event.event_type == TaskEventType.TASK_FAILED:
await websocket.send_json(
{
"type": "error",
"data": {
"message": event.data.get("error", "Unknown error"),
},
}
)
break
else:
# P1 #8/#10 fix: step and data are now
# top-level fields in event.data.
await websocket.send_json(
{
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.data.get("step", 0),
"data": {
k: v
for k, v in event.data.items()
if k not in ("step", "timestamp")
},
"timestamp": event.data.get("timestamp", ""),
},
}
)
except TimeoutError:
logger.warning(f"Resume subscribe timed out for task {resume_task_id}")
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Task resume timed out. Please retry your request.",
"task_id": resume_task_id,
},
}
)
except RuntimeError as exc:
# P1 #5: subscriber limit reached or EQ closed — send
# a friendly error instead of terminating the connection.
logger.warning("Resume subscribe failed for task %s: %s", resume_task_id, exc)
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Server busy, please retry shortly.",
"task_id": resume_task_id,
},
}
)
continue
if msg_type != "chat":
continue
@ -744,6 +1151,7 @@ async def portal_websocket(websocket: WebSocket):
# (EQ is a side-channel: emit failures never break the WebSocket flow)
task_id = str(uuid.uuid4())
event_queue: EventQueue | None = getattr(websocket.app.state, "event_queue", None)
task_store: InMemoryTaskStore | None = getattr(websocket.app.state, "task_store", None)
await _emit_event_safe(
event_queue,
TaskEventType.TASK_CREATED,
@ -844,6 +1252,26 @@ async def portal_websocket(websocket: WebSocket):
},
)
# Register task in TaskStore for status tracking and recovery
if task_store is not None:
try:
await _task_store_create(
task_store,
task_id=task_id,
agent_name=routing_result.agent_name or "default",
input_data={"message": message_text},
skill_name=routing_result.skill_name,
)
# Store conversation_id in metadata for frontend recovery
await _task_store_update_status(
task_store,
task_id,
TaskStatus.PENDING,
metadata={"conversation_id": conv.id},
)
except Exception:
logger.warning("Failed to register task in TaskStore", exc_info=True)
# Execute based on routing result's execution_mode
# This is the single source of truth for path selection,
# replacing fragile string-matching on match_method.
@ -870,6 +1298,21 @@ async def portal_websocket(websocket: WebSocket):
response_content = _ensure_non_empty(response.content)
await _conversation_store.add_message(conv.id, "assistant", response_content)
# Update TaskStore status to COMPLETED
if task_store is not None:
try:
await _task_store_update_status(
task_store,
task_id,
TaskStatus.COMPLETED,
output_data={"output": response_content},
completed_at=datetime.now(timezone.utc),
progress=1.0,
progress_message="Completed",
)
except Exception:
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
# Emit turn.final_answer and task.completed to EQ
await _emit_event_safe(
event_queue,
@ -890,8 +1333,7 @@ async def portal_websocket(websocket: WebSocket):
{
"type": "result",
"data": {
"status": "completed",
"content": response_content,
"message": response_content,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
}
@ -1010,99 +1452,154 @@ async def portal_websocket(websocket: WebSocket):
f"[portal] agent='{agent_name}' tools={len(tools)} "
f"[{', '.join(t.name for t in tools)}] model={model}"
)
collected_output: list[str] = []
try:
async for event in react_engine.execute_stream(
# Start ReAct execution as a background task, decoupled from
# WebSocket lifecycle. When the WebSocket disconnects, the
# background task continues running and persists the result.
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=react_engine,
messages=messages,
tools=tools,
model=model,
agent_name=agent.name,
system_prompt=system_prompt,
timeout_seconds=timeout_seconds,
):
if event.event_type == "final_answer":
collected_output.append(event.data.get("output", ""))
# Map ReAct event types to TurnEventType and emit to EQ
# (side-channel: failures are swallowed by _emit_event_safe)
_turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type)
if _turn_event_type is not None:
await _emit_event_safe(
event_queue,
_turn_event_type,
task_id=task_id,
session_id=conv.id,
data=event.data,
)
await websocket.send_json(
{
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
"timestamp": event.timestamp,
},
}
)
except Exception as e:
# Emit task.failed to EQ before sending error to WebSocket
await _emit_event_safe(
event_queue,
TaskEventType.TASK_FAILED,
conv_id=conv.id,
task_id=task_id,
session_id=conv.id,
data={"error": str(e)},
event_queue=event_queue,
conversation_store=_conversation_store,
task_store=task_store,
)
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
)
_running_background_tasks.add(bg_task)
bg_task.add_done_callback(_running_background_tasks.discard)
active_bg_task = bg_task
# C1 guard: EventQueue is required for subscribe; fall back to
# awaiting the background task directly if unavailable.
if event_queue is None:
logger.warning("EventQueue not configured; awaiting background task directly")
try:
await bg_task
except Exception:
pass # errors handled inside _execute_react_background
active_bg_task = None
continue
response_text = _ensure_non_empty(
"".join(collected_output) if collected_output else None
)
await _conversation_store.add_message(conv.id, "assistant", response_text)
outcome = "success" if response_text != EMPTY_LLM_RESPONSE else "failure"
# Emit task.completed (success) or task.failed (empty response) to EQ
if outcome == "success":
await _emit_event_safe(
event_queue,
TaskEventType.TASK_COMPLETED,
task_id=task_id,
session_id=conv.id,
data={"output": response_text},
# Subscribe to EventQueue (filtered by task_id) and forward
# events to the WebSocket. When the WebSocket disconnects,
# this loop exits but the background task continues.
# P1 #7 fix: bound the subscribe loop with a timeout so a
# hung background task cannot block the WebSocket forever.
# Matches the resume path's timeout strategy.
_subscribe_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600
try:
async with asyncio.timeout(_subscribe_timeout):
async for event in event_queue.subscribe(task_id=task_id):
if event.event_type == TaskEventType.TASK_COMPLETED:
response_text = event.data.get("output", EMPTY_LLM_RESPONSE)
await websocket.send_json(
{
"type": "result",
"data": {
"message": response_text,
"timestamp": event.data.get(
"timestamp",
datetime.now(timezone.utc).isoformat(),
),
},
}
)
await _record_experience(
routing_result.skill_name or "agent",
message_text,
"success" if response_text != EMPTY_LLM_RESPONSE else "failure",
(datetime.now(timezone.utc) - start_time).total_seconds(),
)
break
elif event.event_type == TaskEventType.TASK_FAILED:
await websocket.send_json(
{
"type": "error",
"data": {
"message": event.data.get("error", "Unknown error"),
},
}
)
await _record_experience(
routing_result.skill_name or "agent",
message_text,
"failure",
(datetime.now(timezone.utc) - start_time).total_seconds(),
)
break
else:
# Forward ReAct events as step messages.
# P1 #8/#10 fix: step and data are now top-level
# fields in event.data (no longer nested).
await websocket.send_json(
{
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.data.get("step", 0),
"data": {
k: v
for k, v in event.data.items()
if k not in ("step", "timestamp")
},
"timestamp": event.data.get("timestamp", ""),
},
}
)
except TimeoutError:
logger.warning(f"Subscribe loop timed out for task {task_id}")
if active_bg_task is not None and not active_bg_task.done():
active_bg_task.cancel()
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Task timed out. Please retry your request.",
"task_id": task_id,
},
}
)
else:
await _emit_event_safe(
event_queue,
TaskEventType.TASK_FAILED,
task_id=task_id,
session_id=conv.id,
data={"error": "Empty LLM response"},
except RuntimeError as exc:
# P1 #5: subscriber limit reached or EQ closed — send
# a friendly error instead of terminating the connection.
logger.warning("Subscribe failed for task %s: %s", task_id, exc)
await websocket.send_json(
{
"type": "error",
"data": {
"message": "Server busy, please retry shortly.",
"task_id": task_id,
},
}
)
await websocket.send_json(
{
"type": "result",
"data": {
"message": response_text,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
}
)
await _record_experience(
routing_result.skill_name or "agent",
message_text,
outcome,
(datetime.now(timezone.utc) - start_time).total_seconds(),
)
except WebSocketDisconnect:
logger.debug(f"Portal WebSocket disconnected for conversation {conv.id if conv else 'N/A'}")
# P0 fix: Do NOT cancel the background task on disconnect.
# The entire purpose of the three-layer defense is to let the
# background task continue running and persist the result so the
# frontend can resume it after reconnection. Cancelling here would
# kill the task, lose the full output, and mark it FAILED —
# defeating layers 2 and 3. The task is only cancelled on explicit
# user cancel (msg_type == 'cancel') or application shutdown.
except Exception as e:
logger.error(f"Portal WebSocket error: {e}")
# P1 #6 fix: Do NOT cancel the background task on connection-level
# errors (ConnectionResetError, BrokenPipeError, etc.). These are
# functionally equivalent to WebSocketDisconnect — the client dropped
# — and the background task must survive to persist its result.
# Only cancel on truly unexpected errors that may have corrupted
# state needed by the background task.
if not isinstance(e, (ConnectionResetError, BrokenPipeError, ConnectionError)):
if active_bg_task is not None and not active_bg_task.done():
active_bg_task.cancel()
# Emit task.failed to EQ if a task was in progress
# (task_id is set when a user message is received; None before that)
if task_id is not None and conv is not None:

View File

@ -0,0 +1,938 @@
"""Tests for WebSocket task persistence and background execution (U1-U3).
Tests cover:
- U1: Partial output persistence on WebSocket disconnect
- U2: Background ReAct task execution with EventQueue event distribution
- U3: TaskStore registration and status tracking
- P0 #2: CancelledError path — partial output persisted, task marked FAILED
- P0 #3: Resume handler — conversation_id ownership verification
- P0 #4: Cancel propagation — explicit cancel marks task FAILED
- P0 #5: WebSocketDisconnect does NOT cancel background task
"""
from __future__ import annotations
import asyncio
from agentkit.core.event_queue import EventQueue
from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType
from agentkit.server.routes.portal import _execute_react_background
from agentkit.server.task_store import InMemoryTaskStore
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class FakeConversationStore:
"""Minimal conversation store for testing."""
def __init__(self) -> None:
self.messages: list[tuple[str, str, str]] = []
async def add_message(self, conv_id: str, role: str, content: str) -> None:
self.messages.append((conv_id, role, content))
class FakeReactEngine:
"""Fake ReAct engine that yields events from a predefined list."""
def __init__(self, events: list[Event]) -> None:
self._events = events
async def execute_stream(self, **kwargs):
for event in self._events:
yield event
class FailingReactEngine:
"""Fake ReAct engine that raises an exception after yielding some events."""
def __init__(self, events: list[Event], error: Exception) -> None:
self._events = events
self._error = error
async def execute_stream(self, **kwargs):
for event in self._events:
yield event
raise self._error
def _make_event(
event_type: str,
task_id: str = "test-task",
session_id: str = "test-conv",
data: dict | None = None,
) -> Event:
return Event.create(
event_type=event_type,
task_id=task_id,
session_id=session_id,
data=data or {},
)
class SlowFakeReactEngine:
"""Fake ReAct engine with a delay to allow status checks during execution."""
def __init__(self, events: list[Event], delay: float = 0.1) -> None:
self._events = events
self._delay = delay
async def execute_stream(self, **kwargs):
for event in self._events:
await asyncio.sleep(self._delay)
yield event
class CancellableReactEngine:
"""Fake ReAct engine that blocks forever until cancelled.
Yields one event so collected_output is non-empty, then blocks on an
Event so the test can cancel the task and verify CancelledError cleanup.
"""
def __init__(self, first_event: Event) -> None:
self._first_event = first_event
self.started = asyncio.Event()
async def execute_stream(self, **kwargs):
yield self._first_event
self.started.set()
# Block forever until cancelled
await asyncio.Event().wait()
def _suppress_cancelled():
"""Context manager that suppresses asyncio.CancelledError."""
import contextlib
return contextlib.suppress(asyncio.CancelledError)
# ---------------------------------------------------------------------------
# U1 + U2: Background task persistence tests
# ---------------------------------------------------------------------------
class TestExecuteReactBackground:
"""Tests for _execute_react_background (U1 + U2)."""
async def test_normal_completion_persists_result(self):
"""U2: Normal completion persists result to conversation store."""
events = [
_make_event("thinking", data={"text": "Let me think..."}),
_make_event("final_answer", data={"output": "The answer is 42"}),
]
engine = FakeReactEngine(events)
conv_store = FakeConversationStore()
eq = EventQueue()
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
# Result should be persisted
assert len(conv_store.messages) == 1
conv_id, role, content = conv_store.messages[0]
assert conv_id == "test-conv"
assert role == "assistant"
assert content == "The answer is 42"
async def test_partial_output_persisted_on_error(self):
"""U1: Partial output is persisted when ReAct engine raises an error."""
events = [
_make_event("thinking", data={"text": "Thinking..."}),
_make_event("final_answer", data={"output": "Partial result"}),
]
error = RuntimeError("LLM timeout")
engine = FailingReactEngine(events, error)
conv_store = FakeConversationStore()
eq = EventQueue()
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
# Partial output should be persisted
assert len(conv_store.messages) == 1
_, role, content = conv_store.messages[0]
assert role == "assistant"
assert content == "Partial result"
async def test_no_output_on_error_without_final_answer(self):
"""U1: No message persisted when error occurs before any final_answer."""
events = [_make_event("thinking", data={"text": "Thinking..."})]
error = RuntimeError("Early failure")
engine = FailingReactEngine(events, error)
conv_store = FakeConversationStore()
eq = EventQueue()
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
# No assistant message should be persisted (collected_output is empty)
assert len(conv_store.messages) == 0
async def test_events_emitted_to_event_queue(self):
"""U2: Events are emitted to EventQueue for subscribers."""
events = [
_make_event("thinking", data={"text": "Thinking..."}),
_make_event("final_answer", data={"output": "Done"}),
]
engine = FakeReactEngine(events)
conv_store = FakeConversationStore()
eq = EventQueue()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe(task_id="test-task"):
received.append(evt)
if evt.event_type == TaskEventType.TASK_COMPLETED:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
await asyncio.wait_for(sub_task, timeout=2.0)
# Should receive thinking, final_answer, and task.completed events
# P1 #9: ReAct event types are mapped to TurnEventType constants
event_types = [e.event_type for e in received]
assert TurnEventType.THINKING in event_types
assert TurnEventType.FINAL_ANSWER in event_types
assert TaskEventType.TASK_COMPLETED in event_types
async def test_task_failed_event_on_error(self):
"""U2: task.failed event is emitted on error."""
events: list[Event] = []
error = RuntimeError("Execution failed")
engine = FailingReactEngine(events, error)
conv_store = FakeConversationStore()
eq = EventQueue()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe(task_id="test-task"):
received.append(evt)
if evt.event_type == TaskEventType.TASK_FAILED:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
await asyncio.wait_for(sub_task, timeout=2.0)
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
assert len(failed_events) == 1
assert "Execution failed" in failed_events[0].data.get("error", "")
# ---------------------------------------------------------------------------
# U3: TaskStore integration tests
# ---------------------------------------------------------------------------
class TestTaskStoreIntegration:
"""Tests for TaskStore registration and status tracking (U3)."""
async def test_task_store_status_running_during_execution(self):
"""U3: TaskStore status is RUNNING during background execution."""
events = [
_make_event("thinking", data={"text": "Thinking..."}),
_make_event("final_answer", data={"output": "Result"}),
]
engine = SlowFakeReactEngine(events, delay=0.2)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create(
task_id="test-task",
agent_name="test-agent",
input_data={"message": "hello"},
)
# Start background task
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
)
# Check status while running (need to yield control)
await asyncio.sleep(0.01)
record = task_store.get("test-task")
assert record is not None
assert record.status == TaskStatus.RUNNING
await asyncio.wait_for(bg_task, timeout=2.0)
# After completion, status should be COMPLETED
record = task_store.get("test-task")
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.output_data is not None
assert record.output_data.get("output") == "Result"
assert record.progress == 1.0
async def test_task_store_status_failed_on_error(self):
"""U3: TaskStore status is FAILED when background task raises error."""
events: list[Event] = []
error = RuntimeError("Execution failed")
engine = FailingReactEngine(events, error)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create(
task_id="test-task",
agent_name="test-agent",
input_data={"message": "hello"},
)
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
record = task_store.get("test-task")
assert record is not None
assert record.status == TaskStatus.FAILED
assert record.error_message is not None
assert "Execution failed" in record.error_message
async def test_task_store_none_does_not_crash(self):
"""U3: Passing task_store=None should not crash."""
events = [_make_event("final_answer", data={"output": "Result"})]
engine = FakeReactEngine(events)
conv_store = FakeConversationStore()
eq = EventQueue()
# Should not raise
await _execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
task_store=None,
)
assert len(conv_store.messages) == 1
async def test_task_store_list_by_status(self):
"""U3: TaskStore list_tasks filters by status correctly."""
task_store = InMemoryTaskStore()
# Create tasks in different states
task_store.create("task-1", "agent", {})
task_store.create("task-2", "agent", {})
task_store.create("task-3", "agent", {})
task_store.update_status("task-2", TaskStatus.RUNNING)
task_store.update_status("task-3", TaskStatus.COMPLETED, progress=1.0)
running = task_store.list_tasks(status=TaskStatus.RUNNING)
completed = task_store.list_tasks(status=TaskStatus.COMPLETED)
pending = task_store.list_tasks(status=TaskStatus.PENDING)
assert len(running) == 1
assert running[0].task_id == "task-2"
assert len(completed) == 1
assert completed[0].task_id == "task-3"
assert len(pending) == 1
assert pending[0].task_id == "task-1"
async def test_task_store_metadata_contains_conversation_id(self):
"""U3: TaskStore metadata stores conversation_id for frontend recovery."""
task_store = InMemoryTaskStore()
task_store.create("task-1", "agent", {"message": "hello"})
task_store.update_status(
"task-1",
TaskStatus.PENDING,
metadata={"conversation_id": "conv-123"},
)
record = task_store.get("task-1")
assert record is not None
assert record.metadata.get("conversation_id") == "conv-123"
# ---------------------------------------------------------------------------
# EventQueue task_id filtering tests (U2)
# ---------------------------------------------------------------------------
class TestEventQueueTaskIdFilter:
"""Tests for EventQueue subscribe(task_id=...) filtering."""
async def test_subscribe_with_task_id_filter(self):
"""U2: subscribe(task_id=...) only receives matching events."""
eq = EventQueue()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe(task_id="task-A"):
received.append(evt)
if len(received) >= 2:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
# Emit events for different tasks
await eq.emit(_make_event("thinking", task_id="task-A"))
await eq.emit(_make_event("thinking", task_id="task-B")) # Should be filtered out
await eq.emit(_make_event("final_answer", task_id="task-A"))
await asyncio.wait_for(sub_task, timeout=2.0)
# Should only receive task-A events
assert len(received) == 2
assert all(e.task_id == "task-A" for e in received)
async def test_subscribe_without_filter_receives_all(self):
"""U2: subscribe() without task_id receives all events (backward compat)."""
eq = EventQueue()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 3:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
await eq.emit(_make_event("thinking", task_id="task-A"))
await eq.emit(_make_event("thinking", task_id="task-B"))
await eq.emit(_make_event("final_answer", task_id="task-A"))
await asyncio.wait_for(sub_task, timeout=2.0)
# Should receive all events regardless of task_id
assert len(received) == 3
async def test_subscribe_replays_buffer_filtered(self):
"""U2: Buffer replay respects task_id filter."""
eq = EventQueue()
# Emit events before subscribing
await eq.emit(_make_event("thinking", task_id="task-A"))
await eq.emit(_make_event("thinking", task_id="task-B"))
await eq.emit(_make_event("final_answer", task_id="task-A"))
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe(task_id="task-A"):
received.append(evt)
if len(received) >= 2:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=2.0)
# Should only replay task-A events from buffer
assert len(received) == 2
assert all(e.task_id == "task-A" for e in received)
# ---------------------------------------------------------------------------
# P0 #2: CancelledError path tests
# ---------------------------------------------------------------------------
class TestCancelledErrorPath:
"""P0 #2: Verify CancelledError cleanup persists partial output and
marks the task as FAILED, and TASK_FAILED event is emitted."""
async def test_cancel_persists_partial_output(self):
"""P0 #2: When task is cancelled mid-execution, partial output
collected before cancellation is persisted to conversation store."""
first_event = _make_event("final_answer", data={"output": "Partial before cancel"})
engine = CancellableReactEngine(first_event)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create("test-task", "test-agent", {"message": "hello"})
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
)
# Wait for the engine to yield its first event
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
bg_task.cancel()
with self._expect_cancelled():
await bg_task
# Partial output should be persisted
assert len(conv_store.messages) == 1
_, role, content = conv_store.messages[0]
assert role == "assistant"
assert content == "Partial before cancel"
async def test_cancel_marks_task_failed_in_store(self):
"""P0 #2: CancelledError marks task status as FAILED in TaskStore."""
first_event = _make_event("final_answer", data={"output": "Partial"})
engine = CancellableReactEngine(first_event)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create("test-task", "test-agent", {"message": "hello"})
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
)
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
bg_task.cancel()
with self._expect_cancelled():
await bg_task
record = task_store.get("test-task")
assert record is not None
assert record.status == TaskStatus.FAILED
assert record.error_message is not None
assert "cancelled" in record.error_message.lower()
async def test_cancel_emits_task_failed_event(self):
"""P0 #2: CancelledError emits TASK_FAILED event to EventQueue."""
first_event = _make_event("final_answer", data={"output": "Partial"})
engine = CancellableReactEngine(first_event)
conv_store = FakeConversationStore()
eq = EventQueue()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe(task_id="test-task"):
received.append(evt)
if evt.event_type == TaskEventType.TASK_FAILED:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
)
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
bg_task.cancel()
with self._expect_cancelled():
await bg_task
await asyncio.wait_for(sub_task, timeout=2.0)
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
assert len(failed_events) == 1
assert "cancelled" in failed_events[0].data.get("error", "").lower()
@staticmethod
def _expect_cancelled():
"""Context manager that expects asyncio.CancelledError to be raised."""
import contextlib
return contextlib.suppress(asyncio.CancelledError)
# ---------------------------------------------------------------------------
# P0 #3: Resume handler conversation_id ownership verification tests
# ---------------------------------------------------------------------------
class TestResumeOwnershipVerification:
"""P0 #3: Verify resume path rejects tasks from a different conversation.
These tests exercise the TaskStore metadata check directly, since the
WebSocket resume handler reads metadata to verify ownership.
"""
async def test_resume_rejects_mismatched_conversation_id(self):
"""P0 #3: Task with conversation_id mismatch should be rejected.
Simulates the metadata check performed in portal.py resume handler:
if record.metadata['conversation_id'] != request conversation_id,
the resume is rejected with an error.
"""
task_store = InMemoryTaskStore()
task_store.create("task-X", "agent", {"message": "hello"})
task_store.update_status(
"task-X",
TaskStatus.RUNNING,
metadata={"conversation_id": "conv-A"},
)
record = task_store.get("task-X")
assert record is not None
# Simulate the ownership check from portal.py resume handler
task_conv_id = (record.metadata or {}).get("conversation_id", "")
request_conv_id = "conv-B" # Different conversation
assert task_conv_id != request_conv_id
# In portal.py, this mismatch triggers an error response and `continue`
async def test_resume_accepts_matching_conversation_id(self):
"""P0 #3: Task with matching conversation_id should be allowed to resume."""
task_store = InMemoryTaskStore()
task_store.create("task-Y", "agent", {"message": "hello"})
task_store.update_status(
"task-Y",
TaskStatus.RUNNING,
metadata={"conversation_id": "conv-A"},
)
record = task_store.get("task-Y")
assert record is not None
task_conv_id = (record.metadata or {}).get("conversation_id", "")
request_conv_id = "conv-A" # Same conversation
assert task_conv_id == request_conv_id
# In portal.py, this passes the ownership check and proceeds to subscribe
async def test_resume_rejects_when_metadata_missing(self):
"""P1 #3: When task metadata has no conversation_id, resume is
rejected (fail-closed). Previously this was allowed for backward
compatibility, but the security review identified this as a
bypass vector an attacker can omit conversation_id to subscribe
to any task's events."""
task_store = InMemoryTaskStore()
task_store.create("task-Z", "agent", {"message": "hello"})
# No metadata update — metadata defaults to {}
record = task_store.get("task-Z")
assert record is not None
task_conv_id = (record.metadata or {}).get("conversation_id", "")
request_conv_id = "conv-A"
# P1 #3 fix: fail-closed — reject if task_conv_id is missing
should_reject = not task_conv_id or task_conv_id != request_conv_id
assert should_reject
# ---------------------------------------------------------------------------
# P0 #4: Cancel propagation tests
# ---------------------------------------------------------------------------
class TestCancelPropagation:
"""P0 #4: Verify explicit cancel (msg_type == 'cancel') propagates
correctly to the background task and marks it FAILED."""
async def test_explicit_cancel_marks_task_failed(self):
"""P0 #4: When a background task is explicitly cancelled (simulating
the msg_type == 'cancel' handler), it should propagate CancelledError
and mark the task FAILED with partial output persisted."""
first_event = _make_event("final_answer", data={"output": "Partial before user cancel"})
engine = CancellableReactEngine(first_event)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create("cancel-task", "agent", {"message": "hello"})
# Simulate the background task as portal.py would create it
active_bg_task: asyncio.Task | None = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="cancel-conv",
task_id="cancel-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
)
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
# Simulate the cancel handler: active_bg_task.cancel()
assert active_bg_task is not None
assert not active_bg_task.done()
active_bg_task.cancel()
with _suppress_cancelled():
await active_bg_task
# Verify task is marked FAILED
record = task_store.get("cancel-task")
assert record is not None
assert record.status == TaskStatus.FAILED
# Verify partial output was persisted
assert len(conv_store.messages) == 1
_, _, content = conv_store.messages[0]
assert content == "Partial before user cancel"
async def test_cancel_after_completion_is_noop(self):
"""P0 #4: Cancelling an already-completed task is a no-op
(active_bg_task.done() check prevents double-cancel)."""
events = [_make_event("final_answer", data={"output": "Done"})]
engine = FakeReactEngine(events)
conv_store = FakeConversationStore()
eq = EventQueue()
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
)
await bg_task
# Already done — cancel should be a no-op per portal.py guard:
# `if active_bg_task is not None and not active_bg_task.done():`
assert bg_task.done()
# No exception, no state change
assert len(conv_store.messages) == 1
# ---------------------------------------------------------------------------
# P0 #5: WebSocketDisconnect does NOT cancel background task
# ---------------------------------------------------------------------------
class TestWebSocketDisconnectNoCancel:
"""P0 #5: Verify that WebSocketDisconnect does NOT cancel the background
task this is the core invariant of the three-layer defense.
The test simulates the portal.py control flow: a background task is
started, then the WebSocket disconnects (simulated by cancelling the
subscribe loop but NOT the background task). The background task should
continue running and persist its result.
"""
async def test_disconnect_does_not_cancel_background_task(self):
"""P0 #5: After WebSocketDisconnect, the background task continues
running and persists its result to the conversation store."""
events = [
_make_event("thinking", data={"text": "Thinking..."}),
_make_event("final_answer", data={"output": "Final result"}),
]
engine = SlowFakeReactEngine(events, delay=0.2)
conv_store = FakeConversationStore()
eq = EventQueue()
# Start the background task (as portal.py would)
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="test-conv",
task_id="test-task",
event_queue=eq,
conversation_store=conv_store,
)
)
# Simulate WebSocketDisconnect: the subscribe loop is interrupted,
# but the background task is NOT cancelled (per P0 #1 fix).
# We just stop subscribing — bg_task keeps running.
await asyncio.sleep(0.1) # Let bg_task start
# Verify bg_task is still running (not cancelled by disconnect)
assert not bg_task.done()
# Wait for bg_task to complete naturally
await asyncio.wait_for(bg_task, timeout=5.0)
# Result should be persisted despite "disconnect"
assert len(conv_store.messages) == 1
_, _, content = conv_store.messages[0]
assert content == "Final result"
async def test_disconnect_result_available_for_resume(self):
"""P0 #5: After disconnect, the completed task's result is available
in TaskStore so a reconnecting client can retrieve it via resume."""
events = [_make_event("final_answer", data={"output": "Resumable result"})]
engine = FakeReactEngine(events)
conv_store = FakeConversationStore()
eq = EventQueue()
task_store = InMemoryTaskStore()
task_store.create("resume-task", "agent", {"message": "hello"})
task_store.update_status(
"resume-task",
TaskStatus.RUNNING,
metadata={"conversation_id": "resume-conv"},
)
bg_task = asyncio.create_task(
_execute_react_background(
react_engine=engine,
messages=[],
tools=[],
model="test-model",
agent_name="test-agent",
system_prompt=None,
timeout_seconds=None,
conv_id="resume-conv",
task_id="resume-task",
event_queue=eq,
conversation_store=conv_store,
task_store=task_store,
)
)
# Simulate disconnect: don't cancel bg_task, just let it run
await asyncio.wait_for(bg_task, timeout=5.0)
# Task should be COMPLETED with output available for resume
record = task_store.get("resume-task")
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.output_data is not None
assert record.output_data.get("output") == "Resumable result"