939 lines
33 KiB
Python
939 lines
33 KiB
Python
"""Tests for WebSocket task persistence and background execution (U1-U3).
|
|
|
|
Tests cover:
|
|
- U1: Partial output persistence on WebSocket disconnect
|
|
- U2: Background ReAct task execution with EventQueue event distribution
|
|
- U3: TaskStore registration and status tracking
|
|
- P0 #2: CancelledError path — partial output persisted, task marked FAILED
|
|
- P0 #3: Resume handler — conversation_id ownership verification
|
|
- P0 #4: Cancel propagation — explicit cancel marks task FAILED
|
|
- P0 #5: WebSocketDisconnect does NOT cancel background task
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
from agentkit.core.event_queue import EventQueue
|
|
from agentkit.core.protocol import Event, TaskEventType, TaskStatus, TurnEventType
|
|
from agentkit.server.routes.portal import _execute_react_background
|
|
from agentkit.server.task_store import InMemoryTaskStore
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fakes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FakeConversationStore:
|
|
"""Minimal conversation store for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self.messages: list[tuple[str, str, str]] = []
|
|
|
|
async def add_message(self, conv_id: str, role: str, content: str) -> None:
|
|
self.messages.append((conv_id, role, content))
|
|
|
|
|
|
class FakeReactEngine:
|
|
"""Fake ReAct engine that yields events from a predefined list."""
|
|
|
|
def __init__(self, events: list[Event]) -> None:
|
|
self._events = events
|
|
|
|
async def execute_stream(self, **kwargs):
|
|
for event in self._events:
|
|
yield event
|
|
|
|
|
|
class FailingReactEngine:
|
|
"""Fake ReAct engine that raises an exception after yielding some events."""
|
|
|
|
def __init__(self, events: list[Event], error: Exception) -> None:
|
|
self._events = events
|
|
self._error = error
|
|
|
|
async def execute_stream(self, **kwargs):
|
|
for event in self._events:
|
|
yield event
|
|
raise self._error
|
|
|
|
|
|
def _make_event(
|
|
event_type: str,
|
|
task_id: str = "test-task",
|
|
session_id: str = "test-conv",
|
|
data: dict | None = None,
|
|
) -> Event:
|
|
return Event.create(
|
|
event_type=event_type,
|
|
task_id=task_id,
|
|
session_id=session_id,
|
|
data=data or {},
|
|
)
|
|
|
|
|
|
class SlowFakeReactEngine:
|
|
"""Fake ReAct engine with a delay to allow status checks during execution."""
|
|
|
|
def __init__(self, events: list[Event], delay: float = 0.1) -> None:
|
|
self._events = events
|
|
self._delay = delay
|
|
|
|
async def execute_stream(self, **kwargs):
|
|
for event in self._events:
|
|
await asyncio.sleep(self._delay)
|
|
yield event
|
|
|
|
|
|
class CancellableReactEngine:
|
|
"""Fake ReAct engine that blocks forever until cancelled.
|
|
|
|
Yields one event so collected_output is non-empty, then blocks on an
|
|
Event so the test can cancel the task and verify CancelledError cleanup.
|
|
"""
|
|
|
|
def __init__(self, first_event: Event) -> None:
|
|
self._first_event = first_event
|
|
self.started = asyncio.Event()
|
|
|
|
async def execute_stream(self, **kwargs):
|
|
yield self._first_event
|
|
self.started.set()
|
|
# Block forever until cancelled
|
|
await asyncio.Event().wait()
|
|
|
|
|
|
def _suppress_cancelled():
|
|
"""Context manager that suppresses asyncio.CancelledError."""
|
|
import contextlib
|
|
|
|
return contextlib.suppress(asyncio.CancelledError)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# U1 + U2: Background task persistence tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExecuteReactBackground:
|
|
"""Tests for _execute_react_background (U1 + U2)."""
|
|
|
|
async def test_normal_completion_persists_result(self):
|
|
"""U2: Normal completion persists result to conversation store."""
|
|
events = [
|
|
_make_event("thinking", data={"text": "Let me think..."}),
|
|
_make_event("final_answer", data={"output": "The answer is 42"}),
|
|
]
|
|
engine = FakeReactEngine(events)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
|
|
# Result should be persisted
|
|
assert len(conv_store.messages) == 1
|
|
conv_id, role, content = conv_store.messages[0]
|
|
assert conv_id == "test-conv"
|
|
assert role == "assistant"
|
|
assert content == "The answer is 42"
|
|
|
|
async def test_partial_output_persisted_on_error(self):
|
|
"""U1: Partial output is persisted when ReAct engine raises an error."""
|
|
events = [
|
|
_make_event("thinking", data={"text": "Thinking..."}),
|
|
_make_event("final_answer", data={"output": "Partial result"}),
|
|
]
|
|
error = RuntimeError("LLM timeout")
|
|
engine = FailingReactEngine(events, error)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
|
|
# Partial output should be persisted
|
|
assert len(conv_store.messages) == 1
|
|
_, role, content = conv_store.messages[0]
|
|
assert role == "assistant"
|
|
assert content == "Partial result"
|
|
|
|
async def test_no_output_on_error_without_final_answer(self):
|
|
"""U1: No message persisted when error occurs before any final_answer."""
|
|
events = [_make_event("thinking", data={"text": "Thinking..."})]
|
|
error = RuntimeError("Early failure")
|
|
engine = FailingReactEngine(events, error)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
|
|
# No assistant message should be persisted (collected_output is empty)
|
|
assert len(conv_store.messages) == 0
|
|
|
|
async def test_events_emitted_to_event_queue(self):
|
|
"""U2: Events are emitted to EventQueue for subscribers."""
|
|
events = [
|
|
_make_event("thinking", data={"text": "Thinking..."}),
|
|
_make_event("final_answer", data={"output": "Done"}),
|
|
]
|
|
engine = FakeReactEngine(events)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe(task_id="test-task"):
|
|
received.append(evt)
|
|
if evt.event_type == TaskEventType.TASK_COMPLETED:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.sleep(0.05)
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
# Should receive thinking, final_answer, and task.completed events
|
|
# P1 #9: ReAct event types are mapped to TurnEventType constants
|
|
event_types = [e.event_type for e in received]
|
|
assert TurnEventType.THINKING in event_types
|
|
assert TurnEventType.FINAL_ANSWER in event_types
|
|
assert TaskEventType.TASK_COMPLETED in event_types
|
|
|
|
async def test_task_failed_event_on_error(self):
|
|
"""U2: task.failed event is emitted on error."""
|
|
events: list[Event] = []
|
|
error = RuntimeError("Execution failed")
|
|
engine = FailingReactEngine(events, error)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe(task_id="test-task"):
|
|
received.append(evt)
|
|
if evt.event_type == TaskEventType.TASK_FAILED:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.sleep(0.05)
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
|
|
assert len(failed_events) == 1
|
|
assert "Execution failed" in failed_events[0].data.get("error", "")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# U3: TaskStore integration tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTaskStoreIntegration:
|
|
"""Tests for TaskStore registration and status tracking (U3)."""
|
|
|
|
async def test_task_store_status_running_during_execution(self):
|
|
"""U3: TaskStore status is RUNNING during background execution."""
|
|
events = [
|
|
_make_event("thinking", data={"text": "Thinking..."}),
|
|
_make_event("final_answer", data={"output": "Result"}),
|
|
]
|
|
engine = SlowFakeReactEngine(events, delay=0.2)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
|
|
task_store.create(
|
|
task_id="test-task",
|
|
agent_name="test-agent",
|
|
input_data={"message": "hello"},
|
|
)
|
|
|
|
# Start background task
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
)
|
|
|
|
# Check status while running (need to yield control)
|
|
await asyncio.sleep(0.01)
|
|
record = task_store.get("test-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.RUNNING
|
|
|
|
await asyncio.wait_for(bg_task, timeout=2.0)
|
|
|
|
# After completion, status should be COMPLETED
|
|
record = task_store.get("test-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.COMPLETED
|
|
assert record.output_data is not None
|
|
assert record.output_data.get("output") == "Result"
|
|
assert record.progress == 1.0
|
|
|
|
async def test_task_store_status_failed_on_error(self):
|
|
"""U3: TaskStore status is FAILED when background task raises error."""
|
|
events: list[Event] = []
|
|
error = RuntimeError("Execution failed")
|
|
engine = FailingReactEngine(events, error)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
|
|
task_store.create(
|
|
task_id="test-task",
|
|
agent_name="test-agent",
|
|
input_data={"message": "hello"},
|
|
)
|
|
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
|
|
record = task_store.get("test-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.FAILED
|
|
assert record.error_message is not None
|
|
assert "Execution failed" in record.error_message
|
|
|
|
async def test_task_store_none_does_not_crash(self):
|
|
"""U3: Passing task_store=None should not crash."""
|
|
events = [_make_event("final_answer", data={"output": "Result"})]
|
|
engine = FakeReactEngine(events)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
# Should not raise
|
|
await _execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=None,
|
|
)
|
|
|
|
assert len(conv_store.messages) == 1
|
|
|
|
async def test_task_store_list_by_status(self):
|
|
"""U3: TaskStore list_tasks filters by status correctly."""
|
|
task_store = InMemoryTaskStore()
|
|
|
|
# Create tasks in different states
|
|
task_store.create("task-1", "agent", {})
|
|
task_store.create("task-2", "agent", {})
|
|
task_store.create("task-3", "agent", {})
|
|
|
|
task_store.update_status("task-2", TaskStatus.RUNNING)
|
|
task_store.update_status("task-3", TaskStatus.COMPLETED, progress=1.0)
|
|
|
|
running = task_store.list_tasks(status=TaskStatus.RUNNING)
|
|
completed = task_store.list_tasks(status=TaskStatus.COMPLETED)
|
|
pending = task_store.list_tasks(status=TaskStatus.PENDING)
|
|
|
|
assert len(running) == 1
|
|
assert running[0].task_id == "task-2"
|
|
assert len(completed) == 1
|
|
assert completed[0].task_id == "task-3"
|
|
assert len(pending) == 1
|
|
assert pending[0].task_id == "task-1"
|
|
|
|
async def test_task_store_metadata_contains_conversation_id(self):
|
|
"""U3: TaskStore metadata stores conversation_id for frontend recovery."""
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("task-1", "agent", {"message": "hello"})
|
|
task_store.update_status(
|
|
"task-1",
|
|
TaskStatus.PENDING,
|
|
metadata={"conversation_id": "conv-123"},
|
|
)
|
|
|
|
record = task_store.get("task-1")
|
|
assert record is not None
|
|
assert record.metadata.get("conversation_id") == "conv-123"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EventQueue task_id filtering tests (U2)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEventQueueTaskIdFilter:
|
|
"""Tests for EventQueue subscribe(task_id=...) filtering."""
|
|
|
|
async def test_subscribe_with_task_id_filter(self):
|
|
"""U2: subscribe(task_id=...) only receives matching events."""
|
|
eq = EventQueue()
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe(task_id="task-A"):
|
|
received.append(evt)
|
|
if len(received) >= 2:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Emit events for different tasks
|
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
|
await eq.emit(_make_event("thinking", task_id="task-B")) # Should be filtered out
|
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
|
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
# Should only receive task-A events
|
|
assert len(received) == 2
|
|
assert all(e.task_id == "task-A" for e in received)
|
|
|
|
async def test_subscribe_without_filter_receives_all(self):
|
|
"""U2: subscribe() without task_id receives all events (backward compat)."""
|
|
eq = EventQueue()
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe():
|
|
received.append(evt)
|
|
if len(received) >= 3:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.sleep(0.05)
|
|
|
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
|
await eq.emit(_make_event("thinking", task_id="task-B"))
|
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
|
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
# Should receive all events regardless of task_id
|
|
assert len(received) == 3
|
|
|
|
async def test_subscribe_replays_buffer_filtered(self):
|
|
"""U2: Buffer replay respects task_id filter."""
|
|
eq = EventQueue()
|
|
|
|
# Emit events before subscribing
|
|
await eq.emit(_make_event("thinking", task_id="task-A"))
|
|
await eq.emit(_make_event("thinking", task_id="task-B"))
|
|
await eq.emit(_make_event("final_answer", task_id="task-A"))
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe(task_id="task-A"):
|
|
received.append(evt)
|
|
if len(received) >= 2:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
# Should only replay task-A events from buffer
|
|
assert len(received) == 2
|
|
assert all(e.task_id == "task-A" for e in received)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P0 #2: CancelledError path tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCancelledErrorPath:
|
|
"""P0 #2: Verify CancelledError cleanup persists partial output and
|
|
marks the task as FAILED, and TASK_FAILED event is emitted."""
|
|
|
|
async def test_cancel_persists_partial_output(self):
|
|
"""P0 #2: When task is cancelled mid-execution, partial output
|
|
collected before cancellation is persisted to conversation store."""
|
|
first_event = _make_event("final_answer", data={"output": "Partial before cancel"})
|
|
engine = CancellableReactEngine(first_event)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("test-task", "test-agent", {"message": "hello"})
|
|
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
)
|
|
|
|
# Wait for the engine to yield its first event
|
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
|
bg_task.cancel()
|
|
with self._expect_cancelled():
|
|
await bg_task
|
|
|
|
# Partial output should be persisted
|
|
assert len(conv_store.messages) == 1
|
|
_, role, content = conv_store.messages[0]
|
|
assert role == "assistant"
|
|
assert content == "Partial before cancel"
|
|
|
|
async def test_cancel_marks_task_failed_in_store(self):
|
|
"""P0 #2: CancelledError marks task status as FAILED in TaskStore."""
|
|
first_event = _make_event("final_answer", data={"output": "Partial"})
|
|
engine = CancellableReactEngine(first_event)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("test-task", "test-agent", {"message": "hello"})
|
|
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
)
|
|
|
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
|
bg_task.cancel()
|
|
with self._expect_cancelled():
|
|
await bg_task
|
|
|
|
record = task_store.get("test-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.FAILED
|
|
assert record.error_message is not None
|
|
assert "cancelled" in record.error_message.lower()
|
|
|
|
async def test_cancel_emits_task_failed_event(self):
|
|
"""P0 #2: CancelledError emits TASK_FAILED event to EventQueue."""
|
|
first_event = _make_event("final_answer", data={"output": "Partial"})
|
|
engine = CancellableReactEngine(first_event)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
received: list[Event] = []
|
|
|
|
async def subscriber():
|
|
async for evt in eq.subscribe(task_id="test-task"):
|
|
received.append(evt)
|
|
if evt.event_type == TaskEventType.TASK_FAILED:
|
|
break
|
|
|
|
sub_task = asyncio.create_task(subscriber())
|
|
await asyncio.sleep(0.05)
|
|
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
)
|
|
|
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
|
bg_task.cancel()
|
|
with self._expect_cancelled():
|
|
await bg_task
|
|
|
|
await asyncio.wait_for(sub_task, timeout=2.0)
|
|
|
|
failed_events = [e for e in received if e.event_type == TaskEventType.TASK_FAILED]
|
|
assert len(failed_events) == 1
|
|
assert "cancelled" in failed_events[0].data.get("error", "").lower()
|
|
|
|
@staticmethod
|
|
def _expect_cancelled():
|
|
"""Context manager that expects asyncio.CancelledError to be raised."""
|
|
import contextlib
|
|
|
|
return contextlib.suppress(asyncio.CancelledError)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P0 #3: Resume handler conversation_id ownership verification tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResumeOwnershipVerification:
|
|
"""P0 #3: Verify resume path rejects tasks from a different conversation.
|
|
|
|
These tests exercise the TaskStore metadata check directly, since the
|
|
WebSocket resume handler reads metadata to verify ownership.
|
|
"""
|
|
|
|
async def test_resume_rejects_mismatched_conversation_id(self):
|
|
"""P0 #3: Task with conversation_id mismatch should be rejected.
|
|
|
|
Simulates the metadata check performed in portal.py resume handler:
|
|
if record.metadata['conversation_id'] != request conversation_id,
|
|
the resume is rejected with an error.
|
|
"""
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("task-X", "agent", {"message": "hello"})
|
|
task_store.update_status(
|
|
"task-X",
|
|
TaskStatus.RUNNING,
|
|
metadata={"conversation_id": "conv-A"},
|
|
)
|
|
|
|
record = task_store.get("task-X")
|
|
assert record is not None
|
|
|
|
# Simulate the ownership check from portal.py resume handler
|
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
|
request_conv_id = "conv-B" # Different conversation
|
|
|
|
assert task_conv_id != request_conv_id
|
|
# In portal.py, this mismatch triggers an error response and `continue`
|
|
|
|
async def test_resume_accepts_matching_conversation_id(self):
|
|
"""P0 #3: Task with matching conversation_id should be allowed to resume."""
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("task-Y", "agent", {"message": "hello"})
|
|
task_store.update_status(
|
|
"task-Y",
|
|
TaskStatus.RUNNING,
|
|
metadata={"conversation_id": "conv-A"},
|
|
)
|
|
|
|
record = task_store.get("task-Y")
|
|
assert record is not None
|
|
|
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
|
request_conv_id = "conv-A" # Same conversation
|
|
|
|
assert task_conv_id == request_conv_id
|
|
# In portal.py, this passes the ownership check and proceeds to subscribe
|
|
|
|
async def test_resume_rejects_when_metadata_missing(self):
|
|
"""P1 #3: When task metadata has no conversation_id, resume is
|
|
rejected (fail-closed). Previously this was allowed for backward
|
|
compatibility, but the security review identified this as a
|
|
bypass vector — an attacker can omit conversation_id to subscribe
|
|
to any task's events."""
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("task-Z", "agent", {"message": "hello"})
|
|
# No metadata update — metadata defaults to {}
|
|
|
|
record = task_store.get("task-Z")
|
|
assert record is not None
|
|
|
|
task_conv_id = (record.metadata or {}).get("conversation_id", "")
|
|
request_conv_id = "conv-A"
|
|
|
|
# P1 #3 fix: fail-closed — reject if task_conv_id is missing
|
|
should_reject = not task_conv_id or task_conv_id != request_conv_id
|
|
assert should_reject
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P0 #4: Cancel propagation tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCancelPropagation:
|
|
"""P0 #4: Verify explicit cancel (msg_type == 'cancel') propagates
|
|
correctly to the background task and marks it FAILED."""
|
|
|
|
async def test_explicit_cancel_marks_task_failed(self):
|
|
"""P0 #4: When a background task is explicitly cancelled (simulating
|
|
the msg_type == 'cancel' handler), it should propagate CancelledError
|
|
and mark the task FAILED with partial output persisted."""
|
|
first_event = _make_event("final_answer", data={"output": "Partial before user cancel"})
|
|
engine = CancellableReactEngine(first_event)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("cancel-task", "agent", {"message": "hello"})
|
|
|
|
# Simulate the background task as portal.py would create it
|
|
active_bg_task: asyncio.Task | None = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="cancel-conv",
|
|
task_id="cancel-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
)
|
|
|
|
await asyncio.wait_for(engine.started.wait(), timeout=2.0)
|
|
|
|
# Simulate the cancel handler: active_bg_task.cancel()
|
|
assert active_bg_task is not None
|
|
assert not active_bg_task.done()
|
|
active_bg_task.cancel()
|
|
|
|
with _suppress_cancelled():
|
|
await active_bg_task
|
|
|
|
# Verify task is marked FAILED
|
|
record = task_store.get("cancel-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.FAILED
|
|
|
|
# Verify partial output was persisted
|
|
assert len(conv_store.messages) == 1
|
|
_, _, content = conv_store.messages[0]
|
|
assert content == "Partial before user cancel"
|
|
|
|
async def test_cancel_after_completion_is_noop(self):
|
|
"""P0 #4: Cancelling an already-completed task is a no-op
|
|
(active_bg_task.done() check prevents double-cancel)."""
|
|
events = [_make_event("final_answer", data={"output": "Done"})]
|
|
engine = FakeReactEngine(events)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
)
|
|
await bg_task
|
|
|
|
# Already done — cancel should be a no-op per portal.py guard:
|
|
# `if active_bg_task is not None and not active_bg_task.done():`
|
|
assert bg_task.done()
|
|
# No exception, no state change
|
|
assert len(conv_store.messages) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P0 #5: WebSocketDisconnect does NOT cancel background task
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWebSocketDisconnectNoCancel:
|
|
"""P0 #5: Verify that WebSocketDisconnect does NOT cancel the background
|
|
task — this is the core invariant of the three-layer defense.
|
|
|
|
The test simulates the portal.py control flow: a background task is
|
|
started, then the WebSocket disconnects (simulated by cancelling the
|
|
subscribe loop but NOT the background task). The background task should
|
|
continue running and persist its result.
|
|
"""
|
|
|
|
async def test_disconnect_does_not_cancel_background_task(self):
|
|
"""P0 #5: After WebSocketDisconnect, the background task continues
|
|
running and persists its result to the conversation store."""
|
|
events = [
|
|
_make_event("thinking", data={"text": "Thinking..."}),
|
|
_make_event("final_answer", data={"output": "Final result"}),
|
|
]
|
|
engine = SlowFakeReactEngine(events, delay=0.2)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
|
|
# Start the background task (as portal.py would)
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="test-conv",
|
|
task_id="test-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
)
|
|
)
|
|
|
|
# Simulate WebSocketDisconnect: the subscribe loop is interrupted,
|
|
# but the background task is NOT cancelled (per P0 #1 fix).
|
|
# We just stop subscribing — bg_task keeps running.
|
|
await asyncio.sleep(0.1) # Let bg_task start
|
|
|
|
# Verify bg_task is still running (not cancelled by disconnect)
|
|
assert not bg_task.done()
|
|
|
|
# Wait for bg_task to complete naturally
|
|
await asyncio.wait_for(bg_task, timeout=5.0)
|
|
|
|
# Result should be persisted despite "disconnect"
|
|
assert len(conv_store.messages) == 1
|
|
_, _, content = conv_store.messages[0]
|
|
assert content == "Final result"
|
|
|
|
async def test_disconnect_result_available_for_resume(self):
|
|
"""P0 #5: After disconnect, the completed task's result is available
|
|
in TaskStore so a reconnecting client can retrieve it via resume."""
|
|
events = [_make_event("final_answer", data={"output": "Resumable result"})]
|
|
engine = FakeReactEngine(events)
|
|
conv_store = FakeConversationStore()
|
|
eq = EventQueue()
|
|
task_store = InMemoryTaskStore()
|
|
task_store.create("resume-task", "agent", {"message": "hello"})
|
|
task_store.update_status(
|
|
"resume-task",
|
|
TaskStatus.RUNNING,
|
|
metadata={"conversation_id": "resume-conv"},
|
|
)
|
|
|
|
bg_task = asyncio.create_task(
|
|
_execute_react_background(
|
|
react_engine=engine,
|
|
messages=[],
|
|
tools=[],
|
|
model="test-model",
|
|
agent_name="test-agent",
|
|
system_prompt=None,
|
|
timeout_seconds=None,
|
|
conv_id="resume-conv",
|
|
task_id="resume-task",
|
|
event_queue=eq,
|
|
conversation_store=conv_store,
|
|
task_store=task_store,
|
|
)
|
|
)
|
|
|
|
# Simulate disconnect: don't cancel bg_task, just let it run
|
|
await asyncio.wait_for(bg_task, timeout=5.0)
|
|
|
|
# Task should be COMPLETED with output available for resume
|
|
record = task_store.get("resume-task")
|
|
assert record is not None
|
|
assert record.status == TaskStatus.COMPLETED
|
|
assert record.output_data is not None
|
|
assert record.output_data.get("output") == "Resumable result"
|