281 lines
9.0 KiB
Python
281 lines
9.0 KiB
Python
"""统一事件模型 - SQ/EQ 双队列实现
|
||
|
||
参考 Codex 的 Session/Task/Turn 三级模型,提供:
|
||
- SubmissionQueue (SQ): 接收用户输入,返回 task_id
|
||
- EventQueue (EQ): 推送 Agent 事件,支持多订阅者广播和事件缓冲回放
|
||
|
||
集成点(Phase 4 再做实际集成):
|
||
- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端
|
||
- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
from collections import deque
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import AsyncIterator
|
||
|
||
from agentkit.core.protocol import Event
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流)
|
||
_CLOSED_SENTINEL: Event = Event(
|
||
event_type="__closed__",
|
||
task_id="",
|
||
session_id="",
|
||
data={},
|
||
timestamp="",
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class _Subscriber:
|
||
"""Internal subscriber tracking with optional task_id filter."""
|
||
|
||
queue: asyncio.Queue[Event]
|
||
task_id_filter: str | None = None # None = receive all events
|
||
|
||
def matches(self, event: Event) -> bool:
|
||
"""Check if this subscriber should receive the given event."""
|
||
if self.task_id_filter is None:
|
||
return True
|
||
return event.task_id == self.task_id_filter
|
||
|
||
|
||
@dataclass
|
||
class Submission:
|
||
"""用户提交的任务
|
||
|
||
由 SubmissionQueue.submit() 创建,消费者通过 SubmissionQueue.drain() 获取。
|
||
|
||
Attributes:
|
||
task_id: 唯一任务 ID
|
||
session_id: 会话 ID
|
||
content: 用户输入内容
|
||
created_at: 创建时间
|
||
cancelled: 是否已取消
|
||
"""
|
||
|
||
task_id: str
|
||
session_id: str
|
||
content: str
|
||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||
cancelled: bool = False
|
||
|
||
|
||
class SubmissionQueue:
|
||
"""提交队列 (SQ) - 接收用户输入
|
||
|
||
用户通过 submit() 提交输入,消费者通过 drain() 获取提交。
|
||
支持通过 task_id 取消提交。
|
||
|
||
内部使用 asyncio.Queue 实现,队列大小上限 1024(与 HandoffTransport 一致)。
|
||
"""
|
||
|
||
_MAX_QUEUE_SIZE: int = 1024
|
||
|
||
def __init__(self) -> None:
|
||
self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||
self._submissions: dict[str, Submission] = {}
|
||
self._cancelled_tasks: set[str] = set()
|
||
self._closed: bool = False
|
||
|
||
async def submit(self, content: str, session_id: str) -> str:
|
||
"""提交用户输入,返回 task_id
|
||
|
||
Args:
|
||
content: 用户输入内容
|
||
session_id: 会话 ID
|
||
|
||
Returns:
|
||
新生成的 task_id(UUID4)
|
||
|
||
Raises:
|
||
RuntimeError: 队列已关闭
|
||
"""
|
||
if self._closed:
|
||
raise RuntimeError("SubmissionQueue is closed")
|
||
task_id = str(uuid.uuid4())
|
||
submission = Submission(
|
||
task_id=task_id,
|
||
session_id=session_id,
|
||
content=content,
|
||
)
|
||
self._submissions[task_id] = submission
|
||
await self._queue.put(submission)
|
||
return task_id
|
||
|
||
async def drain(self) -> AsyncIterator[Submission]:
|
||
"""消费提交队列(异步生成器)
|
||
|
||
已取消的提交会被跳过。当队列关闭时,生成器结束。
|
||
"""
|
||
while True:
|
||
submission = await self._queue.get()
|
||
if submission.task_id in self._cancelled_tasks:
|
||
continue
|
||
yield submission
|
||
|
||
async def cancel(self, task_id: str) -> bool:
|
||
"""取消任务
|
||
|
||
Args:
|
||
task_id: 要取消的任务 ID
|
||
|
||
Returns:
|
||
是否成功取消(任务存在且未取消过)
|
||
"""
|
||
if task_id not in self._submissions:
|
||
return False
|
||
if task_id in self._cancelled_tasks:
|
||
return False
|
||
self._cancelled_tasks.add(task_id)
|
||
self._submissions[task_id].cancelled = True
|
||
return True
|
||
|
||
@property
|
||
def is_closed(self) -> bool:
|
||
"""返回队列是否已关闭"""
|
||
return self._closed
|
||
|
||
def close(self) -> None:
|
||
"""关闭队列,不再接受新提交。
|
||
|
||
已在队列中的提交不受影响,消费者仍可通过 drain() 获取。
|
||
"""
|
||
self._closed = True
|
||
|
||
|
||
class EventQueue:
|
||
"""事件队列 (EQ) - 推送 Agent 事件
|
||
|
||
支持多订阅者广播模式:每条事件会投递到所有活跃订阅者。
|
||
新订阅者会收到最近 N 条缓冲事件的回放(默认 100 条)。
|
||
|
||
集成点:
|
||
- Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端
|
||
- CLI 可通过 subscribe() 订阅事件流并打印
|
||
"""
|
||
|
||
_MAX_QUEUE_SIZE: int = 1024
|
||
_DEFAULT_BUFFER_SIZE: int = 100
|
||
# P1 #13 fix: cap total subscribers to prevent resource exhaustion
|
||
# from malicious resume floods or runaway client loops.
|
||
_MAX_SUBSCRIBERS: int = 1000
|
||
|
||
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
|
||
self._subscribers: list[_Subscriber] = []
|
||
self._buffer: deque[Event] = deque(maxlen=buffer_size)
|
||
self._buffer_size = buffer_size
|
||
self._closed: bool = False
|
||
|
||
async def emit(self, event: Event) -> None:
|
||
"""推送事件给所有订阅者
|
||
|
||
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
|
||
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
|
||
支持按 task_id 过滤:只有 task_id 匹配的订阅者才会收到事件。
|
||
|
||
Args:
|
||
event: 要推送的事件
|
||
"""
|
||
self._buffer.append(event)
|
||
for sub in self._subscribers:
|
||
if not sub.matches(event):
|
||
continue
|
||
try:
|
||
sub.queue.put_nowait(event)
|
||
except asyncio.QueueFull:
|
||
logger.warning("EventQueue subscriber queue full, dropping event")
|
||
|
||
async def subscribe(self, task_id: str | None = None) -> AsyncIterator[Event]:
|
||
"""订阅事件流(异步生成器)
|
||
|
||
订阅时会先回放缓冲区中的事件(按 task_id 过滤),然后持续接收新事件。
|
||
每个订阅者获得独立的队列,实现广播语义。
|
||
|
||
当队列关闭时,生成器结束。
|
||
|
||
Args:
|
||
task_id: 可选的任务 ID 过滤器。如果提供,只接收该任务的 events。
|
||
None 表示接收所有事件。
|
||
|
||
注意:回放和加入订阅者列表在同一同步段内完成(无 await),
|
||
保证不会遗漏或重复事件。
|
||
"""
|
||
if self._closed:
|
||
return
|
||
|
||
# P1 #13 fix: enforce subscriber cap to prevent resource exhaustion
|
||
# from malicious resume floods or runaway client loops.
|
||
if len(self._subscribers) >= self._MAX_SUBSCRIBERS:
|
||
logger.error(
|
||
"EventQueue subscriber limit reached (%d), rejecting new subscription",
|
||
self._MAX_SUBSCRIBERS,
|
||
)
|
||
raise RuntimeError(f"EventQueue subscriber limit reached ({self._MAX_SUBSCRIBERS})")
|
||
|
||
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||
|
||
# 回放缓冲事件(同步操作,无 await,保证原子性)
|
||
for event in list(self._buffer):
|
||
if task_id is not None and event.task_id != task_id:
|
||
continue
|
||
try:
|
||
queue.put_nowait(event)
|
||
except asyncio.QueueFull:
|
||
logger.warning("EventQueue replay buffer full, skipping remaining")
|
||
break
|
||
|
||
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
|
||
sub = _Subscriber(queue=queue, task_id_filter=task_id)
|
||
self._subscribers.append(sub)
|
||
|
||
try:
|
||
while True:
|
||
event = await queue.get()
|
||
if event is _CLOSED_SENTINEL:
|
||
break
|
||
yield event
|
||
finally:
|
||
# 清理:移除当前订阅者
|
||
if sub in self._subscribers:
|
||
self._subscribers.remove(sub)
|
||
|
||
@property
|
||
def subscriber_count(self) -> int:
|
||
"""返回当前订阅者数量"""
|
||
return len(self._subscribers)
|
||
|
||
@property
|
||
def buffer_size(self) -> int:
|
||
"""返回缓冲区大小上限"""
|
||
return self._buffer_size
|
||
|
||
@property
|
||
def is_closed(self) -> bool:
|
||
"""返回队列是否已关闭"""
|
||
return self._closed
|
||
|
||
def close(self) -> None:
|
||
"""关闭队列,向所有订阅者发送哨兵并清理状态。
|
||
|
||
使用哨兵模式(参考 HandoffTransport),确保阻塞在 subscribe() 上的
|
||
订阅者能够优雅退出。
|
||
"""
|
||
self._closed = True
|
||
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
|
||
for sub in self._subscribers:
|
||
try:
|
||
sub.queue.put_nowait(_CLOSED_SENTINEL)
|
||
except asyncio.QueueFull:
|
||
pass
|
||
self._subscribers.clear()
|
||
self._buffer.clear()
|