"""Tests for WebSocket task persistence and background execution (U1-U3). Tests cover: - U1: Partial output persistence on WebSocket disconnect - U2: Background ReAct task execution with EventQueue event distribution - U3: TaskStore registration and status tracking - P0 #2: CancelledError path — partial output persisted, task marked FAILED - P0 #3: Resume handler — conversation_id ownership verification - P0 #4: Cancel propagation — explicit cancel marks task FAILED - P0 #5: WebSocketDisconnect does NOT cancel background task """ from __future__ import annotations import asyncio from agentkit.core.event_queue import EventQueue from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType from agentkit.server.routes.portal import _execute_react_background from agentkit.server.task_store import InMemoryTaskStore # --------------------------------------------------------------------------- # Fakes # --------------------------------------------------------------------------- class FakeConversationStore: """Minimal conversation store for testing.""" def __init__(self) -> None: self.messages: list[tuple[str, str, str]] = [] async def add_message(self, conv_id: str, role: str, content: str) -> None: self.messages.append((conv_id, role, content)) class FakeReactEngine: """Fake ReAct engine that yields events from a predefined list.""" name = "test-agent" def __init__(self, events: list[Event]) -> None: self._events = events async def execute_stream(self, task): for event in self._events: yield event class FailingReactEngine: """Fake ReAct engine that raises an exception after yielding some events.""" name = "test-agent" def __init__(self, events: list[Event], error: Exception) -> None: self._events = events self._error = error async def execute_stream(self, task): for event in self._events: yield event raise self._error def _make_event( event_type: str, task_id: str = "test-task", session_id: str = "test-conv", data: dict | None = None, ) -> Event: return Event.create( event_type=event_type, task_id=task_id, session_id=session_id, data=data or {}, ) class SlowFakeReactEngine: """Fake ReAct engine with a delay to allow status checks during execution.""" name = "test-agent" def __init__(self, events: list[Event], delay: float = 0.1) -> None: self._events = events self._delay = delay async def execute_stream(self, task): for event in self._events: await asyncio.sleep(self._delay) yield event class CancellableReactEngine: """Fake ReAct engine that blocks forever until cancelled. Yields one event so collected_output is non-empty, then blocks on an Event so the test can cancel the task and verify CancelledError cleanup. """ name = "test-agent" def __init__(self, first_event: Event) -> None: self._first_event = first_event self.started = asyncio.Event() async def execute_stream(self, task): yield self._first_event self.started.set() # Block forever until cancelled await asyncio.Event().wait() def _suppress_cancelled(): """Context manager that suppresses asyncio.CancelledError.""" import contextlib return contextlib.suppress(asyncio.CancelledError) # --------------------------------------------------------------------------- # U1 + U2: Background task persistence tests # --------------------------------------------------------------------------- class TestExecuteReactBackground: """Tests for _execute_react_background (U1 + U2).""" async def test_normal_completion_persists_result(self): """U2: Normal completion persists result to conversation store.""" events = [ _make_event("thinking", data={"text": "Let me think..."}), _make_event("final_answer", data={"output": "The answer is 42"}), ] engine = FakeReactEngine(events) conv_store = FakeConversationStore() eq = EventQueue() await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) # Result should be persisted assert len(conv_store.messages) == 1 conv_id, role, content = conv_store.messages[0] assert conv_id == "test-conv" assert role == "assistant" assert content == "The answer is 42" async def test_partial_output_persisted_on_error(self): """U1: Partial output is persisted when ReAct engine raises an error.""" events = [ _make_event("thinking", data={"text": "Thinking..."}), _make_event("final_answer", data={"output": "Partial result"}), ] error = RuntimeError("LLM timeout") engine = FailingReactEngine(events, error) conv_store = FakeConversationStore() eq = EventQueue() await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) # Partial output should be persisted assert len(conv_store.messages) == 1 _, role, content = conv_store.messages[0] assert role == "assistant" assert content == "Partial result" async def test_no_output_on_error_without_final_answer(self): """U1: No message persisted when error occurs before any final_answer.""" events = [_make_event("thinking", data={"text": "Thinking..."})] error = RuntimeError("Early failure") engine = FailingReactEngine(events, error) conv_store = FakeConversationStore() eq = EventQueue() await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) # No assistant message should be persisted (collected_output is empty) assert len(conv_store.messages) == 0 async def test_events_emitted_to_event_queue(self): """U2: Events are emitted to EventQueue for subscribers.""" events = [ _make_event("thinking", data={"text": "Thinking..."}), _make_event("final_answer", data={"output": "Done"}), ] engine = FakeReactEngine(events) conv_store = FakeConversationStore() eq = EventQueue() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(task_id="test-task"): received.append(evt) if evt.event_type == TaskEventType.TASK_COMPLETED: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) await asyncio.wait_for(sub_task, timeout=2.0) # Should receive thinking, final_answer, and task.completed events # P1 #9: ReAct event types are mapped to TurnEventType constants event_types = [e.event_type for e in received] assert TurnEventType.THINKING in event_types assert TurnEventType.FINAL_ANSWER in event_types assert TaskEventType.TASK_COMPLETED in event_types async def test_task_failed_event_on_error(self): """U2: task.failed event is emitted on error.""" events: list[Event] = [] error = RuntimeError("Execution failed") engine = FailingReactEngine(events, error) conv_store = FakeConversationStore() eq = EventQueue() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(task_id="test-task"): received.append(evt) if evt.event_type == TaskEventType.TASK_FAILED: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) await asyncio.wait_for(sub_task, timeout=2.0) failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED] assert len(failed_events) == 1 assert "Execution failed" in failed_events[0].data.get("error", "") # --------------------------------------------------------------------------- # U3: TaskStore integration tests # --------------------------------------------------------------------------- class TestTaskStoreIntegration: """Tests for TaskStore registration and status tracking (U3).""" async def test_task_store_status_running_during_execution(self): """U3: TaskStore status is RUNNING during background execution.""" events = [ _make_event("thinking", data={"text": "Thinking..."}), _make_event("final_answer", data={"output": "Result"}), ] engine = SlowFakeReactEngine(events, delay=0.2) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create( task_id="test-task", agent_name="test-agent", input_data={"message": "hello"}, ) # Start background task bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) ) # Check status while running (need to yield control) await asyncio.sleep(0.01) record = task_store.get("test-task") assert record is not None assert record.status == TaskStatus.RUNNING await asyncio.wait_for(bg_task, timeout=2.0) # After completion, status should be COMPLETED record = task_store.get("test-task") assert record is not None assert record.status == TaskStatus.COMPLETED assert record.output_data is not None assert record.output_data.get("output") == "Result" assert record.progress == 1.0 async def test_task_store_status_failed_on_error(self): """U3: TaskStore status is FAILED when background task raises error.""" events: list[Event] = [] error = RuntimeError("Execution failed") engine = FailingReactEngine(events, error) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create( task_id="test-task", agent_name="test-agent", input_data={"message": "hello"}, ) await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) record = task_store.get("test-task") assert record is not None assert record.status == TaskStatus.FAILED assert record.error_message is not None assert "Execution failed" in record.error_message async def test_task_store_none_does_not_crash(self): """U3: Passing task_store=None should not crash.""" events = [_make_event("final_answer", data={"output": "Result"})] engine = FakeReactEngine(events) conv_store = FakeConversationStore() eq = EventQueue() # Should not raise await _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, task_store=None, ) assert len(conv_store.messages) == 1 async def test_task_store_list_by_status(self): """U3: TaskStore list_tasks filters by status correctly.""" task_store = InMemoryTaskStore() # Create tasks in different states task_store.create("task-1", "agent", {}) task_store.create("task-2", "agent", {}) task_store.create("task-3", "agent", {}) task_store.update_status("task-2", TaskStatus.RUNNING) task_store.update_status("task-3", TaskStatus.COMPLETED, progress=1.0) running = task_store.list_tasks(status=TaskStatus.RUNNING) completed = task_store.list_tasks(status=TaskStatus.COMPLETED) pending = task_store.list_tasks(status=TaskStatus.PENDING) assert len(running) == 1 assert running[0].task_id == "task-2" assert len(completed) == 1 assert completed[0].task_id == "task-3" assert len(pending) == 1 assert pending[0].task_id == "task-1" async def test_task_store_metadata_contains_conversation_id(self): """U3: TaskStore metadata stores conversation_id for frontend recovery.""" task_store = InMemoryTaskStore() task_store.create("task-1", "agent", {"message": "hello"}) task_store.update_status( "task-1", TaskStatus.PENDING, metadata={"conversation_id": "conv-123"}, ) record = task_store.get("task-1") assert record is not None assert record.metadata.get("conversation_id") == "conv-123" # --------------------------------------------------------------------------- # EventQueue task_id filtering tests (U2) # --------------------------------------------------------------------------- class TestEventQueueTaskIdFilter: """Tests for EventQueue subscribe(task_id=...) filtering.""" async def test_subscribe_with_task_id_filter(self): """U2: subscribe(task_id=...) only receives matching events.""" eq = EventQueue() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(task_id="task-A"): received.append(evt) if len(received) >= 2: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) # Emit events for different tasks await eq.emit(_make_event("thinking", task_id="task-A")) await eq.emit(_make_event("thinking", task_id="task-B")) # Should be filtered out await eq.emit(_make_event("final_answer", task_id="task-A")) await asyncio.wait_for(sub_task, timeout=2.0) # Should only receive task-A events assert len(received) == 2 assert all(e.task_id == "task-A" for e in received) async def test_subscribe_without_filter_receives_all(self): """U2: subscribe() without task_id receives all events (backward compat).""" eq = EventQueue() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) if len(received) >= 3: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) await eq.emit(_make_event("thinking", task_id="task-A")) await eq.emit(_make_event("thinking", task_id="task-B")) await eq.emit(_make_event("final_answer", task_id="task-A")) await asyncio.wait_for(sub_task, timeout=2.0) # Should receive all events regardless of task_id assert len(received) == 3 async def test_subscribe_replays_buffer_filtered(self): """U2: Buffer replay respects task_id filter.""" eq = EventQueue() # Emit events before subscribing await eq.emit(_make_event("thinking", task_id="task-A")) await eq.emit(_make_event("thinking", task_id="task-B")) await eq.emit(_make_event("final_answer", task_id="task-A")) received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(task_id="task-A"): received.append(evt) if len(received) >= 2: break sub_task = asyncio.create_task(subscriber()) await asyncio.wait_for(sub_task, timeout=2.0) # Should only replay task-A events from buffer assert len(received) == 2 assert all(e.task_id == "task-A" for e in received) # --------------------------------------------------------------------------- # P0 #2: CancelledError path tests # --------------------------------------------------------------------------- class TestCancelledErrorPath: """P0 #2: Verify CancelledError cleanup persists partial output and marks the task as FAILED, and TASK_FAILED event is emitted.""" async def test_cancel_persists_partial_output(self): """P0 #2: When task is cancelled mid-execution, partial output collected before cancellation is persisted to conversation store.""" first_event = _make_event("final_answer", data={"output": "Partial before cancel"}) engine = CancellableReactEngine(first_event) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create("test-task", "test-agent", {"message": "hello"}) bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) ) # Wait for the engine to yield its first event await asyncio.wait_for(engine.started.wait(), timeout=2.0) bg_task.cancel() with self._expect_cancelled(): await bg_task # Partial output should be persisted assert len(conv_store.messages) == 1 _, role, content = conv_store.messages[0] assert role == "assistant" assert content == "Partial before cancel" async def test_cancel_marks_task_failed_in_store(self): """P0 #2: CancelledError marks task status as FAILED in TaskStore.""" first_event = _make_event("final_answer", data={"output": "Partial"}) engine = CancellableReactEngine(first_event) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create("test-task", "test-agent", {"message": "hello"}) bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) ) await asyncio.wait_for(engine.started.wait(), timeout=2.0) bg_task.cancel() with self._expect_cancelled(): await bg_task record = task_store.get("test-task") assert record is not None assert record.status == TaskStatus.FAILED assert record.error_message is not None assert "cancelled" in record.error_message.lower() async def test_cancel_emits_task_failed_event(self): """P0 #2: CancelledError emits TASK_FAILED event to EventQueue.""" first_event = _make_event("final_answer", data={"output": "Partial"}) engine = CancellableReactEngine(first_event) conv_store = FakeConversationStore() eq = EventQueue() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(task_id="test-task"): received.append(evt) if evt.event_type == TaskEventType.TASK_FAILED: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) ) await asyncio.wait_for(engine.started.wait(), timeout=2.0) bg_task.cancel() with self._expect_cancelled(): await bg_task await asyncio.wait_for(sub_task, timeout=2.0) failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED] assert len(failed_events) == 1 assert "cancelled" in failed_events[0].data.get("error", "").lower() @staticmethod def _expect_cancelled(): """Context manager that expects asyncio.CancelledError to be raised.""" import contextlib return contextlib.suppress(asyncio.CancelledError) # --------------------------------------------------------------------------- # P0 #3: Resume handler conversation_id ownership verification tests # --------------------------------------------------------------------------- class TestResumeOwnershipVerification: """P0 #3: Verify resume path rejects tasks from a different conversation. These tests exercise the TaskStore metadata check directly, since the WebSocket resume handler reads metadata to verify ownership. """ async def test_resume_rejects_mismatched_conversation_id(self): """P0 #3: Task with conversation_id mismatch should be rejected. Simulates the metadata check performed in portal.py resume handler: if record.metadata['conversation_id'] != request conversation_id, the resume is rejected with an error. """ task_store = InMemoryTaskStore() task_store.create("task-X", "agent", {"message": "hello"}) task_store.update_status( "task-X", TaskStatus.RUNNING, metadata={"conversation_id": "conv-A"}, ) record = task_store.get("task-X") assert record is not None # Simulate the ownership check from portal.py resume handler task_conv_id = (record.metadata or {}).get("conversation_id", "") request_conv_id = "conv-B" # Different conversation assert task_conv_id != request_conv_id # In portal.py, this mismatch triggers an error response and `continue` async def test_resume_accepts_matching_conversation_id(self): """P0 #3: Task with matching conversation_id should be allowed to resume.""" task_store = InMemoryTaskStore() task_store.create("task-Y", "agent", {"message": "hello"}) task_store.update_status( "task-Y", TaskStatus.RUNNING, metadata={"conversation_id": "conv-A"}, ) record = task_store.get("task-Y") assert record is not None task_conv_id = (record.metadata or {}).get("conversation_id", "") request_conv_id = "conv-A" # Same conversation assert task_conv_id == request_conv_id # In portal.py, this passes the ownership check and proceeds to subscribe async def test_resume_rejects_when_metadata_missing(self): """P1 #3: When task metadata has no conversation_id, resume is rejected (fail-closed). Previously this was allowed for backward compatibility, but the security review identified this as a bypass vector — an attacker can omit conversation_id to subscribe to any task's events.""" task_store = InMemoryTaskStore() task_store.create("task-Z", "agent", {"message": "hello"}) # No metadata update — metadata defaults to {} record = task_store.get("task-Z") assert record is not None task_conv_id = (record.metadata or {}).get("conversation_id", "") request_conv_id = "conv-A" # P1 #3 fix: fail-closed — reject if task_conv_id is missing should_reject = not task_conv_id or task_conv_id != request_conv_id assert should_reject # --------------------------------------------------------------------------- # P0 #4: Cancel propagation tests # --------------------------------------------------------------------------- class TestCancelPropagation: """P0 #4: Verify explicit cancel (msg_type == 'cancel') propagates correctly to the background task and marks it FAILED.""" async def test_explicit_cancel_marks_task_failed(self): """P0 #4: When a background task is explicitly cancelled (simulating the msg_type == 'cancel' handler), it should propagate CancelledError and mark the task FAILED with partial output persisted.""" first_event = _make_event("final_answer", data={"output": "Partial before user cancel"}) engine = CancellableReactEngine(first_event) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create("cancel-task", "agent", {"message": "hello"}) # Simulate the background task as portal.py would create it active_bg_task: asyncio.Task | None = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="cancel-conv", task_id="cancel-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) ) await asyncio.wait_for(engine.started.wait(), timeout=2.0) # Simulate the cancel handler: active_bg_task.cancel() assert active_bg_task is not None assert not active_bg_task.done() active_bg_task.cancel() with _suppress_cancelled(): await active_bg_task # Verify task is marked FAILED record = task_store.get("cancel-task") assert record is not None assert record.status == TaskStatus.FAILED # Verify partial output was persisted assert len(conv_store.messages) == 1 _, _, content = conv_store.messages[0] assert content == "Partial before user cancel" async def test_cancel_after_completion_is_noop(self): """P0 #4: Cancelling an already-completed task is a no-op (active_bg_task.done() check prevents double-cancel).""" events = [_make_event("final_answer", data={"output": "Done"})] engine = FakeReactEngine(events) conv_store = FakeConversationStore() eq = EventQueue() bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) ) await bg_task # Already done — cancel should be a no-op per portal.py guard: # `if active_bg_task is not None and not active_bg_task.done():` assert bg_task.done() # No exception, no state change assert len(conv_store.messages) == 1 # --------------------------------------------------------------------------- # P0 #5: WebSocketDisconnect does NOT cancel background task # --------------------------------------------------------------------------- class TestWebSocketDisconnectNoCancel: """P0 #5: Verify that WebSocketDisconnect does NOT cancel the background task — this is the core invariant of the three-layer defense. The test simulates the portal.py control flow: a background task is started, then the WebSocket disconnects (simulated by cancelling the subscribe loop but NOT the background task). The background task should continue running and persist its result. """ async def test_disconnect_does_not_cancel_background_task(self): """P0 #5: After WebSocketDisconnect, the background task continues running and persists its result to the conversation store.""" events = [ _make_event("thinking", data={"text": "Thinking..."}), _make_event("final_answer", data={"output": "Final result"}), ] engine = SlowFakeReactEngine(events, delay=0.2) conv_store = FakeConversationStore() eq = EventQueue() # Start the background task (as portal.py would) bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="test-conv", task_id="test-task", event_queue=eq, conversation_store=conv_store, ) ) # Simulate WebSocketDisconnect: the subscribe loop is interrupted, # but the background task is NOT cancelled (per P0 #1 fix). # We just stop subscribing — bg_task keeps running. await asyncio.sleep(0.1) # Let bg_task start # Verify bg_task is still running (not cancelled by disconnect) assert not bg_task.done() # Wait for bg_task to complete naturally await asyncio.wait_for(bg_task, timeout=5.0) # Result should be persisted despite "disconnect" assert len(conv_store.messages) == 1 _, _, content = conv_store.messages[0] assert content == "Final result" async def test_disconnect_result_available_for_resume(self): """P0 #5: After disconnect, the completed task's result is available in TaskStore so a reconnecting client can retrieve it via resume.""" events = [_make_event("final_answer", data={"output": "Resumable result"})] engine = FakeReactEngine(events) conv_store = FakeConversationStore() eq = EventQueue() task_store = InMemoryTaskStore() task_store.create("resume-task", "agent", {"message": "hello"}) task_store.update_status( "resume-task", TaskStatus.RUNNING, metadata={"conversation_id": "resume-conv"}, ) bg_task = asyncio.create_task( _execute_react_background( agent=engine, messages=[], system_prompt=None, timeout_seconds=None, conv_id="resume-conv", task_id="resume-task", event_queue=eq, conversation_store=conv_store, task_store=task_store, ) ) # Simulate disconnect: don't cancel bg_task, just let it run await asyncio.wait_for(bg_task, timeout=5.0) # Task should be COMPLETED with output available for resume record = task_store.get("resume-task") assert record is not None assert record.status == TaskStatus.COMPLETED assert record.output_data is not None assert record.output_data.get("output") == "Resumable result"