"""Tests for EventQueue — SQ/EQ 双队列实现 测试场景: - SQ 正确接收用户输入并返回 task_id - EQ 正确推送事件给订阅者 - 多订阅者同时接收事件(广播) - 事件缓冲对新订阅者的回放 - SQ 取消任务 - 事件类型正确分类 """ from __future__ import annotations import asyncio from datetime import datetime from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue from agentkit.core.protocol import ( Event, SessionEventType, TaskEventType, TurnEventType, ) # ── SubmissionQueue Tests ─────────────────────────────────────── class TestSubmissionQueue: """SubmissionQueue 单元测试""" async def test_submit_returns_task_id(self): """测试 submit 返回有效的 task_id""" sq = SubmissionQueue() task_id = await sq.submit("hello", "session-1") assert isinstance(task_id, str) assert len(task_id) > 0 async def test_submit_returns_unique_task_ids(self): """测试每次 submit 返回不同的 task_id""" sq = SubmissionQueue() task_id_1 = await sq.submit("hello", "session-1") task_id_2 = await sq.submit("world", "session-1") assert task_id_1 != task_id_2 async def test_submit_stores_submission(self): """测试 submit 正确存储提交内容""" sq = SubmissionQueue() task_id = await sq.submit("hello world", "session-1") assert task_id in sq._submissions submission = sq._submissions[task_id] assert submission.content == "hello world" assert submission.session_id == "session-1" assert submission.task_id == task_id assert submission.cancelled is False async def test_drain_receives_submissions_in_order(self): """测试 drain 按提交顺序接收提交""" sq = SubmissionQueue() await sq.submit("first", "session-1") await sq.submit("second", "session-1") received: list[str] = [] async def consumer(): async for submission in sq.drain(): received.append(submission.content) if len(received) >= 2: break consumer_task = asyncio.create_task(consumer()) await asyncio.wait_for(consumer_task, timeout=1.0) assert received == ["first", "second"] async def test_drain_preserves_submission_fields(self): """测试 drain 返回的 Submission 字段完整""" sq = SubmissionQueue() await sq.submit("hello", "session-1") received: list[Submission] = [] async def consumer(): async for submission in sq.drain(): received.append(submission) break consumer_task = asyncio.create_task(consumer()) await asyncio.wait_for(consumer_task, timeout=1.0) assert len(received) == 1 sub = received[0] assert sub.content == "hello" assert sub.session_id == "session-1" assert isinstance(sub.task_id, str) assert isinstance(sub.created_at, datetime) async def test_cancel_task_succeeds(self): """测试取消已存在的任务""" sq = SubmissionQueue() task_id = await sq.submit("hello", "session-1") result = await sq.cancel(task_id) assert result is True assert task_id in sq._cancelled_tasks assert sq._submissions[task_id].cancelled is True async def test_cancel_nonexistent_task_returns_false(self): """测试取消不存在的任务返回 False""" sq = SubmissionQueue() result = await sq.cancel("nonexistent-task-id") assert result is False async def test_cancel_already_cancelled_task_returns_false(self): """测试重复取消返回 False""" sq = SubmissionQueue() task_id = await sq.submit("hello", "session-1") first_cancel = await sq.cancel(task_id) second_cancel = await sq.cancel(task_id) assert first_cancel is True assert second_cancel is False async def test_drain_skips_cancelled_submissions(self): """测试 drain 跳过已取消的提交""" sq = SubmissionQueue() task_id_1 = await sq.submit("first", "session-1") await sq.submit("second", "session-1") # 取消第一个提交 await sq.cancel(task_id_1) received: list[str] = [] async def consumer(): async for submission in sq.drain(): received.append(submission.content) if len(received) >= 1: break consumer_task = asyncio.create_task(consumer()) await asyncio.wait_for(consumer_task, timeout=1.0) # 应只收到第二个提交 assert received == ["second"] async def test_close_prevents_new_submissions(self): """测试关闭后不再接受新提交""" sq = SubmissionQueue() sq.close() assert sq.is_closed is True try: await sq.submit("hello", "session-1") raise AssertionError("Should have raised RuntimeError") except RuntimeError: pass async def test_close_does_not_affect_existing_submissions(self): """测试关闭后已提交的内容仍可消费""" sq = SubmissionQueue() await sq.submit("before-close", "session-1") sq.close() received: list[str] = [] async def consumer(): async for submission in sq.drain(): received.append(submission.content) break consumer_task = asyncio.create_task(consumer()) await asyncio.wait_for(consumer_task, timeout=1.0) assert received == ["before-close"] # ── EventQueue Tests ──────────────────────────────────────────── class TestEventQueue: """EventQueue 单元测试""" async def test_emit_and_subscribe_single_event(self): """测试 EQ 正确推送事件给订阅者""" eq = EventQueue() event = Event.create( event_type=TurnEventType.TOKEN, task_id="task-1", session_id="session-1", data={"text": "hello"}, ) received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) # 给订阅者启动时间 await eq.emit(event) await asyncio.wait_for(sub_task, timeout=1.0) assert len(received) == 1 assert received[0].event_type == TurnEventType.TOKEN assert received[0].task_id == "task-1" assert received[0].session_id == "session-1" assert received[0].data == {"text": "hello"} async def test_emit_preserves_event_fields(self): """测试 emit 不修改事件字段""" eq = EventQueue() original = Event.create( event_type=TaskEventType.TASK_STARTED, task_id="task-1", session_id="session-1", data={"agent": "react"}, ) received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) await eq.emit(original) await asyncio.wait_for(sub_task, timeout=1.0) assert received[0] is original or received[0].to_dict() == original.to_dict() async def test_broadcast_to_multiple_subscribers(self): """测试多订阅者同时接收事件(广播)""" eq = EventQueue() received_a: list[Event] = [] received_b: list[Event] = [] async def subscriber_a(): async for evt in eq.subscribe(): received_a.append(evt) if len(received_a) >= 2: break async def subscriber_b(): async for evt in eq.subscribe(): received_b.append(evt) if len(received_b) >= 2: break task_a = asyncio.create_task(subscriber_a()) task_b = asyncio.create_task(subscriber_b()) await asyncio.sleep(0.05) # 给订阅者启动时间 await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) await asyncio.wait_for(task_a, timeout=1.0) await asyncio.wait_for(task_b, timeout=1.0) assert len(received_a) == 2 assert len(received_b) == 2 assert received_a[0].data == {"seq": 1} assert received_a[1].data == {"seq": 2} assert received_b[0].data == {"seq": 1} assert received_b[1].data == {"seq": 2} async def test_buffer_replay_for_new_subscriber(self): """测试事件缓冲对新订阅者的回放""" eq = EventQueue(buffer_size=100) # 先发送几条事件(无订阅者) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3})) # 新订阅者应收到缓冲回放 received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) if len(received) >= 3: break sub_task = asyncio.create_task(subscriber()) await asyncio.wait_for(sub_task, timeout=1.0) assert len(received) == 3 assert received[0].data == {"seq": 1} assert received[1].data == {"seq": 2} assert received[2].data == {"seq": 3} async def test_buffer_replay_then_live_events(self): """测试新订阅者先收到回放,再收到新事件""" eq = EventQueue(buffer_size=100) # 先发送 2 条事件(进入缓冲) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) if len(received) >= 4: break sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲) # 再发送 2 条新事件 await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3})) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4})) await asyncio.wait_for(sub_task, timeout=1.0) assert len(received) == 4 assert [r.data["seq"] for r in received] == [1, 2, 3, 4] async def test_buffer_size_limit_keeps_latest(self): """测试缓冲区大小限制,只保留最新 N 条""" eq = EventQueue(buffer_size=3) # 发送 5 条事件,缓冲区只保留最后 3 条 for i in range(5): await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i})) received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) if len(received) >= 3: break sub_task = asyncio.create_task(subscriber()) await asyncio.wait_for(sub_task, timeout=1.0) assert len(received) == 3 # 应该是最后 3 条(seq: 2, 3, 4) assert [r.data["seq"] for r in received] == [2, 3, 4] async def test_default_buffer_size_is_100(self): """测试默认缓冲区大小为 100""" eq = EventQueue() assert eq.buffer_size == 100 async def test_close_unblocks_subscribers(self): """测试 close 解除订阅者阻塞""" eq = EventQueue() async def subscriber(): async for _ in eq.subscribe(): pass # 消费事件直到队列关闭 sub_task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) eq.close() await asyncio.wait_for(sub_task, timeout=1.0) assert sub_task.done() assert eq.is_closed is True async def test_subscribe_after_close_returns_immediately(self): """测试关闭后订阅立即返回(不阻塞)""" eq = EventQueue() eq.close() received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) sub_task = asyncio.create_task(subscriber()) await asyncio.wait_for(sub_task, timeout=1.0) assert sub_task.done() assert len(received) == 0 async def test_subscriber_count_tracks_subscriptions(self): """测试订阅者计数正确跟踪订阅""" eq = EventQueue() assert eq.subscriber_count == 0 async def subscriber(): async for _ in eq.subscribe(): pass task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) assert eq.subscriber_count == 1 eq.close() await asyncio.wait_for(task, timeout=1.0) assert eq.subscriber_count == 0 async def test_subscriber_removed_on_explicit_close(self): """测试显式关闭订阅生成器后从列表移除 注意:async for 的 break 不会立即触发生成器的 finally, 需要显式调用 aclose() 才能保证清理。 """ eq = EventQueue() received: list[Event] = [] async def subscriber(): gen = eq.subscribe() try: async for evt in gen: received.append(evt) break finally: await gen.aclose() task = asyncio.create_task(subscriber()) await asyncio.sleep(0.05) assert eq.subscriber_count == 1 # 触发一次 emit 让订阅者能收到事件并 break await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) await asyncio.wait_for(task, timeout=1.0) assert len(received) == 1 assert eq.subscriber_count == 0 async def test_emit_to_no_subscribers_still_buffers(self): """测试无订阅者时 emit 仍写入缓冲区""" eq = EventQueue(buffer_size=100) await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) # 缓冲区应有 1 条 assert len(eq._buffer) == 1 # 新订阅者应收到回放 received: list[Event] = [] async def subscriber(): async for evt in eq.subscribe(): received.append(evt) break sub_task = asyncio.create_task(subscriber()) await asyncio.wait_for(sub_task, timeout=1.0) assert len(received) == 1 assert received[0].data == {"seq": 1} # ── Event Type Tests ──────────────────────────────────────────── class TestEventTypes: """事件类型分类测试""" def test_session_event_types(self): """测试 Session 级别事件类型""" assert SessionEventType.SESSION_STARTED == "session.started" assert SessionEventType.SESSION_ENDED == "session.ended" def test_task_event_types(self): """测试 Task 级别事件类型""" assert TaskEventType.TASK_CREATED == "task.created" assert TaskEventType.TASK_STARTED == "task.started" assert TaskEventType.TASK_COMPLETED == "task.completed" assert TaskEventType.TASK_FAILED == "task.failed" def test_turn_event_types(self): """测试 Turn 级别事件类型""" assert TurnEventType.TURN_STARTED == "turn.started" assert TurnEventType.THINKING == "turn.thinking" assert TurnEventType.TOOL_CALL == "turn.tool_call" assert TurnEventType.TOOL_RESULT == "turn.tool_result" assert TurnEventType.TOKEN == "turn.token" assert TurnEventType.STEP == "turn.step" assert TurnEventType.FINAL_ANSWER == "turn.final_answer" assert TurnEventType.TURN_COMPLETED == "turn.completed" def test_event_type_prefixes(self): """测试事件类型按前缀正确分类""" session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED] task_events = [ TaskEventType.TASK_CREATED, TaskEventType.TASK_STARTED, TaskEventType.TASK_COMPLETED, TaskEventType.TASK_FAILED, ] turn_events = [ TurnEventType.TURN_STARTED, TurnEventType.THINKING, TurnEventType.TOOL_CALL, TurnEventType.TOOL_RESULT, TurnEventType.TOKEN, TurnEventType.STEP, TurnEventType.FINAL_ANSWER, TurnEventType.TURN_COMPLETED, ] for evt in session_events: assert evt.startswith("session."), f"{evt} should start with 'session.'" for evt in task_events: assert evt.startswith("task."), f"{evt} should start with 'task.'" for evt in turn_events: assert evt.startswith("turn."), f"{evt} should start with 'turn.'" def test_event_types_are_distinct(self): """测试所有事件类型互不相同""" all_types = [ SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED, TaskEventType.TASK_CREATED, TaskEventType.TASK_STARTED, TaskEventType.TASK_COMPLETED, TaskEventType.TASK_FAILED, TurnEventType.TURN_STARTED, TurnEventType.THINKING, TurnEventType.TOOL_CALL, TurnEventType.TOOL_RESULT, TurnEventType.TOKEN, TurnEventType.STEP, TurnEventType.FINAL_ANSWER, TurnEventType.TURN_COMPLETED, ] assert len(all_types) == len(set(all_types)), "Event types should be distinct" # ── Event Dataclass Tests ─────────────────────────────────────── class TestEventDataclass: """Event 数据结构测试""" def test_event_creation(self): """测试 Event 创建""" event = Event( event_type=TurnEventType.TOKEN, task_id="task-1", session_id="session-1", data={"text": "hello"}, timestamp="2025-01-01T00:00:00+00:00", ) assert event.event_type == TurnEventType.TOKEN assert event.task_id == "task-1" assert event.session_id == "session-1" assert event.data == {"text": "hello"} assert event.timestamp == "2025-01-01T00:00:00+00:00" def test_event_create_factory_generates_timestamp(self): """测试 Event.create 工厂方法自动生成时间戳""" event = Event.create( event_type=TaskEventType.TASK_STARTED, task_id="task-1", session_id="session-1", data={"agent": "react"}, ) assert event.event_type == TaskEventType.TASK_STARTED assert event.task_id == "task-1" assert event.session_id == "session-1" assert event.data == {"agent": "react"} assert len(event.timestamp) > 0 # 时间戳应为 ISO 8601 格式(fromisoformat 不抛异常即正确) datetime.fromisoformat(event.timestamp) def test_event_create_with_default_data(self): """测试 Event.create 不传 data 时默认为空 dict""" event = Event.create( event_type=SessionEventType.SESSION_STARTED, task_id="task-1", session_id="session-1", ) assert event.data == {} def test_event_to_dict(self): """测试 Event.to_dict""" event = Event( event_type=TurnEventType.TOKEN, task_id="task-1", session_id="session-1", data={"text": "hello"}, timestamp="2025-01-01T00:00:00+00:00", ) d = event.to_dict() assert d == { "event_type": TurnEventType.TOKEN, "task_id": "task-1", "session_id": "session-1", "data": {"text": "hello"}, "timestamp": "2025-01-01T00:00:00+00:00", } def test_event_from_dict(self): """测试 Event.from_dict""" data = { "event_type": TurnEventType.TOKEN, "task_id": "task-1", "session_id": "session-1", "data": {"text": "hello"}, "timestamp": "2025-01-01T00:00:00+00:00", } event = Event.from_dict(data) assert event.event_type == TurnEventType.TOKEN assert event.task_id == "task-1" assert event.session_id == "session-1" assert event.data == {"text": "hello"} assert event.timestamp == "2025-01-01T00:00:00+00:00" def test_event_to_dict_from_dict_roundtrip(self): """测试 to_dict -> from_dict 往返保持一致""" original = Event.create( event_type=SessionEventType.SESSION_STARTED, task_id="task-1", session_id="session-1", data={"user": "alice"}, ) d = original.to_dict() restored = Event.from_dict(d) assert restored.event_type == original.event_type assert restored.task_id == original.task_id assert restored.session_id == original.session_id assert restored.data == original.data assert restored.timestamp == original.timestamp def test_event_from_dict_with_missing_data_defaults_to_empty(self): """测试 from_dict 缺少 data 字段时默认为空 dict""" data = { "event_type": TurnEventType.TOKEN, "task_id": "task-1", "session_id": "session-1", "timestamp": "2025-01-01T00:00:00+00:00", } event = Event.from_dict(data) assert event.data == {}