fischer-agentkit/tests/unit/core/test_event_queue.py

667 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for EventQueue — SQ/EQ 双队列实现
测试场景:
- SQ 正确接收用户输入并返回 task_id
- EQ 正确推送事件给订阅者
- 多订阅者同时接收事件(广播)
- 事件缓冲对新订阅者的回放
- SQ 取消任务
- 事件类型正确分类
"""
from __future__ import annotations
import asyncio
from datetime import datetime
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
from agentkit.core.protocol import (
Event,
SessionEventType,
TaskEventType,
TurnEventType,
)
# ── SubmissionQueue Tests ───────────────────────────────────────
class TestSubmissionQueue:
"""SubmissionQueue 单元测试"""
async def test_submit_returns_task_id(self):
"""测试 submit 返回有效的 task_id"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
assert isinstance(task_id, str)
assert len(task_id) > 0
async def test_submit_returns_unique_task_ids(self):
"""测试每次 submit 返回不同的 task_id"""
sq = SubmissionQueue()
task_id_1 = await sq.submit("hello", "session-1")
task_id_2 = await sq.submit("world", "session-1")
assert task_id_1 != task_id_2
async def test_submit_stores_submission(self):
"""测试 submit 正确存储提交内容"""
sq = SubmissionQueue()
task_id = await sq.submit("hello world", "session-1")
assert task_id in sq._submissions
submission = sq._submissions[task_id]
assert submission.content == "hello world"
assert submission.session_id == "session-1"
assert submission.task_id == task_id
assert submission.cancelled is False
async def test_drain_receives_submissions_in_order(self):
"""测试 drain 按提交顺序接收提交"""
sq = SubmissionQueue()
await sq.submit("first", "session-1")
await sq.submit("second", "session-1")
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
if len(received) >= 2:
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert received == ["first", "second"]
async def test_drain_preserves_submission_fields(self):
"""测试 drain 返回的 Submission 字段完整"""
sq = SubmissionQueue()
await sq.submit("hello", "session-1")
received: list[Submission] = []
async def consumer():
async for submission in sq.drain():
received.append(submission)
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert len(received) == 1
sub = received[0]
assert sub.content == "hello"
assert sub.session_id == "session-1"
assert isinstance(sub.task_id, str)
assert isinstance(sub.created_at, datetime)
async def test_cancel_task_succeeds(self):
"""测试取消已存在的任务"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
result = await sq.cancel(task_id)
assert result is True
assert task_id in sq._cancelled_tasks
assert sq._submissions[task_id].cancelled is True
async def test_cancel_nonexistent_task_returns_false(self):
"""测试取消不存在的任务返回 False"""
sq = SubmissionQueue()
result = await sq.cancel("nonexistent-task-id")
assert result is False
async def test_cancel_already_cancelled_task_returns_false(self):
"""测试重复取消返回 False"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
first_cancel = await sq.cancel(task_id)
second_cancel = await sq.cancel(task_id)
assert first_cancel is True
assert second_cancel is False
async def test_drain_skips_cancelled_submissions(self):
"""测试 drain 跳过已取消的提交"""
sq = SubmissionQueue()
task_id_1 = await sq.submit("first", "session-1")
await sq.submit("second", "session-1")
# 取消第一个提交
await sq.cancel(task_id_1)
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
if len(received) >= 1:
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
# 应只收到第二个提交
assert received == ["second"]
async def test_close_prevents_new_submissions(self):
"""测试关闭后不再接受新提交"""
sq = SubmissionQueue()
sq.close()
assert sq.is_closed is True
try:
await sq.submit("hello", "session-1")
raise AssertionError("Should have raised RuntimeError")
except RuntimeError:
pass
async def test_close_does_not_affect_existing_submissions(self):
"""测试关闭后已提交的内容仍可消费"""
sq = SubmissionQueue()
await sq.submit("before-close", "session-1")
sq.close()
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert received == ["before-close"]
# ── EventQueue Tests ────────────────────────────────────────────
class TestEventQueue:
"""EventQueue 单元测试"""
async def test_emit_and_subscribe_single_event(self):
"""测试 EQ 正确推送事件给订阅者"""
eq = EventQueue()
event = Event.create(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
)
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05) # 给订阅者启动时间
await eq.emit(event)
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 1
assert received[0].event_type == TurnEventType.TOKEN
assert received[0].task_id == "task-1"
assert received[0].session_id == "session-1"
assert received[0].data == {"text": "hello"}
async def test_emit_preserves_event_fields(self):
"""测试 emit 不修改事件字段"""
eq = EventQueue()
original = Event.create(
event_type=TaskEventType.TASK_STARTED,
task_id="task-1",
session_id="session-1",
data={"agent": "react"},
)
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
await eq.emit(original)
await asyncio.wait_for(sub_task, timeout=1.0)
assert received[0] is original or received[0].to_dict() == original.to_dict()
async def test_broadcast_to_multiple_subscribers(self):
"""测试多订阅者同时接收事件(广播)"""
eq = EventQueue()
received_a: list[Event] = []
received_b: list[Event] = []
async def subscriber_a():
async for evt in eq.subscribe():
received_a.append(evt)
if len(received_a) >= 2:
break
async def subscriber_b():
async for evt in eq.subscribe():
received_b.append(evt)
if len(received_b) >= 2:
break
task_a = asyncio.create_task(subscriber_a())
task_b = asyncio.create_task(subscriber_b())
await asyncio.sleep(0.05) # 给订阅者启动时间
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
await asyncio.wait_for(task_a, timeout=1.0)
await asyncio.wait_for(task_b, timeout=1.0)
assert len(received_a) == 2
assert len(received_b) == 2
assert received_a[0].data == {"seq": 1}
assert received_a[1].data == {"seq": 2}
assert received_b[0].data == {"seq": 1}
assert received_b[1].data == {"seq": 2}
async def test_buffer_replay_for_new_subscriber(self):
"""测试事件缓冲对新订阅者的回放"""
eq = EventQueue(buffer_size=100)
# 先发送几条事件(无订阅者)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
# 新订阅者应收到缓冲回放
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 3:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 3
assert received[0].data == {"seq": 1}
assert received[1].data == {"seq": 2}
assert received[2].data == {"seq": 3}
async def test_buffer_replay_then_live_events(self):
"""测试新订阅者先收到回放,再收到新事件"""
eq = EventQueue(buffer_size=100)
# 先发送 2 条事件(进入缓冲)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 4:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲)
# 再发送 2 条新事件
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4}))
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 4
assert [r.data["seq"] for r in received] == [1, 2, 3, 4]
async def test_buffer_size_limit_keeps_latest(self):
"""测试缓冲区大小限制,只保留最新 N 条"""
eq = EventQueue(buffer_size=3)
# 发送 5 条事件,缓冲区只保留最后 3 条
for i in range(5):
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i}))
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 3:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 3
# 应该是最后 3 条seq: 2, 3, 4
assert [r.data["seq"] for r in received] == [2, 3, 4]
async def test_default_buffer_size_is_100(self):
"""测试默认缓冲区大小为 100"""
eq = EventQueue()
assert eq.buffer_size == 100
async def test_close_unblocks_subscribers(self):
"""测试 close 解除订阅者阻塞"""
eq = EventQueue()
async def subscriber():
async for _ in eq.subscribe():
pass # 消费事件直到队列关闭
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
eq.close()
await asyncio.wait_for(sub_task, timeout=1.0)
assert sub_task.done()
assert eq.is_closed is True
async def test_subscribe_after_close_returns_immediately(self):
"""测试关闭后订阅立即返回(不阻塞)"""
eq = EventQueue()
eq.close()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert sub_task.done()
assert len(received) == 0
async def test_subscriber_count_tracks_subscriptions(self):
"""测试订阅者计数正确跟踪订阅"""
eq = EventQueue()
assert eq.subscriber_count == 0
async def subscriber():
async for _ in eq.subscribe():
pass
task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
assert eq.subscriber_count == 1
eq.close()
await asyncio.wait_for(task, timeout=1.0)
assert eq.subscriber_count == 0
async def test_subscriber_removed_on_explicit_close(self):
"""测试显式关闭订阅生成器后从列表移除
注意async for 的 break 不会立即触发生成器的 finally
需要显式调用 aclose() 才能保证清理。
"""
eq = EventQueue()
received: list[Event] = []
async def subscriber():
gen = eq.subscribe()
try:
async for evt in gen:
received.append(evt)
break
finally:
await gen.aclose()
task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
assert eq.subscriber_count == 1
# 触发一次 emit 让订阅者能收到事件并 break
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await asyncio.wait_for(task, timeout=1.0)
assert len(received) == 1
assert eq.subscriber_count == 0
async def test_emit_to_no_subscribers_still_buffers(self):
"""测试无订阅者时 emit 仍写入缓冲区"""
eq = EventQueue(buffer_size=100)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
# 缓冲区应有 1 条
assert len(eq._buffer) == 1
# 新订阅者应收到回放
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 1
assert received[0].data == {"seq": 1}
# ── Event Type Tests ────────────────────────────────────────────
class TestEventTypes:
"""事件类型分类测试"""
def test_session_event_types(self):
"""测试 Session 级别事件类型"""
assert SessionEventType.SESSION_STARTED == "session.started"
assert SessionEventType.SESSION_ENDED == "session.ended"
def test_task_event_types(self):
"""测试 Task 级别事件类型"""
assert TaskEventType.TASK_CREATED == "task.created"
assert TaskEventType.TASK_STARTED == "task.started"
assert TaskEventType.TASK_COMPLETED == "task.completed"
assert TaskEventType.TASK_FAILED == "task.failed"
def test_turn_event_types(self):
"""测试 Turn 级别事件类型"""
assert TurnEventType.TURN_STARTED == "turn.started"
assert TurnEventType.THINKING == "turn.thinking"
assert TurnEventType.TOOL_CALL == "turn.tool_call"
assert TurnEventType.TOOL_RESULT == "turn.tool_result"
assert TurnEventType.TOKEN == "turn.token"
assert TurnEventType.STEP == "turn.step"
assert TurnEventType.FINAL_ANSWER == "turn.final_answer"
assert TurnEventType.TURN_COMPLETED == "turn.completed"
def test_event_type_prefixes(self):
"""测试事件类型按前缀正确分类"""
session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED]
task_events = [
TaskEventType.TASK_CREATED,
TaskEventType.TASK_STARTED,
TaskEventType.TASK_COMPLETED,
TaskEventType.TASK_FAILED,
]
turn_events = [
TurnEventType.TURN_STARTED,
TurnEventType.THINKING,
TurnEventType.TOOL_CALL,
TurnEventType.TOOL_RESULT,
TurnEventType.TOKEN,
TurnEventType.STEP,
TurnEventType.FINAL_ANSWER,
TurnEventType.TURN_COMPLETED,
]
for evt in session_events:
assert evt.startswith("session."), f"{evt} should start with 'session.'"
for evt in task_events:
assert evt.startswith("task."), f"{evt} should start with 'task.'"
for evt in turn_events:
assert evt.startswith("turn."), f"{evt} should start with 'turn.'"
def test_event_types_are_distinct(self):
"""测试所有事件类型互不相同"""
all_types = [
SessionEventType.SESSION_STARTED,
SessionEventType.SESSION_ENDED,
TaskEventType.TASK_CREATED,
TaskEventType.TASK_STARTED,
TaskEventType.TASK_COMPLETED,
TaskEventType.TASK_FAILED,
TurnEventType.TURN_STARTED,
TurnEventType.THINKING,
TurnEventType.TOOL_CALL,
TurnEventType.TOOL_RESULT,
TurnEventType.TOKEN,
TurnEventType.STEP,
TurnEventType.FINAL_ANSWER,
TurnEventType.TURN_COMPLETED,
]
assert len(all_types) == len(set(all_types)), "Event types should be distinct"
# ── Event Dataclass Tests ───────────────────────────────────────
class TestEventDataclass:
"""Event 数据结构测试"""
def test_event_creation(self):
"""测试 Event 创建"""
event = Event(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
timestamp="2025-01-01T00:00:00+00:00",
)
assert event.event_type == TurnEventType.TOKEN
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"text": "hello"}
assert event.timestamp == "2025-01-01T00:00:00+00:00"
def test_event_create_factory_generates_timestamp(self):
"""测试 Event.create 工厂方法自动生成时间戳"""
event = Event.create(
event_type=TaskEventType.TASK_STARTED,
task_id="task-1",
session_id="session-1",
data={"agent": "react"},
)
assert event.event_type == TaskEventType.TASK_STARTED
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"agent": "react"}
assert len(event.timestamp) > 0
# 时间戳应为 ISO 8601 格式fromisoformat 不抛异常即正确)
datetime.fromisoformat(event.timestamp)
def test_event_create_with_default_data(self):
"""测试 Event.create 不传 data 时默认为空 dict"""
event = Event.create(
event_type=SessionEventType.SESSION_STARTED,
task_id="task-1",
session_id="session-1",
)
assert event.data == {}
def test_event_to_dict(self):
"""测试 Event.to_dict"""
event = Event(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
timestamp="2025-01-01T00:00:00+00:00",
)
d = event.to_dict()
assert d == {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"data": {"text": "hello"},
"timestamp": "2025-01-01T00:00:00+00:00",
}
def test_event_from_dict(self):
"""测试 Event.from_dict"""
data = {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"data": {"text": "hello"},
"timestamp": "2025-01-01T00:00:00+00:00",
}
event = Event.from_dict(data)
assert event.event_type == TurnEventType.TOKEN
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"text": "hello"}
assert event.timestamp == "2025-01-01T00:00:00+00:00"
def test_event_to_dict_from_dict_roundtrip(self):
"""测试 to_dict -> from_dict 往返保持一致"""
original = Event.create(
event_type=SessionEventType.SESSION_STARTED,
task_id="task-1",
session_id="session-1",
data={"user": "alice"},
)
d = original.to_dict()
restored = Event.from_dict(d)
assert restored.event_type == original.event_type
assert restored.task_id == original.task_id
assert restored.session_id == original.session_id
assert restored.data == original.data
assert restored.timestamp == original.timestamp
def test_event_from_dict_with_missing_data_defaults_to_empty(self):
"""测试 from_dict 缺少 data 字段时默认为空 dict"""
data = {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"timestamp": "2025-01-01T00:00:00+00:00",
}
event = Event.from_dict(data)
assert event.data == {}