fischer-agentkit/src/agentkit/core/event_queue.py

281 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""统一事件模型 - SQ/EQ 双队列实现
参考 Codex 的 Session/Task/Turn 三级模型,提供:
- SubmissionQueue (SQ): 接收用户输入,返回 task_id
- EventQueue (EQ): 推送 Agent 事件,支持多订阅者广播和事件缓冲回放
集成点Phase 4 再做实际集成):
- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端
- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import AsyncIterator
from agentkit.core.protocol import Event
logger = logging.getLogger(__name__)
# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流)
_CLOSED_SENTINEL: Event = Event(
event_type="__closed__",
task_id="",
session_id="",
data={},
timestamp="",
)
@dataclass
class _Subscriber:
"""Internal subscriber tracking with optional task_id filter."""
queue: asyncio.Queue[Event]
task_id_filter: str | None = None # None = receive all events
def matches(self, event: Event) -> bool:
"""Check if this subscriber should receive the given event."""
if self.task_id_filter is None:
return True
return event.task_id == self.task_id_filter
@dataclass
class Submission:
"""用户提交的任务
由 SubmissionQueue.submit() 创建,消费者通过 SubmissionQueue.drain() 获取。
Attributes:
task_id: 唯一任务 ID
session_id: 会话 ID
content: 用户输入内容
created_at: 创建时间
cancelled: 是否已取消
"""
task_id: str
session_id: str
content: str
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
cancelled: bool = False
class SubmissionQueue:
"""提交队列 (SQ) - 接收用户输入
用户通过 submit() 提交输入,消费者通过 drain() 获取提交。
支持通过 task_id 取消提交。
内部使用 asyncio.Queue 实现,队列大小上限 1024与 HandoffTransport 一致)。
"""
_MAX_QUEUE_SIZE: int = 1024
def __init__(self) -> None:
self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
self._submissions: dict[str, Submission] = {}
self._cancelled_tasks: set[str] = set()
self._closed: bool = False
async def submit(self, content: str, session_id: str) -> str:
"""提交用户输入,返回 task_id
Args:
content: 用户输入内容
session_id: 会话 ID
Returns:
新生成的 task_idUUID4
Raises:
RuntimeError: 队列已关闭
"""
if self._closed:
raise RuntimeError("SubmissionQueue is closed")
task_id = str(uuid.uuid4())
submission = Submission(
task_id=task_id,
session_id=session_id,
content=content,
)
self._submissions[task_id] = submission
await self._queue.put(submission)
return task_id
async def drain(self) -> AsyncIterator[Submission]:
"""消费提交队列(异步生成器)
已取消的提交会被跳过。当队列关闭时,生成器结束。
"""
while True:
submission = await self._queue.get()
if submission.task_id in self._cancelled_tasks:
continue
yield submission
async def cancel(self, task_id: str) -> bool:
"""取消任务
Args:
task_id: 要取消的任务 ID
Returns:
是否成功取消(任务存在且未取消过)
"""
if task_id not in self._submissions:
return False
if task_id in self._cancelled_tasks:
return False
self._cancelled_tasks.add(task_id)
self._submissions[task_id].cancelled = True
return True
@property
def is_closed(self) -> bool:
"""返回队列是否已关闭"""
return self._closed
def close(self) -> None:
"""关闭队列,不再接受新提交。
已在队列中的提交不受影响,消费者仍可通过 drain() 获取。
"""
self._closed = True
class EventQueue:
"""事件队列 (EQ) - 推送 Agent 事件
支持多订阅者广播模式:每条事件会投递到所有活跃订阅者。
新订阅者会收到最近 N 条缓冲事件的回放(默认 100 条)。
集成点:
- Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端
- CLI 可通过 subscribe() 订阅事件流并打印
"""
_MAX_QUEUE_SIZE: int = 1024
_DEFAULT_BUFFER_SIZE: int = 100
# P1 #13 fix: cap total subscribers to prevent resource exhaustion
# from malicious resume floods or runaway client loops.
_MAX_SUBSCRIBERS: int = 1000
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
self._subscribers: list[_Subscriber] = []
self._buffer: deque[Event] = deque(maxlen=buffer_size)
self._buffer_size = buffer_size
self._closed: bool = False
async def emit(self, event: Event) -> None:
"""推送事件给所有订阅者
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
支持按 task_id 过滤:只有 task_id 匹配的订阅者才会收到事件。
Args:
event: 要推送的事件
"""
self._buffer.append(event)
for sub in self._subscribers:
if not sub.matches(event):
continue
try:
sub.queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("EventQueue subscriber queue full, dropping event")
async def subscribe(self, task_id: str | None = None) -> AsyncIterator[Event]:
"""订阅事件流(异步生成器)
订阅时会先回放缓冲区中的事件(按 task_id 过滤),然后持续接收新事件。
每个订阅者获得独立的队列,实现广播语义。
当队列关闭时,生成器结束。
Args:
task_id: 可选的任务 ID 过滤器。如果提供,只接收该任务的 events。
None 表示接收所有事件。
注意:回放和加入订阅者列表在同一同步段内完成(无 await
保证不会遗漏或重复事件。
"""
if self._closed:
return
# P1 #13 fix: enforce subscriber cap to prevent resource exhaustion
# from malicious resume floods or runaway client loops.
if len(self._subscribers) >= self._MAX_SUBSCRIBERS:
logger.error(
"EventQueue subscriber limit reached (%d), rejecting new subscription",
self._MAX_SUBSCRIBERS,
)
raise RuntimeError(f"EventQueue subscriber limit reached ({self._MAX_SUBSCRIBERS})")
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
# 回放缓冲事件(同步操作,无 await保证原子性
for event in list(self._buffer):
if task_id is not None and event.task_id != task_id:
continue
try:
queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("EventQueue replay buffer full, skipping remaining")
break
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
sub = _Subscriber(queue=queue, task_id_filter=task_id)
self._subscribers.append(sub)
try:
while True:
event = await queue.get()
if event is _CLOSED_SENTINEL:
break
yield event
finally:
# 清理:移除当前订阅者
if sub in self._subscribers:
self._subscribers.remove(sub)
@property
def subscriber_count(self) -> int:
"""返回当前订阅者数量"""
return len(self._subscribers)
@property
def buffer_size(self) -> int:
"""返回缓冲区大小上限"""
return self._buffer_size
@property
def is_closed(self) -> bool:
"""返回队列是否已关闭"""
return self._closed
def close(self) -> None:
"""关闭队列,向所有订阅者发送哨兵并清理状态。
使用哨兵模式(参考 HandoffTransport确保阻塞在 subscribe() 上的
订阅者能够优雅退出。
"""
self._closed = True
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
for sub in self._subscribers:
try:
sub.queue.put_nowait(_CLOSED_SENTINEL)
except asyncio.QueueFull:
pass
self._subscribers.clear()
self._buffer.clear()