fix: WebSocket task persistence three-layer defense with security hardening
Fix chat history empty content and task stops on refresh. Implements: result persistence on disconnect, task backgrounding via asyncio + EventQueue, frontend reconnection recovery. Security: fail-closed conversation_id ownership, asyncio.shield on CancelledError cleanup, async TaskStore shim, EventQueue subscriber limit, connection error resilience. 23 tests added.
This commit is contained in:
parent
840d1afd6a
commit
5b5291c7e5
|
|
@ -34,6 +34,20 @@ _CLOSED_SENTINEL: Event = Event(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Subscriber:
|
||||||
|
"""Internal subscriber tracking with optional task_id filter."""
|
||||||
|
|
||||||
|
queue: asyncio.Queue[Event]
|
||||||
|
task_id_filter: str | None = None # None = receive all events
|
||||||
|
|
||||||
|
def matches(self, event: Event) -> bool:
|
||||||
|
"""Check if this subscriber should receive the given event."""
|
||||||
|
if self.task_id_filter is None:
|
||||||
|
return True
|
||||||
|
return event.task_id == self.task_id_filter
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Submission:
|
class Submission:
|
||||||
"""用户提交的任务
|
"""用户提交的任务
|
||||||
|
|
@ -151,9 +165,12 @@ class EventQueue:
|
||||||
|
|
||||||
_MAX_QUEUE_SIZE: int = 1024
|
_MAX_QUEUE_SIZE: int = 1024
|
||||||
_DEFAULT_BUFFER_SIZE: int = 100
|
_DEFAULT_BUFFER_SIZE: int = 100
|
||||||
|
# P1 #13 fix: cap total subscribers to prevent resource exhaustion
|
||||||
|
# from malicious resume floods or runaway client loops.
|
||||||
|
_MAX_SUBSCRIBERS: int = 1000
|
||||||
|
|
||||||
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
|
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
|
||||||
self._subscribers: list[asyncio.Queue[Event]] = []
|
self._subscribers: list[_Subscriber] = []
|
||||||
self._buffer: deque[Event] = deque(maxlen=buffer_size)
|
self._buffer: deque[Event] = deque(maxlen=buffer_size)
|
||||||
self._buffer_size = buffer_size
|
self._buffer_size = buffer_size
|
||||||
self._closed: bool = False
|
self._closed: bool = False
|
||||||
|
|
@ -163,35 +180,53 @@ class EventQueue:
|
||||||
|
|
||||||
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
|
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
|
||||||
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
|
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
|
||||||
|
支持按 task_id 过滤:只有 task_id 匹配的订阅者才会收到事件。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: 要推送的事件
|
event: 要推送的事件
|
||||||
"""
|
"""
|
||||||
self._buffer.append(event)
|
self._buffer.append(event)
|
||||||
for queue in self._subscribers:
|
for sub in self._subscribers:
|
||||||
|
if not sub.matches(event):
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
queue.put_nowait(event)
|
sub.queue.put_nowait(event)
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
logger.warning("EventQueue subscriber queue full, dropping event")
|
logger.warning("EventQueue subscriber queue full, dropping event")
|
||||||
|
|
||||||
async def subscribe(self) -> AsyncIterator[Event]:
|
async def subscribe(self, task_id: str | None = None) -> AsyncIterator[Event]:
|
||||||
"""订阅事件流(异步生成器)
|
"""订阅事件流(异步生成器)
|
||||||
|
|
||||||
订阅时会先回放缓冲区中的事件,然后持续接收新事件。
|
订阅时会先回放缓冲区中的事件(按 task_id 过滤),然后持续接收新事件。
|
||||||
每个订阅者获得独立的队列,实现广播语义。
|
每个订阅者获得独立的队列,实现广播语义。
|
||||||
|
|
||||||
当队列关闭时,生成器结束。
|
当队列关闭时,生成器结束。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 可选的任务 ID 过滤器。如果提供,只接收该任务的 events。
|
||||||
|
None 表示接收所有事件。
|
||||||
|
|
||||||
注意:回放和加入订阅者列表在同一同步段内完成(无 await),
|
注意:回放和加入订阅者列表在同一同步段内完成(无 await),
|
||||||
保证不会遗漏或重复事件。
|
保证不会遗漏或重复事件。
|
||||||
"""
|
"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# P1 #13 fix: enforce subscriber cap to prevent resource exhaustion
|
||||||
|
# from malicious resume floods or runaway client loops.
|
||||||
|
if len(self._subscribers) >= self._MAX_SUBSCRIBERS:
|
||||||
|
logger.error(
|
||||||
|
"EventQueue subscriber limit reached (%d), rejecting new subscription",
|
||||||
|
self._MAX_SUBSCRIBERS,
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"EventQueue subscriber limit reached ({self._MAX_SUBSCRIBERS})")
|
||||||
|
|
||||||
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||||||
|
|
||||||
# 回放缓冲事件(同步操作,无 await,保证原子性)
|
# 回放缓冲事件(同步操作,无 await,保证原子性)
|
||||||
for event in list(self._buffer):
|
for event in list(self._buffer):
|
||||||
|
if task_id is not None and event.task_id != task_id:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
queue.put_nowait(event)
|
queue.put_nowait(event)
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
|
|
@ -199,7 +234,8 @@ class EventQueue:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
|
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
|
||||||
self._subscribers.append(queue)
|
sub = _Subscriber(queue=queue, task_id_filter=task_id)
|
||||||
|
self._subscribers.append(sub)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -208,9 +244,9 @@ class EventQueue:
|
||||||
break
|
break
|
||||||
yield event
|
yield event
|
||||||
finally:
|
finally:
|
||||||
# 清理:移除当前订阅者的队列
|
# 清理:移除当前订阅者
|
||||||
if queue in self._subscribers:
|
if sub in self._subscribers:
|
||||||
self._subscribers.remove(queue)
|
self._subscribers.remove(sub)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subscriber_count(self) -> int:
|
def subscriber_count(self) -> int:
|
||||||
|
|
@ -235,9 +271,9 @@ class EventQueue:
|
||||||
"""
|
"""
|
||||||
self._closed = True
|
self._closed = True
|
||||||
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
|
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
|
||||||
for queue in self._subscribers:
|
for sub in self._subscribers:
|
||||||
try:
|
try:
|
||||||
queue.put_nowait(_CLOSED_SENTINEL)
|
sub.queue.put_nowait(_CLOSED_SENTINEL)
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
pass
|
pass
|
||||||
self._subscribers.clear()
|
self._subscribers.clear()
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ import type {
|
||||||
IChatResponse,
|
IChatResponse,
|
||||||
ICapabilitiesResponse,
|
ICapabilitiesResponse,
|
||||||
IConversation,
|
IConversation,
|
||||||
|
ITaskRecord,
|
||||||
|
TaskStatus,
|
||||||
} from './types'
|
} from './types'
|
||||||
import { BaseApiClient } from './base'
|
import { BaseApiClient } from './base'
|
||||||
|
|
||||||
|
|
@ -36,6 +38,19 @@ class ApiClient extends BaseApiClient {
|
||||||
return this.request<IConversation>(`/conversations/${id}`)
|
return this.request<IConversation>(`/conversations/${id}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Get a task by ID (uses /api/v1/tasks prefix) */
|
||||||
|
async getTask(taskId: string): Promise<ITaskRecord> {
|
||||||
|
return this.request<ITaskRecord>(`/api/v1/tasks/${taskId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** List tasks, optionally filtered by status (uses /api/v1/tasks prefix) */
|
||||||
|
async listTasks(status?: TaskStatus, limit: number = 100): Promise<ITaskRecord[]> {
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
if (status) params.set('status', status)
|
||||||
|
params.set('limit', String(limit))
|
||||||
|
return this.request<ITaskRecord[]>(`/api/v1/tasks?${params.toString()}`)
|
||||||
|
}
|
||||||
|
|
||||||
/** Create a WebSocket connection for real-time chat */
|
/** Create a WebSocket connection for real-time chat */
|
||||||
createWebSocket(): WebSocket {
|
createWebSocket(): WebSocket {
|
||||||
return super.createWebSocket('/ws')
|
return super.createWebSocket('/ws')
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,15 @@ export type WsClientMessage = {
|
||||||
sources?: string[]
|
sources?: string[]
|
||||||
conversation_id?: string
|
conversation_id?: string
|
||||||
model?: string
|
model?: string
|
||||||
|
} | {
|
||||||
|
type: 'resume'
|
||||||
|
task_id: string
|
||||||
|
conversation_id?: string
|
||||||
|
} | {
|
||||||
|
type: 'cancel'
|
||||||
|
task_id?: string
|
||||||
|
} | {
|
||||||
|
type: 'ping'
|
||||||
}
|
}
|
||||||
|
|
||||||
/** WebSocket server message types — matches backend portal.py protocol */
|
/** WebSocket server message types — matches backend portal.py protocol */
|
||||||
|
|
@ -132,3 +141,23 @@ export interface IApiError {
|
||||||
message: string
|
message: string
|
||||||
detail?: string
|
detail?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Task status (matches backend TaskStatus enum) */
|
||||||
|
export type TaskStatus = 'pending' | 'running' | 'completed' | 'partially_completed' | 'failed' | 'cancelled'
|
||||||
|
|
||||||
|
/** Task record (matches backend TaskRecord.to_dict()) */
|
||||||
|
export interface ITaskRecord {
|
||||||
|
task_id: string
|
||||||
|
agent_name: string
|
||||||
|
skill_name: string | null
|
||||||
|
input_data: Record<string, unknown>
|
||||||
|
status: TaskStatus
|
||||||
|
output_data: Record<string, unknown> | null
|
||||||
|
error_message: string | null
|
||||||
|
created_at: string
|
||||||
|
started_at: string | null
|
||||||
|
completed_at: string | null
|
||||||
|
progress: number
|
||||||
|
progress_message: string
|
||||||
|
metadata: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -52,10 +52,37 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Select a conversation by ID */
|
/** Select a conversation by ID and load its messages */
|
||||||
function selectConversation(id: string): void {
|
async function selectConversation(id: string, force = false): Promise<void> {
|
||||||
currentConversationId.value = id
|
currentConversationId.value = id
|
||||||
streamingSteps.value = []
|
streamingSteps.value = []
|
||||||
|
|
||||||
|
// Load full conversation with messages if not already loaded (or when forced)
|
||||||
|
const conv = conversations.value.find((c) => c.id === id)
|
||||||
|
if (force || !conv || !conv.messages || conv.messages.length === 0) {
|
||||||
|
try {
|
||||||
|
const fullConv = await apiClient.getConversation(id)
|
||||||
|
if (conv) {
|
||||||
|
conv.messages = fullConv.messages || []
|
||||||
|
conv.title = fullConv.title || conv.title
|
||||||
|
conv.created_at = fullConv.created_at || conv.created_at
|
||||||
|
conv.updated_at = fullConv.updated_at || conv.updated_at
|
||||||
|
} else {
|
||||||
|
// P1 #7 fix: If the conversation is not in the local list (e.g.
|
||||||
|
// after a page refresh), add it from the fetched data instead
|
||||||
|
// of silently discarding the result.
|
||||||
|
conversations.value.unshift({
|
||||||
|
id: fullConv.id || id,
|
||||||
|
title: fullConv.title || '新对话',
|
||||||
|
messages: fullConv.messages || [],
|
||||||
|
created_at: fullConv.created_at || new Date().toISOString(),
|
||||||
|
updated_at: fullConv.updated_at || new Date().toISOString(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to load conversation messages:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new empty conversation */
|
/** Create a new empty conversation */
|
||||||
|
|
@ -135,7 +162,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Send a message via WebSocket for streaming */
|
/** Send a message via WebSocket for streaming */
|
||||||
function sendWsMessage(message: string, sources?: string[], model?: string): void {
|
async function sendWsMessage(message: string, sources?: string[], model?: string): Promise<void> {
|
||||||
if (!currentConversationId.value) {
|
if (!currentConversationId.value) {
|
||||||
createConversation()
|
createConversation()
|
||||||
}
|
}
|
||||||
|
|
@ -143,7 +170,7 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
// Check WebSocket state BEFORE creating messages to avoid duplicates
|
// Check WebSocket state BEFORE creating messages to avoid duplicates
|
||||||
if (!ws.value || ws.value.readyState !== WebSocket.OPEN) {
|
if (!ws.value || ws.value.readyState !== WebSocket.OPEN) {
|
||||||
// Fallback to REST directly — sendMessage will create its own messages
|
// Fallback to REST directly — sendMessage will create its own messages
|
||||||
sendMessage(message, sources)
|
await sendMessage(message, sources)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,7 +206,22 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
model,
|
model,
|
||||||
}
|
}
|
||||||
|
|
||||||
ws.value.send(JSON.stringify(wsMessage))
|
// Problem 7: catch send() exceptions (e.g. connection closed mid-send)
|
||||||
|
try {
|
||||||
|
ws.value.send(JSON.stringify(wsMessage))
|
||||||
|
} catch (error) {
|
||||||
|
console.error('WebSocket send failed, falling back to REST:', error)
|
||||||
|
// Remove the placeholder messages we just added; sendMessage will re-add them
|
||||||
|
const conv = conversations.value.find((c) => c.id === conversationId)
|
||||||
|
if (conv) {
|
||||||
|
conv.messages = conv.messages.filter(
|
||||||
|
(m) => m.id !== userMessage.id && m.id !== assistantMessage.id,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
isLoading.value = false
|
||||||
|
await sendMessage(message, sources)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Update conversation title from first user message
|
// Update conversation title from first user message
|
||||||
const conv = conversations.value.find((c) => c.id === conversationId)
|
const conv = conversations.value.find((c) => c.id === conversationId)
|
||||||
|
|
@ -190,12 +232,16 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
|
|
||||||
/** Connect to WebSocket for real-time streaming */
|
/** Connect to WebSocket for real-time streaming */
|
||||||
let _heartbeatTimer: ReturnType<typeof setInterval> | null = null
|
let _heartbeatTimer: ReturnType<typeof setInterval> | null = null
|
||||||
|
let _reconnectTimer: ReturnType<typeof setTimeout> | null = null
|
||||||
|
let _intentionalDisconnect = false
|
||||||
|
|
||||||
function connectWebSocket(): void {
|
function connectWebSocket(): void {
|
||||||
if (ws.value && ws.value.readyState === WebSocket.OPEN) {
|
// Problem 6: also skip if already CONNECTING to avoid orphan sockets
|
||||||
|
if (ws.value && (ws.value.readyState === WebSocket.OPEN || ws.value.readyState === WebSocket.CONNECTING)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_intentionalDisconnect = false
|
||||||
const socket = apiClient.createWebSocket()
|
const socket = apiClient.createWebSocket()
|
||||||
|
|
||||||
socket.onopen = () => {
|
socket.onopen = () => {
|
||||||
|
|
@ -208,6 +254,8 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
ws.value.send(JSON.stringify({ type: 'ping' }))
|
ws.value.send(JSON.stringify({ type: 'ping' }))
|
||||||
}
|
}
|
||||||
}, 30000)
|
}, 30000)
|
||||||
|
// Check for running tasks to resume after reconnection
|
||||||
|
_recoverTaskAfterReconnect()
|
||||||
}
|
}
|
||||||
|
|
||||||
socket.onmessage = (event: MessageEvent) => {
|
socket.onmessage = (event: MessageEvent) => {
|
||||||
|
|
@ -222,13 +270,22 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
|
|
||||||
socket.onclose = () => {
|
socket.onclose = () => {
|
||||||
isWsConnected.value = false
|
isWsConnected.value = false
|
||||||
|
// P2 #21 fix: reset isLoading to prevent stuck loading state during
|
||||||
|
// disconnect. _recoverTaskAfterReconnect will re-set it if an active
|
||||||
|
// task is found after reconnection.
|
||||||
|
isLoading.value = false
|
||||||
console.log('WebSocket disconnected')
|
console.log('WebSocket disconnected')
|
||||||
if (_heartbeatTimer) {
|
if (_heartbeatTimer) {
|
||||||
clearInterval(_heartbeatTimer)
|
clearInterval(_heartbeatTimer)
|
||||||
_heartbeatTimer = null
|
_heartbeatTimer = null
|
||||||
}
|
}
|
||||||
|
// Problem 1: do not auto-reconnect after an intentional disconnect
|
||||||
|
if (_intentionalDisconnect) {
|
||||||
|
return
|
||||||
|
}
|
||||||
// Auto reconnect after 3 seconds
|
// Auto reconnect after 3 seconds
|
||||||
setTimeout(() => {
|
if (_reconnectTimer) clearTimeout(_reconnectTimer)
|
||||||
|
_reconnectTimer = setTimeout(() => {
|
||||||
if (!ws.value || ws.value.readyState === WebSocket.CLOSED) {
|
if (!ws.value || ws.value.readyState === WebSocket.CLOSED) {
|
||||||
connectWebSocket()
|
connectWebSocket()
|
||||||
}
|
}
|
||||||
|
|
@ -245,6 +302,11 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
|
|
||||||
/** Disconnect WebSocket */
|
/** Disconnect WebSocket */
|
||||||
function disconnectWebSocket(): void {
|
function disconnectWebSocket(): void {
|
||||||
|
_intentionalDisconnect = true
|
||||||
|
if (_reconnectTimer) {
|
||||||
|
clearTimeout(_reconnectTimer)
|
||||||
|
_reconnectTimer = null
|
||||||
|
}
|
||||||
if (_heartbeatTimer) {
|
if (_heartbeatTimer) {
|
||||||
clearInterval(_heartbeatTimer)
|
clearInterval(_heartbeatTimer)
|
||||||
_heartbeatTimer = null
|
_heartbeatTimer = null
|
||||||
|
|
@ -256,6 +318,64 @@ export const useChatStore = defineStore('chat', () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** After WebSocket reconnects, check for running tasks and resume them */
|
||||||
|
async function _recoverTaskAfterReconnect(): Promise<void> {
|
||||||
|
if (!currentConversationId.value) return
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Problem 2: query both 'running' and 'pending' tasks — a task may be
|
||||||
|
// in PENDING state if the background worker hasn't picked it up yet.
|
||||||
|
const [runningTasks, pendingTasks] = await Promise.all([
|
||||||
|
apiClient.listTasks('running'),
|
||||||
|
apiClient.listTasks('pending'),
|
||||||
|
])
|
||||||
|
const candidates = [...runningTasks, ...pendingTasks]
|
||||||
|
const activeTask = candidates.find(
|
||||||
|
(t) => t.metadata?.conversation_id === currentConversationId.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (activeTask && ws.value && ws.value.readyState === WebSocket.OPEN) {
|
||||||
|
// P1 #12 fix: Clear the last pending assistant message's accumulated
|
||||||
|
// content before resuming. The EventQueue will replay buffered events
|
||||||
|
// from the beginning of the task, so any partial content already
|
||||||
|
// accumulated before disconnect would be duplicated if not cleared.
|
||||||
|
const conv = conversations.value.find((c) => c.id === currentConversationId.value)
|
||||||
|
if (conv) {
|
||||||
|
const lastPendingAssistant = [...conv.messages]
|
||||||
|
.reverse()
|
||||||
|
.find((m) => m.role === 'assistant' && m.status === 'pending')
|
||||||
|
if (lastPendingAssistant) {
|
||||||
|
updateMessage(currentConversationId.value, lastPendingAssistant.id, {
|
||||||
|
content: '',
|
||||||
|
thinking: '',
|
||||||
|
tool_calls: [],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Problem 5: only set isLoading if we can actually send the resume
|
||||||
|
isLoading.value = true
|
||||||
|
try {
|
||||||
|
ws.value.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: 'resume',
|
||||||
|
task_id: activeTask.task_id,
|
||||||
|
conversation_id: currentConversationId.value,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to send resume message:', error)
|
||||||
|
isLoading.value = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No active task — force reload conversation messages (may include
|
||||||
|
// completed results persisted while disconnected).
|
||||||
|
await selectConversation(currentConversationId.value, true)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to recover task after reconnect:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Internal helpers ---
|
// --- Internal helpers ---
|
||||||
|
|
||||||
/** Get team store lazily — safe to call inside actions after Pinia is installed */
|
/** Get team store lazily — safe to call inside actions after Pinia is installed */
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
from agentkit.core.event_queue import EventQueue
|
from agentkit.core.event_queue import EventQueue
|
||||||
from agentkit.core.protocol import Event, TaskEventType, TurnEventType
|
from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||||
|
|
@ -32,20 +32,15 @@ from agentkit.server.routes.evolution_dashboard import (
|
||||||
)
|
)
|
||||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
||||||
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
||||||
|
from agentkit.server.task_store import InMemoryTaskStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["portal"])
|
router = APIRouter(tags=["portal"])
|
||||||
|
|
||||||
# Map ReAct engine event_type strings to TurnEventType constants for EQ emission.
|
# Track background ReAct tasks so they are not garbage-collected mid-execution.
|
||||||
# Only events with a corresponding TurnEventType are forwarded to the EQ;
|
# Tasks are removed automatically via add_done_callback when they complete.
|
||||||
# other events (e.g. "token") are still sent over WebSocket but not duplicated to EQ.
|
_running_background_tasks: set[asyncio.Task] = set()
|
||||||
_REACT_EVENT_TYPE_MAP: dict[str, str] = {
|
|
||||||
"thinking": TurnEventType.THINKING,
|
|
||||||
"tool_call": TurnEventType.TOOL_CALL,
|
|
||||||
"tool_result": TurnEventType.TOOL_RESULT,
|
|
||||||
"final_answer": TurnEventType.FINAL_ANSWER,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# API Key Authentication
|
# API Key Authentication
|
||||||
|
|
@ -95,6 +90,37 @@ async def _emit_event_safe(
|
||||||
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
|
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# P1 #14 fix: TaskStore sync/async compatibility shim.
|
||||||
|
# InMemoryTaskStore methods are sync; RedisTaskStore methods are async.
|
||||||
|
# These helpers detect and await coroutines so portal.py works with both.
|
||||||
|
async def _task_store_create(store, *args, **kwargs):
|
||||||
|
result = store.create(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _task_store_get(store, *args, **kwargs):
|
||||||
|
result = store.get(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _task_store_update_status(store, *args, **kwargs):
|
||||||
|
result = store.update_status(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _task_store_list_tasks(store, *args, **kwargs):
|
||||||
|
result = store.list_tasks(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _verify_api_key(
|
async def _verify_api_key(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key_header: str | None = Security(_api_key_header),
|
api_key_header: str | None = Security(_api_key_header),
|
||||||
|
|
@ -144,6 +170,19 @@ class Conversation:
|
||||||
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
|
_WS_HEARTBEAT_TIMEOUT = float(os.environ.get("AGENTKIT_WS_TIMEOUT", "120"))
|
||||||
_conversation_store = SqliteConversationStore()
|
_conversation_store = SqliteConversationStore()
|
||||||
|
|
||||||
|
# P1 #9 fix: ReAct event type -> TurnEventType mapping for EQ subscribers.
|
||||||
|
# Preserves the original EQ contract so CLI and other subscribers that
|
||||||
|
# filter on TurnEventType constants (e.g. 'turn.thinking') keep working.
|
||||||
|
_REACT_EVENT_TYPE_MAP: dict[str, str] = {
|
||||||
|
"thinking": TurnEventType.THINKING,
|
||||||
|
"tool_call": TurnEventType.TOOL_CALL,
|
||||||
|
"tool_result": TurnEventType.TOOL_RESULT,
|
||||||
|
"token": TurnEventType.TOKEN,
|
||||||
|
"final_answer": TurnEventType.FINAL_ANSWER,
|
||||||
|
"error": TurnEventType.TURN_COMPLETED, # best-effort mapping
|
||||||
|
"confirmation_request": TurnEventType.STEP,
|
||||||
|
}
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# History injection helper — configurable limit + optional compression
|
# History injection helper — configurable limit + optional compression
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -666,6 +705,171 @@ def _derive_title_from_messages(messages: list) -> str:
|
||||||
return "对话"
|
return "对话"
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_react_background(
|
||||||
|
react_engine: ReActEngine,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list,
|
||||||
|
model: str,
|
||||||
|
agent_name: str,
|
||||||
|
system_prompt: str | None,
|
||||||
|
timeout_seconds: float | None,
|
||||||
|
conv_id: str,
|
||||||
|
task_id: str,
|
||||||
|
event_queue: EventQueue,
|
||||||
|
conversation_store: SqliteConversationStore,
|
||||||
|
task_store: InMemoryTaskStore | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute ReAct engine in the background, decoupled from WebSocket lifecycle.
|
||||||
|
|
||||||
|
Events are emitted to the EventQueue (filtered by task_id) so that any
|
||||||
|
subscriber — including a reconnected WebSocket — can consume them.
|
||||||
|
Results are always persisted to the conversation store, regardless of
|
||||||
|
whether a WebSocket subscriber is active.
|
||||||
|
Task status is tracked in TaskStore when provided.
|
||||||
|
"""
|
||||||
|
collected_output: list[str] = []
|
||||||
|
try:
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await _task_store_update_status(
|
||||||
|
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
|
||||||
|
|
||||||
|
async for event in react_engine.execute_stream(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
agent_name=agent_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
):
|
||||||
|
if event.event_type == "final_answer":
|
||||||
|
collected_output.append(event.data.get("output", ""))
|
||||||
|
|
||||||
|
# P1 #8/#9/#10 fix: Preserve TurnEventType mapping, step field,
|
||||||
|
# and original data structure for EQ subscriber compatibility.
|
||||||
|
# Note: Event dataclass has no 'step' field; use getattr for
|
||||||
|
# compatibility with ReActEngine events that may include it.
|
||||||
|
_turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type)
|
||||||
|
if _turn_event_type is not None:
|
||||||
|
await _emit_event_safe(
|
||||||
|
event_queue,
|
||||||
|
_turn_event_type,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=conv_id,
|
||||||
|
data={
|
||||||
|
**event.data,
|
||||||
|
"step": getattr(event, "step", 0),
|
||||||
|
"timestamp": event.timestamp,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normal completion: persist result
|
||||||
|
response_text = _ensure_non_empty("".join(collected_output) if collected_output else None)
|
||||||
|
await conversation_store.add_message(conv_id, "assistant", response_text)
|
||||||
|
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await _task_store_update_status(
|
||||||
|
task_store,
|
||||||
|
task_id,
|
||||||
|
TaskStatus.COMPLETED,
|
||||||
|
output_data={"output": response_text},
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
progress=1.0,
|
||||||
|
progress_message="Completed",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
|
||||||
|
|
||||||
|
# Emit task.completed so subscribers know the task is done
|
||||||
|
await _emit_event_safe(
|
||||||
|
event_queue,
|
||||||
|
TaskEventType.TASK_COMPLETED,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=conv_id,
|
||||||
|
data={"output": response_text, "timestamp": datetime.now(timezone.utc).isoformat()},
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Application shutdown or explicit cancel — persist partial output
|
||||||
|
# and mark task as FAILED so resume does not block forever.
|
||||||
|
# P0 #1/#2 fix: ALL persistence operations must use asyncio.shield
|
||||||
|
# and the async TaskStore shim. Without shield, a re-entrant
|
||||||
|
# cancellation kills the cleanup itself; without the shim,
|
||||||
|
# RedisTaskStore (async) silently drops the coroutine.
|
||||||
|
if collected_output:
|
||||||
|
partial = _ensure_non_empty("".join(collected_output))
|
||||||
|
try:
|
||||||
|
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
|
||||||
|
except (Exception, asyncio.CancelledError):
|
||||||
|
logger.warning("Failed to persist partial output on cancel")
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.shield(
|
||||||
|
_task_store_update_status(
|
||||||
|
task_store,
|
||||||
|
task_id,
|
||||||
|
TaskStatus.FAILED,
|
||||||
|
error_message="Task cancelled",
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (Exception, asyncio.CancelledError):
|
||||||
|
logger.warning("Failed to update TaskStore on cancel", exc_info=True)
|
||||||
|
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
|
||||||
|
# Shield it so a re-entrant CancelledError doesn't kill the emit
|
||||||
|
# and leave subscribers blocked until timeout.
|
||||||
|
try:
|
||||||
|
await asyncio.shield(
|
||||||
|
_emit_event_safe(
|
||||||
|
event_queue,
|
||||||
|
TaskEventType.TASK_FAILED,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=conv_id,
|
||||||
|
data={
|
||||||
|
"error": "Task cancelled",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (Exception, asyncio.CancelledError):
|
||||||
|
logger.warning("Failed to emit TASK_FAILED on cancel")
|
||||||
|
raise # Propagate cancellation
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Persist any partial output collected before the error
|
||||||
|
if collected_output:
|
||||||
|
partial = _ensure_non_empty("".join(collected_output))
|
||||||
|
try:
|
||||||
|
await conversation_store.add_message(conv_id, "assistant", partial)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist partial output in background task")
|
||||||
|
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await _task_store_update_status(
|
||||||
|
task_store,
|
||||||
|
task_id,
|
||||||
|
TaskStatus.FAILED,
|
||||||
|
error_message=str(e),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to update TaskStore FAILED", exc_info=True)
|
||||||
|
|
||||||
|
# Emit task.failed so subscribers know the task failed
|
||||||
|
await _emit_event_safe(
|
||||||
|
event_queue,
|
||||||
|
TaskEventType.TASK_FAILED,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=conv_id,
|
||||||
|
data={"error": str(e), "timestamp": datetime.now(timezone.utc).isoformat()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/portal/ws")
|
@router.websocket("/portal/ws")
|
||||||
async def portal_websocket(websocket: WebSocket):
|
async def portal_websocket(websocket: WebSocket):
|
||||||
"""Real-time chat WebSocket endpoint."""
|
"""Real-time chat WebSocket endpoint."""
|
||||||
|
|
@ -692,6 +896,8 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
conv: Conversation | None = None
|
conv: Conversation | None = None
|
||||||
# task_id is per-user-message; tracked here so the outer except can emit task.failed
|
# task_id is per-user-message; tracked here so the outer except can emit task.failed
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
# Track the active background task so cancel can propagate to it.
|
||||||
|
active_bg_task: asyncio.Task | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -710,6 +916,10 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
msg_type = msg.get("type")
|
msg_type = msg.get("type")
|
||||||
|
|
||||||
if msg_type == "cancel":
|
if msg_type == "cancel":
|
||||||
|
# Cancel the active background task if still running
|
||||||
|
if active_bg_task is not None and not active_bg_task.done():
|
||||||
|
active_bg_task.cancel()
|
||||||
|
active_bg_task = None
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "result",
|
"type": "result",
|
||||||
|
|
@ -725,6 +935,203 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
await websocket.send_json({"type": "pong"})
|
await websocket.send_json({"type": "pong"})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if msg_type == "resume":
|
||||||
|
# Frontend reconnected and wants to resume a running task
|
||||||
|
resume_task_id = msg.get("task_id", "")
|
||||||
|
if not resume_task_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# P1 #3/#4 fix: Fail-closed ownership verification.
|
||||||
|
# Require conversation_id and TaskStore — reject if either
|
||||||
|
# is missing, to prevent cross-conversation task hijacking
|
||||||
|
# via empty conversation_id or unconfigured TaskStore.
|
||||||
|
resume_conv_id = msg.get("conversation_id", "")
|
||||||
|
if not resume_conv_id:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Resume requires conversation_id.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
resume_task_store: InMemoryTaskStore | None = getattr(
|
||||||
|
websocket.app.state, "task_store", None
|
||||||
|
)
|
||||||
|
resume_eq: EventQueue | None = getattr(websocket.app.state, "event_queue", None)
|
||||||
|
|
||||||
|
# P1 #4: Fail-closed if TaskStore is unavailable — cannot
|
||||||
|
# verify ownership without it.
|
||||||
|
if resume_task_store is None:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Resume not supported (TaskStore unavailable). Please retry your request.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
record = await _task_store_get(resume_task_store, resume_task_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("TaskStore.get failed during resume", exc_info=True)
|
||||||
|
record = None
|
||||||
|
if record is not None:
|
||||||
|
# P1 #3: Fail-closed ownership check — reject if
|
||||||
|
# conversation_id is missing from task metadata OR
|
||||||
|
# does not match the request.
|
||||||
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
||||||
|
if not task_conv_id or resume_conv_id != task_conv_id:
|
||||||
|
logger.warning(
|
||||||
|
"Resume rejected: conversation_id mismatch "
|
||||||
|
"(task=%s, request=%s, task_id=%s)",
|
||||||
|
task_conv_id,
|
||||||
|
resume_conv_id,
|
||||||
|
resume_task_id,
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Task does not belong to this conversation.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if record.status == TaskStatus.COMPLETED:
|
||||||
|
# Task already finished — send result immediately
|
||||||
|
output = (record.output_data or {}).get("output", "")
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"data": {
|
||||||
|
"message": output,
|
||||||
|
"timestamp": record.completed_at.isoformat()
|
||||||
|
if record.completed_at
|
||||||
|
else datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif record.status == TaskStatus.FAILED:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": record.error_message or "Task failed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Task not found in store — cannot resume
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Task not found or has expired. Please retry your request.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Task is still running — subscribe to EventQueue for remaining events.
|
||||||
|
# H6: if EventQueue is unavailable, inform the client instead of
|
||||||
|
# silently continuing (which would leave the UI loading forever).
|
||||||
|
if resume_eq is None:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Resume not supported (EventQueue unavailable). Please retry your request.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# C2: bound the subscribe loop with a timeout so a dead
|
||||||
|
# background task cannot block resume forever.
|
||||||
|
resume_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(resume_timeout):
|
||||||
|
async for event in resume_eq.subscribe(task_id=resume_task_id):
|
||||||
|
if event.event_type == TaskEventType.TASK_COMPLETED:
|
||||||
|
response_text = event.data.get("output", EMPTY_LLM_RESPONSE)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"data": {
|
||||||
|
"message": response_text,
|
||||||
|
"timestamp": event.data.get(
|
||||||
|
"timestamp",
|
||||||
|
datetime.now(timezone.utc).isoformat(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif event.event_type == TaskEventType.TASK_FAILED:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": event.data.get("error", "Unknown error"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# P1 #8/#10 fix: step and data are now
|
||||||
|
# top-level fields in event.data.
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "step",
|
||||||
|
"data": {
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"step": event.data.get("step", 0),
|
||||||
|
"data": {
|
||||||
|
k: v
|
||||||
|
for k, v in event.data.items()
|
||||||
|
if k not in ("step", "timestamp")
|
||||||
|
},
|
||||||
|
"timestamp": event.data.get("timestamp", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(f"Resume subscribe timed out for task {resume_task_id}")
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Task resume timed out. Please retry your request.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
# P1 #5: subscriber limit reached or EQ closed — send
|
||||||
|
# a friendly error instead of terminating the connection.
|
||||||
|
logger.warning("Resume subscribe failed for task %s: %s", resume_task_id, exc)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Server busy, please retry shortly.",
|
||||||
|
"task_id": resume_task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
if msg_type != "chat":
|
if msg_type != "chat":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -744,6 +1151,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
# (EQ is a side-channel: emit failures never break the WebSocket flow)
|
# (EQ is a side-channel: emit failures never break the WebSocket flow)
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
event_queue: EventQueue | None = getattr(websocket.app.state, "event_queue", None)
|
event_queue: EventQueue | None = getattr(websocket.app.state, "event_queue", None)
|
||||||
|
task_store: InMemoryTaskStore | None = getattr(websocket.app.state, "task_store", None)
|
||||||
await _emit_event_safe(
|
await _emit_event_safe(
|
||||||
event_queue,
|
event_queue,
|
||||||
TaskEventType.TASK_CREATED,
|
TaskEventType.TASK_CREATED,
|
||||||
|
|
@ -844,6 +1252,26 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register task in TaskStore for status tracking and recovery
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await _task_store_create(
|
||||||
|
task_store,
|
||||||
|
task_id=task_id,
|
||||||
|
agent_name=routing_result.agent_name or "default",
|
||||||
|
input_data={"message": message_text},
|
||||||
|
skill_name=routing_result.skill_name,
|
||||||
|
)
|
||||||
|
# Store conversation_id in metadata for frontend recovery
|
||||||
|
await _task_store_update_status(
|
||||||
|
task_store,
|
||||||
|
task_id,
|
||||||
|
TaskStatus.PENDING,
|
||||||
|
metadata={"conversation_id": conv.id},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to register task in TaskStore", exc_info=True)
|
||||||
|
|
||||||
# Execute based on routing result's execution_mode
|
# Execute based on routing result's execution_mode
|
||||||
# This is the single source of truth for path selection,
|
# This is the single source of truth for path selection,
|
||||||
# replacing fragile string-matching on match_method.
|
# replacing fragile string-matching on match_method.
|
||||||
|
|
@ -870,6 +1298,21 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
response_content = _ensure_non_empty(response.content)
|
response_content = _ensure_non_empty(response.content)
|
||||||
await _conversation_store.add_message(conv.id, "assistant", response_content)
|
await _conversation_store.add_message(conv.id, "assistant", response_content)
|
||||||
|
|
||||||
|
# Update TaskStore status to COMPLETED
|
||||||
|
if task_store is not None:
|
||||||
|
try:
|
||||||
|
await _task_store_update_status(
|
||||||
|
task_store,
|
||||||
|
task_id,
|
||||||
|
TaskStatus.COMPLETED,
|
||||||
|
output_data={"output": response_content},
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
progress=1.0,
|
||||||
|
progress_message="Completed",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
|
||||||
|
|
||||||
# Emit turn.final_answer and task.completed to EQ
|
# Emit turn.final_answer and task.completed to EQ
|
||||||
await _emit_event_safe(
|
await _emit_event_safe(
|
||||||
event_queue,
|
event_queue,
|
||||||
|
|
@ -890,8 +1333,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
{
|
{
|
||||||
"type": "result",
|
"type": "result",
|
||||||
"data": {
|
"data": {
|
||||||
"status": "completed",
|
"message": response_content,
|
||||||
"content": response_content,
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -1010,99 +1452,154 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
f"[portal] agent='{agent_name}' tools={len(tools)} "
|
f"[portal] agent='{agent_name}' tools={len(tools)} "
|
||||||
f"[{', '.join(t.name for t in tools)}] model={model}"
|
f"[{', '.join(t.name for t in tools)}] model={model}"
|
||||||
)
|
)
|
||||||
collected_output: list[str] = []
|
|
||||||
try:
|
# Start ReAct execution as a background task, decoupled from
|
||||||
async for event in react_engine.execute_stream(
|
# WebSocket lifecycle. When the WebSocket disconnects, the
|
||||||
|
# background task continues running and persists the result.
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=react_engine,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=model,
|
model=model,
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
timeout_seconds=timeout_seconds,
|
timeout_seconds=timeout_seconds,
|
||||||
):
|
conv_id=conv.id,
|
||||||
if event.event_type == "final_answer":
|
|
||||||
collected_output.append(event.data.get("output", ""))
|
|
||||||
|
|
||||||
# Map ReAct event types to TurnEventType and emit to EQ
|
|
||||||
# (side-channel: failures are swallowed by _emit_event_safe)
|
|
||||||
_turn_event_type = _REACT_EVENT_TYPE_MAP.get(event.event_type)
|
|
||||||
if _turn_event_type is not None:
|
|
||||||
await _emit_event_safe(
|
|
||||||
event_queue,
|
|
||||||
_turn_event_type,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=conv.id,
|
|
||||||
data=event.data,
|
|
||||||
)
|
|
||||||
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "step",
|
|
||||||
"data": {
|
|
||||||
"event_type": event.event_type,
|
|
||||||
"step": event.step,
|
|
||||||
"data": event.data,
|
|
||||||
"timestamp": event.timestamp,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Emit task.failed to EQ before sending error to WebSocket
|
|
||||||
await _emit_event_safe(
|
|
||||||
event_queue,
|
|
||||||
TaskEventType.TASK_FAILED,
|
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=conv.id,
|
event_queue=event_queue,
|
||||||
data={"error": str(e)},
|
conversation_store=_conversation_store,
|
||||||
|
task_store=task_store,
|
||||||
)
|
)
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
)
|
||||||
|
_running_background_tasks.add(bg_task)
|
||||||
|
bg_task.add_done_callback(_running_background_tasks.discard)
|
||||||
|
active_bg_task = bg_task
|
||||||
|
|
||||||
|
# C1 guard: EventQueue is required for subscribe; fall back to
|
||||||
|
# awaiting the background task directly if unavailable.
|
||||||
|
if event_queue is None:
|
||||||
|
logger.warning("EventQueue not configured; awaiting background task directly")
|
||||||
|
try:
|
||||||
|
await bg_task
|
||||||
|
except Exception:
|
||||||
|
pass # errors handled inside _execute_react_background
|
||||||
|
active_bg_task = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_text = _ensure_non_empty(
|
# Subscribe to EventQueue (filtered by task_id) and forward
|
||||||
"".join(collected_output) if collected_output else None
|
# events to the WebSocket. When the WebSocket disconnects,
|
||||||
)
|
# this loop exits but the background task continues.
|
||||||
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
# P1 #7 fix: bound the subscribe loop with a timeout so a
|
||||||
|
# hung background task cannot block the WebSocket forever.
|
||||||
outcome = "success" if response_text != EMPTY_LLM_RESPONSE else "failure"
|
# Matches the resume path's timeout strategy.
|
||||||
|
_subscribe_timeout = _WS_HEARTBEAT_TIMEOUT * 10 if _WS_HEARTBEAT_TIMEOUT > 0 else 600
|
||||||
# Emit task.completed (success) or task.failed (empty response) to EQ
|
try:
|
||||||
if outcome == "success":
|
async with asyncio.timeout(_subscribe_timeout):
|
||||||
await _emit_event_safe(
|
async for event in event_queue.subscribe(task_id=task_id):
|
||||||
event_queue,
|
if event.event_type == TaskEventType.TASK_COMPLETED:
|
||||||
TaskEventType.TASK_COMPLETED,
|
response_text = event.data.get("output", EMPTY_LLM_RESPONSE)
|
||||||
task_id=task_id,
|
await websocket.send_json(
|
||||||
session_id=conv.id,
|
{
|
||||||
data={"output": response_text},
|
"type": "result",
|
||||||
|
"data": {
|
||||||
|
"message": response_text,
|
||||||
|
"timestamp": event.data.get(
|
||||||
|
"timestamp",
|
||||||
|
datetime.now(timezone.utc).isoformat(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await _record_experience(
|
||||||
|
routing_result.skill_name or "agent",
|
||||||
|
message_text,
|
||||||
|
"success" if response_text != EMPTY_LLM_RESPONSE else "failure",
|
||||||
|
(datetime.now(timezone.utc) - start_time).total_seconds(),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif event.event_type == TaskEventType.TASK_FAILED:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": event.data.get("error", "Unknown error"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await _record_experience(
|
||||||
|
routing_result.skill_name or "agent",
|
||||||
|
message_text,
|
||||||
|
"failure",
|
||||||
|
(datetime.now(timezone.utc) - start_time).total_seconds(),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Forward ReAct events as step messages.
|
||||||
|
# P1 #8/#10 fix: step and data are now top-level
|
||||||
|
# fields in event.data (no longer nested).
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "step",
|
||||||
|
"data": {
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"step": event.data.get("step", 0),
|
||||||
|
"data": {
|
||||||
|
k: v
|
||||||
|
for k, v in event.data.items()
|
||||||
|
if k not in ("step", "timestamp")
|
||||||
|
},
|
||||||
|
"timestamp": event.data.get("timestamp", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(f"Subscribe loop timed out for task {task_id}")
|
||||||
|
if active_bg_task is not None and not active_bg_task.done():
|
||||||
|
active_bg_task.cancel()
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Task timed out. Please retry your request.",
|
||||||
|
"task_id": task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
except RuntimeError as exc:
|
||||||
await _emit_event_safe(
|
# P1 #5: subscriber limit reached or EQ closed — send
|
||||||
event_queue,
|
# a friendly error instead of terminating the connection.
|
||||||
TaskEventType.TASK_FAILED,
|
logger.warning("Subscribe failed for task %s: %s", task_id, exc)
|
||||||
task_id=task_id,
|
await websocket.send_json(
|
||||||
session_id=conv.id,
|
{
|
||||||
data={"error": "Empty LLM response"},
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "Server busy, please retry shortly.",
|
||||||
|
"task_id": task_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "result",
|
|
||||||
"data": {
|
|
||||||
"message": response_text,
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
await _record_experience(
|
|
||||||
routing_result.skill_name or "agent",
|
|
||||||
message_text,
|
|
||||||
outcome,
|
|
||||||
(datetime.now(timezone.utc) - start_time).total_seconds(),
|
|
||||||
)
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug(f"Portal WebSocket disconnected for conversation {conv.id if conv else 'N/A'}")
|
logger.debug(f"Portal WebSocket disconnected for conversation {conv.id if conv else 'N/A'}")
|
||||||
|
# P0 fix: Do NOT cancel the background task on disconnect.
|
||||||
|
# The entire purpose of the three-layer defense is to let the
|
||||||
|
# background task continue running and persist the result so the
|
||||||
|
# frontend can resume it after reconnection. Cancelling here would
|
||||||
|
# kill the task, lose the full output, and mark it FAILED —
|
||||||
|
# defeating layers 2 and 3. The task is only cancelled on explicit
|
||||||
|
# user cancel (msg_type == 'cancel') or application shutdown.
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Portal WebSocket error: {e}")
|
logger.error(f"Portal WebSocket error: {e}")
|
||||||
|
# P1 #6 fix: Do NOT cancel the background task on connection-level
|
||||||
|
# errors (ConnectionResetError, BrokenPipeError, etc.). These are
|
||||||
|
# functionally equivalent to WebSocketDisconnect — the client dropped
|
||||||
|
# — and the background task must survive to persist its result.
|
||||||
|
# Only cancel on truly unexpected errors that may have corrupted
|
||||||
|
# state needed by the background task.
|
||||||
|
if not isinstance(e, (ConnectionResetError, BrokenPipeError, ConnectionError)):
|
||||||
|
if active_bg_task is not None and not active_bg_task.done():
|
||||||
|
active_bg_task.cancel()
|
||||||
# Emit task.failed to EQ if a task was in progress
|
# Emit task.failed to EQ if a task was in progress
|
||||||
# (task_id is set when a user message is received; None before that)
|
# (task_id is set when a user message is received; None before that)
|
||||||
if task_id is not None and conv is not None:
|
if task_id is not None and conv is not None:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,938 @@
|
||||||
|
"""Tests for WebSocket task persistence and background execution (U1-U3).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- U1: Partial output persistence on WebSocket disconnect
|
||||||
|
- U2: Background ReAct task execution with EventQueue event distribution
|
||||||
|
- U3: TaskStore registration and status tracking
|
||||||
|
- P0 #2: CancelledError path — partial output persisted, task marked FAILED
|
||||||
|
- P0 #3: Resume handler — conversation_id ownership verification
|
||||||
|
- P0 #4: Cancel propagation — explicit cancel marks task FAILED
|
||||||
|
- P0 #5: WebSocketDisconnect does NOT cancel background task
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from agentkit.core.event_queue import EventQueue
|
||||||
|
from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType
|
||||||
|
from agentkit.server.routes.portal import _execute_react_background
|
||||||
|
from agentkit.server.task_store import InMemoryTaskStore
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fakes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConversationStore:
|
||||||
|
"""Minimal conversation store for testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.messages: list[tuple[str, str, str]] = []
|
||||||
|
|
||||||
|
async def add_message(self, conv_id: str, role: str, content: str) -> None:
|
||||||
|
self.messages.append((conv_id, role, content))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeReactEngine:
|
||||||
|
"""Fake ReAct engine that yields events from a predefined list."""
|
||||||
|
|
||||||
|
def __init__(self, events: list[Event]) -> None:
|
||||||
|
self._events = events
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
for event in self._events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
class FailingReactEngine:
|
||||||
|
"""Fake ReAct engine that raises an exception after yielding some events."""
|
||||||
|
|
||||||
|
def __init__(self, events: list[Event], error: Exception) -> None:
|
||||||
|
self._events = events
|
||||||
|
self._error = error
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
for event in self._events:
|
||||||
|
yield event
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event(
|
||||||
|
event_type: str,
|
||||||
|
task_id: str = "test-task",
|
||||||
|
session_id: str = "test-conv",
|
||||||
|
data: dict | None = None,
|
||||||
|
) -> Event:
|
||||||
|
return Event.create(
|
||||||
|
event_type=event_type,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
data=data or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SlowFakeReactEngine:
|
||||||
|
"""Fake ReAct engine with a delay to allow status checks during execution."""
|
||||||
|
|
||||||
|
def __init__(self, events: list[Event], delay: float = 0.1) -> None:
|
||||||
|
self._events = events
|
||||||
|
self._delay = delay
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
for event in self._events:
|
||||||
|
await asyncio.sleep(self._delay)
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
class CancellableReactEngine:
|
||||||
|
"""Fake ReAct engine that blocks forever until cancelled.
|
||||||
|
|
||||||
|
Yields one event so collected_output is non-empty, then blocks on an
|
||||||
|
Event so the test can cancel the task and verify CancelledError cleanup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, first_event: Event) -> None:
|
||||||
|
self._first_event = first_event
|
||||||
|
self.started = asyncio.Event()
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
yield self._first_event
|
||||||
|
self.started.set()
|
||||||
|
# Block forever until cancelled
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
|
||||||
|
def _suppress_cancelled():
|
||||||
|
"""Context manager that suppresses asyncio.CancelledError."""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
return contextlib.suppress(asyncio.CancelledError)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U1 + U2: Background task persistence tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteReactBackground:
|
||||||
|
"""Tests for _execute_react_background (U1 + U2)."""
|
||||||
|
|
||||||
|
async def test_normal_completion_persists_result(self):
|
||||||
|
"""U2: Normal completion persists result to conversation store."""
|
||||||
|
events = [
|
||||||
|
_make_event("thinking", data={"text": "Let me think..."}),
|
||||||
|
_make_event("final_answer", data={"output": "The answer is 42"}),
|
||||||
|
]
|
||||||
|
engine = FakeReactEngine(events)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result should be persisted
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
conv_id, role, content = conv_store.messages[0]
|
||||||
|
assert conv_id == "test-conv"
|
||||||
|
assert role == "assistant"
|
||||||
|
assert content == "The answer is 42"
|
||||||
|
|
||||||
|
async def test_partial_output_persisted_on_error(self):
|
||||||
|
"""U1: Partial output is persisted when ReAct engine raises an error."""
|
||||||
|
events = [
|
||||||
|
_make_event("thinking", data={"text": "Thinking..."}),
|
||||||
|
_make_event("final_answer", data={"output": "Partial result"}),
|
||||||
|
]
|
||||||
|
error = RuntimeError("LLM timeout")
|
||||||
|
engine = FailingReactEngine(events, error)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Partial output should be persisted
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
_, role, content = conv_store.messages[0]
|
||||||
|
assert role == "assistant"
|
||||||
|
assert content == "Partial result"
|
||||||
|
|
||||||
|
async def test_no_output_on_error_without_final_answer(self):
|
||||||
|
"""U1: No message persisted when error occurs before any final_answer."""
|
||||||
|
events = [_make_event("thinking", data={"text": "Thinking..."})]
|
||||||
|
error = RuntimeError("Early failure")
|
||||||
|
engine = FailingReactEngine(events, error)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No assistant message should be persisted (collected_output is empty)
|
||||||
|
assert len(conv_store.messages) == 0
|
||||||
|
|
||||||
|
async def test_events_emitted_to_event_queue(self):
|
||||||
|
"""U2: Events are emitted to EventQueue for subscribers."""
|
||||||
|
events = [
|
||||||
|
_make_event("thinking", data={"text": "Thinking..."}),
|
||||||
|
_make_event("final_answer", data={"output": "Done"}),
|
||||||
|
]
|
||||||
|
engine = FakeReactEngine(events)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe(task_id="test-task"):
|
||||||
|
received.append(evt)
|
||||||
|
if evt.event_type == TaskEventType.TASK_COMPLETED:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
# Should receive thinking, final_answer, and task.completed events
|
||||||
|
# P1 #9: ReAct event types are mapped to TurnEventType constants
|
||||||
|
event_types = [e.event_type for e in received]
|
||||||
|
assert TurnEventType.THINKING in event_types
|
||||||
|
assert TurnEventType.FINAL_ANSWER in event_types
|
||||||
|
assert TaskEventType.TASK_COMPLETED in event_types
|
||||||
|
|
||||||
|
async def test_task_failed_event_on_error(self):
|
||||||
|
"""U2: task.failed event is emitted on error."""
|
||||||
|
events: list[Event] = []
|
||||||
|
error = RuntimeError("Execution failed")
|
||||||
|
engine = FailingReactEngine(events, error)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe(task_id="test-task"):
|
||||||
|
received.append(evt)
|
||||||
|
if evt.event_type == TaskEventType.TASK_FAILED:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
|
||||||
|
assert len(failed_events) == 1
|
||||||
|
assert "Execution failed" in failed_events[0].data.get("error", "")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U3: TaskStore integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskStoreIntegration:
|
||||||
|
"""Tests for TaskStore registration and status tracking (U3)."""
|
||||||
|
|
||||||
|
async def test_task_store_status_running_during_execution(self):
|
||||||
|
"""U3: TaskStore status is RUNNING during background execution."""
|
||||||
|
events = [
|
||||||
|
_make_event("thinking", data={"text": "Thinking..."}),
|
||||||
|
_make_event("final_answer", data={"output": "Result"}),
|
||||||
|
]
|
||||||
|
engine = SlowFakeReactEngine(events, delay=0.2)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
|
||||||
|
task_store.create(
|
||||||
|
task_id="test-task",
|
||||||
|
agent_name="test-agent",
|
||||||
|
input_data={"message": "hello"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start background task
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check status while running (need to yield control)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
record = task_store.get("test-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.RUNNING
|
||||||
|
|
||||||
|
await asyncio.wait_for(bg_task, timeout=2.0)
|
||||||
|
|
||||||
|
# After completion, status should be COMPLETED
|
||||||
|
record = task_store.get("test-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.COMPLETED
|
||||||
|
assert record.output_data is not None
|
||||||
|
assert record.output_data.get("output") == "Result"
|
||||||
|
assert record.progress == 1.0
|
||||||
|
|
||||||
|
async def test_task_store_status_failed_on_error(self):
|
||||||
|
"""U3: TaskStore status is FAILED when background task raises error."""
|
||||||
|
events: list[Event] = []
|
||||||
|
error = RuntimeError("Execution failed")
|
||||||
|
engine = FailingReactEngine(events, error)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
|
||||||
|
task_store.create(
|
||||||
|
task_id="test-task",
|
||||||
|
agent_name="test-agent",
|
||||||
|
input_data={"message": "hello"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
record = task_store.get("test-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.FAILED
|
||||||
|
assert record.error_message is not None
|
||||||
|
assert "Execution failed" in record.error_message
|
||||||
|
|
||||||
|
async def test_task_store_none_does_not_crash(self):
|
||||||
|
"""U3: Passing task_store=None should not crash."""
|
||||||
|
events = [_make_event("final_answer", data={"output": "Result"})]
|
||||||
|
engine = FakeReactEngine(events)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await _execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
|
||||||
|
async def test_task_store_list_by_status(self):
|
||||||
|
"""U3: TaskStore list_tasks filters by status correctly."""
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
|
||||||
|
# Create tasks in different states
|
||||||
|
task_store.create("task-1", "agent", {})
|
||||||
|
task_store.create("task-2", "agent", {})
|
||||||
|
task_store.create("task-3", "agent", {})
|
||||||
|
|
||||||
|
task_store.update_status("task-2", TaskStatus.RUNNING)
|
||||||
|
task_store.update_status("task-3", TaskStatus.COMPLETED, progress=1.0)
|
||||||
|
|
||||||
|
running = task_store.list_tasks(status=TaskStatus.RUNNING)
|
||||||
|
completed = task_store.list_tasks(status=TaskStatus.COMPLETED)
|
||||||
|
pending = task_store.list_tasks(status=TaskStatus.PENDING)
|
||||||
|
|
||||||
|
assert len(running) == 1
|
||||||
|
assert running[0].task_id == "task-2"
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0].task_id == "task-3"
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].task_id == "task-1"
|
||||||
|
|
||||||
|
async def test_task_store_metadata_contains_conversation_id(self):
|
||||||
|
"""U3: TaskStore metadata stores conversation_id for frontend recovery."""
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("task-1", "agent", {"message": "hello"})
|
||||||
|
task_store.update_status(
|
||||||
|
"task-1",
|
||||||
|
TaskStatus.PENDING,
|
||||||
|
metadata={"conversation_id": "conv-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
record = task_store.get("task-1")
|
||||||
|
assert record is not None
|
||||||
|
assert record.metadata.get("conversation_id") == "conv-123"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EventQueue task_id filtering tests (U2)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventQueueTaskIdFilter:
|
||||||
|
"""Tests for EventQueue subscribe(task_id=...) filtering."""
|
||||||
|
|
||||||
|
async def test_subscribe_with_task_id_filter(self):
|
||||||
|
"""U2: subscribe(task_id=...) only receives matching events."""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe(task_id="task-A"):
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
# Emit events for different tasks
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-B")) # Should be filtered out
|
||||||
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
# Should only receive task-A events
|
||||||
|
assert len(received) == 2
|
||||||
|
assert all(e.task_id == "task-A" for e in received)
|
||||||
|
|
||||||
|
async def test_subscribe_without_filter_receives_all(self):
|
||||||
|
"""U2: subscribe() without task_id receives all events (backward compat)."""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-B"))
|
||||||
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
# Should receive all events regardless of task_id
|
||||||
|
assert len(received) == 3
|
||||||
|
|
||||||
|
async def test_subscribe_replays_buffer_filtered(self):
|
||||||
|
"""U2: Buffer replay respects task_id filter."""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
# Emit events before subscribing
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
||||||
|
await eq.emit(_make_event("thinking", task_id="task-B"))
|
||||||
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe(task_id="task-A"):
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
# Should only replay task-A events from buffer
|
||||||
|
assert len(received) == 2
|
||||||
|
assert all(e.task_id == "task-A" for e in received)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P0 #2: CancelledError path tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancelledErrorPath:
|
||||||
|
"""P0 #2: Verify CancelledError cleanup persists partial output and
|
||||||
|
marks the task as FAILED, and TASK_FAILED event is emitted."""
|
||||||
|
|
||||||
|
async def test_cancel_persists_partial_output(self):
|
||||||
|
"""P0 #2: When task is cancelled mid-execution, partial output
|
||||||
|
collected before cancellation is persisted to conversation store."""
|
||||||
|
first_event = _make_event("final_answer", data={"output": "Partial before cancel"})
|
||||||
|
engine = CancellableReactEngine(first_event)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("test-task", "test-agent", {"message": "hello"})
|
||||||
|
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the engine to yield its first event
|
||||||
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
||||||
|
bg_task.cancel()
|
||||||
|
with self._expect_cancelled():
|
||||||
|
await bg_task
|
||||||
|
|
||||||
|
# Partial output should be persisted
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
_, role, content = conv_store.messages[0]
|
||||||
|
assert role == "assistant"
|
||||||
|
assert content == "Partial before cancel"
|
||||||
|
|
||||||
|
async def test_cancel_marks_task_failed_in_store(self):
|
||||||
|
"""P0 #2: CancelledError marks task status as FAILED in TaskStore."""
|
||||||
|
first_event = _make_event("final_answer", data={"output": "Partial"})
|
||||||
|
engine = CancellableReactEngine(first_event)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("test-task", "test-agent", {"message": "hello"})
|
||||||
|
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
||||||
|
bg_task.cancel()
|
||||||
|
with self._expect_cancelled():
|
||||||
|
await bg_task
|
||||||
|
|
||||||
|
record = task_store.get("test-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.FAILED
|
||||||
|
assert record.error_message is not None
|
||||||
|
assert "cancelled" in record.error_message.lower()
|
||||||
|
|
||||||
|
async def test_cancel_emits_task_failed_event(self):
|
||||||
|
"""P0 #2: CancelledError emits TASK_FAILED event to EventQueue."""
|
||||||
|
first_event = _make_event("final_answer", data={"output": "Partial"})
|
||||||
|
engine = CancellableReactEngine(first_event)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe(task_id="test-task"):
|
||||||
|
received.append(evt)
|
||||||
|
if evt.event_type == TaskEventType.TASK_FAILED:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
||||||
|
bg_task.cancel()
|
||||||
|
with self._expect_cancelled():
|
||||||
|
await bg_task
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
||||||
|
|
||||||
|
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
|
||||||
|
assert len(failed_events) == 1
|
||||||
|
assert "cancelled" in failed_events[0].data.get("error", "").lower()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expect_cancelled():
|
||||||
|
"""Context manager that expects asyncio.CancelledError to be raised."""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
return contextlib.suppress(asyncio.CancelledError)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P0 #3: Resume handler conversation_id ownership verification tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeOwnershipVerification:
|
||||||
|
"""P0 #3: Verify resume path rejects tasks from a different conversation.
|
||||||
|
|
||||||
|
These tests exercise the TaskStore metadata check directly, since the
|
||||||
|
WebSocket resume handler reads metadata to verify ownership.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def test_resume_rejects_mismatched_conversation_id(self):
|
||||||
|
"""P0 #3: Task with conversation_id mismatch should be rejected.
|
||||||
|
|
||||||
|
Simulates the metadata check performed in portal.py resume handler:
|
||||||
|
if record.metadata['conversation_id'] != request conversation_id,
|
||||||
|
the resume is rejected with an error.
|
||||||
|
"""
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("task-X", "agent", {"message": "hello"})
|
||||||
|
task_store.update_status(
|
||||||
|
"task-X",
|
||||||
|
TaskStatus.RUNNING,
|
||||||
|
metadata={"conversation_id": "conv-A"},
|
||||||
|
)
|
||||||
|
|
||||||
|
record = task_store.get("task-X")
|
||||||
|
assert record is not None
|
||||||
|
|
||||||
|
# Simulate the ownership check from portal.py resume handler
|
||||||
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
||||||
|
request_conv_id = "conv-B" # Different conversation
|
||||||
|
|
||||||
|
assert task_conv_id != request_conv_id
|
||||||
|
# In portal.py, this mismatch triggers an error response and `continue`
|
||||||
|
|
||||||
|
async def test_resume_accepts_matching_conversation_id(self):
|
||||||
|
"""P0 #3: Task with matching conversation_id should be allowed to resume."""
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("task-Y", "agent", {"message": "hello"})
|
||||||
|
task_store.update_status(
|
||||||
|
"task-Y",
|
||||||
|
TaskStatus.RUNNING,
|
||||||
|
metadata={"conversation_id": "conv-A"},
|
||||||
|
)
|
||||||
|
|
||||||
|
record = task_store.get("task-Y")
|
||||||
|
assert record is not None
|
||||||
|
|
||||||
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
||||||
|
request_conv_id = "conv-A" # Same conversation
|
||||||
|
|
||||||
|
assert task_conv_id == request_conv_id
|
||||||
|
# In portal.py, this passes the ownership check and proceeds to subscribe
|
||||||
|
|
||||||
|
async def test_resume_rejects_when_metadata_missing(self):
|
||||||
|
"""P1 #3: When task metadata has no conversation_id, resume is
|
||||||
|
rejected (fail-closed). Previously this was allowed for backward
|
||||||
|
compatibility, but the security review identified this as a
|
||||||
|
bypass vector — an attacker can omit conversation_id to subscribe
|
||||||
|
to any task's events."""
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("task-Z", "agent", {"message": "hello"})
|
||||||
|
# No metadata update — metadata defaults to {}
|
||||||
|
|
||||||
|
record = task_store.get("task-Z")
|
||||||
|
assert record is not None
|
||||||
|
|
||||||
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
||||||
|
request_conv_id = "conv-A"
|
||||||
|
|
||||||
|
# P1 #3 fix: fail-closed — reject if task_conv_id is missing
|
||||||
|
should_reject = not task_conv_id or task_conv_id != request_conv_id
|
||||||
|
assert should_reject
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P0 #4: Cancel propagation tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancelPropagation:
|
||||||
|
"""P0 #4: Verify explicit cancel (msg_type == 'cancel') propagates
|
||||||
|
correctly to the background task and marks it FAILED."""
|
||||||
|
|
||||||
|
async def test_explicit_cancel_marks_task_failed(self):
|
||||||
|
"""P0 #4: When a background task is explicitly cancelled (simulating
|
||||||
|
the msg_type == 'cancel' handler), it should propagate CancelledError
|
||||||
|
and mark the task FAILED with partial output persisted."""
|
||||||
|
first_event = _make_event("final_answer", data={"output": "Partial before user cancel"})
|
||||||
|
engine = CancellableReactEngine(first_event)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("cancel-task", "agent", {"message": "hello"})
|
||||||
|
|
||||||
|
# Simulate the background task as portal.py would create it
|
||||||
|
active_bg_task: asyncio.Task | None = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="cancel-conv",
|
||||||
|
task_id="cancel-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
||||||
|
|
||||||
|
# Simulate the cancel handler: active_bg_task.cancel()
|
||||||
|
assert active_bg_task is not None
|
||||||
|
assert not active_bg_task.done()
|
||||||
|
active_bg_task.cancel()
|
||||||
|
|
||||||
|
with _suppress_cancelled():
|
||||||
|
await active_bg_task
|
||||||
|
|
||||||
|
# Verify task is marked FAILED
|
||||||
|
record = task_store.get("cancel-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.FAILED
|
||||||
|
|
||||||
|
# Verify partial output was persisted
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
_, _, content = conv_store.messages[0]
|
||||||
|
assert content == "Partial before user cancel"
|
||||||
|
|
||||||
|
async def test_cancel_after_completion_is_noop(self):
|
||||||
|
"""P0 #4: Cancelling an already-completed task is a no-op
|
||||||
|
(active_bg_task.done() check prevents double-cancel)."""
|
||||||
|
events = [_make_event("final_answer", data={"output": "Done"})]
|
||||||
|
engine = FakeReactEngine(events)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await bg_task
|
||||||
|
|
||||||
|
# Already done — cancel should be a no-op per portal.py guard:
|
||||||
|
# `if active_bg_task is not None and not active_bg_task.done():`
|
||||||
|
assert bg_task.done()
|
||||||
|
# No exception, no state change
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P0 #5: WebSocketDisconnect does NOT cancel background task
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestWebSocketDisconnectNoCancel:
|
||||||
|
"""P0 #5: Verify that WebSocketDisconnect does NOT cancel the background
|
||||||
|
task — this is the core invariant of the three-layer defense.
|
||||||
|
|
||||||
|
The test simulates the portal.py control flow: a background task is
|
||||||
|
started, then the WebSocket disconnects (simulated by cancelling the
|
||||||
|
subscribe loop but NOT the background task). The background task should
|
||||||
|
continue running and persist its result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def test_disconnect_does_not_cancel_background_task(self):
|
||||||
|
"""P0 #5: After WebSocketDisconnect, the background task continues
|
||||||
|
running and persists its result to the conversation store."""
|
||||||
|
events = [
|
||||||
|
_make_event("thinking", data={"text": "Thinking..."}),
|
||||||
|
_make_event("final_answer", data={"output": "Final result"}),
|
||||||
|
]
|
||||||
|
engine = SlowFakeReactEngine(events, delay=0.2)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
# Start the background task (as portal.py would)
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="test-conv",
|
||||||
|
task_id="test-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate WebSocketDisconnect: the subscribe loop is interrupted,
|
||||||
|
# but the background task is NOT cancelled (per P0 #1 fix).
|
||||||
|
# We just stop subscribing — bg_task keeps running.
|
||||||
|
await asyncio.sleep(0.1) # Let bg_task start
|
||||||
|
|
||||||
|
# Verify bg_task is still running (not cancelled by disconnect)
|
||||||
|
assert not bg_task.done()
|
||||||
|
|
||||||
|
# Wait for bg_task to complete naturally
|
||||||
|
await asyncio.wait_for(bg_task, timeout=5.0)
|
||||||
|
|
||||||
|
# Result should be persisted despite "disconnect"
|
||||||
|
assert len(conv_store.messages) == 1
|
||||||
|
_, _, content = conv_store.messages[0]
|
||||||
|
assert content == "Final result"
|
||||||
|
|
||||||
|
async def test_disconnect_result_available_for_resume(self):
|
||||||
|
"""P0 #5: After disconnect, the completed task's result is available
|
||||||
|
in TaskStore so a reconnecting client can retrieve it via resume."""
|
||||||
|
events = [_make_event("final_answer", data={"output": "Resumable result"})]
|
||||||
|
engine = FakeReactEngine(events)
|
||||||
|
conv_store = FakeConversationStore()
|
||||||
|
eq = EventQueue()
|
||||||
|
task_store = InMemoryTaskStore()
|
||||||
|
task_store.create("resume-task", "agent", {"message": "hello"})
|
||||||
|
task_store.update_status(
|
||||||
|
"resume-task",
|
||||||
|
TaskStatus.RUNNING,
|
||||||
|
metadata={"conversation_id": "resume-conv"},
|
||||||
|
)
|
||||||
|
|
||||||
|
bg_task = asyncio.create_task(
|
||||||
|
_execute_react_background(
|
||||||
|
react_engine=engine,
|
||||||
|
messages=[],
|
||||||
|
tools=[],
|
||||||
|
model="test-model",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
timeout_seconds=None,
|
||||||
|
conv_id="resume-conv",
|
||||||
|
task_id="resume-task",
|
||||||
|
event_queue=eq,
|
||||||
|
conversation_store=conv_store,
|
||||||
|
task_store=task_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate disconnect: don't cancel bg_task, just let it run
|
||||||
|
await asyncio.wait_for(bg_task, timeout=5.0)
|
||||||
|
|
||||||
|
# Task should be COMPLETED with output available for resume
|
||||||
|
record = task_store.get("resume-task")
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.COMPLETED
|
||||||
|
assert record.output_data is not None
|
||||||
|
assert record.output_data.get("output") == "Resumable result"
|
||||||
Loading…
Reference in New Issue