220 lines
7.4 KiB
Python
220 lines
7.4 KiB
Python
"""Middleware pipeline architecture (U6)
|
||
|
||
洋葱模型中间件管道,将横切关注点(压缩/计量/安全/循环检测)集中化。
|
||
并行接入(KTD1):通过 feature flag 控制,与现有 ReActEngine 路径共存。
|
||
|
||
Usage::
|
||
|
||
from agentkit.core.middleware import (
|
||
MiddlewareChain,
|
||
RequestContext,
|
||
SummarizationMiddleware,
|
||
TokenUsageMiddleware,
|
||
)
|
||
|
||
chain = MiddlewareChain([
|
||
SummarizationMiddleware(compressor),
|
||
TokenUsageMiddleware(),
|
||
])
|
||
result = await chain.execute(ctx, handler)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class RequestContext:
|
||
"""请求上下文,贯穿中间件链。
|
||
|
||
中间件可读写此上下文,实现中间件间状态传递。
|
||
"""
|
||
|
||
messages: list[dict[str, str]]
|
||
tools: list[Any] = field(default_factory=list)
|
||
system_prompt: str | None = None
|
||
model: str = "default"
|
||
agent_name: str = ""
|
||
task_type: str = ""
|
||
task_id: str | None = None
|
||
# 中间件间共享状态(压缩结果、token 用量、循环检测状态等)
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@runtime_checkable
|
||
class Middleware(Protocol):
|
||
"""中间件协议 — 洋葱模型的 before/after 钩子。
|
||
|
||
before: 请求处理前调用,可修改 ctx(如压缩 conversation)
|
||
after: 请求处理后调用,可修改 result(如附加 token 用量)
|
||
"""
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext: ...
|
||
|
||
async def after(self, ctx: RequestContext, result: Any) -> Any: ...
|
||
|
||
|
||
class MiddlewareChain:
|
||
"""洋葱模型中间件链。
|
||
|
||
before 由外到内执行(A → B → C),
|
||
handler 执行,
|
||
after 由内到外执行(C → B → A)。
|
||
|
||
若某中间件 before 抛异常,后续 before 不执行,after 链不触发,
|
||
异常向上传播由调用者处理。
|
||
"""
|
||
|
||
def __init__(self, middlewares: list[Middleware] | None = None) -> None:
|
||
self._middlewares: list[Middleware] = middlewares or []
|
||
|
||
@property
|
||
def middlewares(self) -> list[Middleware]:
|
||
return list(self._middlewares)
|
||
|
||
async def execute(
|
||
self,
|
||
ctx: RequestContext,
|
||
handler: Callable[[RequestContext], Awaitable[Any]],
|
||
) -> Any:
|
||
"""执行中间件链 + handler。
|
||
|
||
洋葱模型:before 顺序执行 → handler → after 逆序执行。
|
||
"""
|
||
if not self._middlewares:
|
||
return await handler(ctx)
|
||
|
||
# before: 外 → 内
|
||
# before 异常自然传播,不执行 after 链
|
||
executed_befores: list[Middleware] = []
|
||
current_ctx = ctx
|
||
for mw in self._middlewares:
|
||
current_ctx = await mw.before(current_ctx)
|
||
executed_befores.append(mw)
|
||
|
||
# handler
|
||
result = await handler(current_ctx)
|
||
|
||
# after: 内 → 外(逆序)
|
||
for mw in reversed(executed_befores):
|
||
try:
|
||
result = await mw.after(current_ctx, result)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"Middleware %s.after() failed: %s — continuing with current result",
|
||
type(mw).__name__,
|
||
e,
|
||
)
|
||
|
||
return result
|
||
|
||
|
||
# ── 示例中间件 ──────────────────────────────────────────────
|
||
|
||
|
||
class SummarizationMiddleware:
|
||
"""压缩中间件 — 包装 ContextCompressor 的 headroom 压缩(U3)。
|
||
|
||
before: 若 conversation token 用量超过 headroom 阈值,压缩历史消息
|
||
after: 无操作(压缩在 before 完成)
|
||
"""
|
||
|
||
def __init__(self, compressor: Any = None) -> None:
|
||
self._compressor = compressor
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
if not self._compressor:
|
||
return ctx
|
||
# 检查是否需要压缩
|
||
should_compress_fn = getattr(self._compressor, "should_compress", None)
|
||
if should_compress_fn is not None and should_compress_fn(ctx.messages):
|
||
try:
|
||
compressed = await self._compressor.compress(ctx.messages)
|
||
ctx.messages = compressed
|
||
ctx.metadata["compressed"] = True
|
||
logger.info("SummarizationMiddleware: compressed conversation")
|
||
except Exception as e:
|
||
logger.warning(f"SummarizationMiddleware: compression failed: {e}")
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||
return result
|
||
|
||
|
||
class TokenUsageMiddleware:
|
||
"""Token 计量中间件 — 记录请求的 token 用量。
|
||
|
||
before: 无操作
|
||
after: 从 result 中提取 token usage,记录到 metadata
|
||
"""
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||
# 从 ReActResult 或类似结构提取 token 用量
|
||
# ReActResult 有 total_tokens 属性(非 token_usage)
|
||
usage = getattr(result, "total_tokens", None)
|
||
if usage is not None:
|
||
ctx.metadata["token_usage_total"] = usage
|
||
return result
|
||
|
||
|
||
class LoopDetectionMiddleware:
|
||
"""循环检测中间件 — 包装 U1 的循环检测逻辑。
|
||
|
||
before: 初始化循环检测窗口
|
||
after: 检查最终 trajectory 是否有循环模式,记录警告
|
||
|
||
注意:per-step 循环检测仍在 _execute_loop 内进行(U1),
|
||
此中间件提供请求级的循环模式审计。
|
||
"""
|
||
|
||
def __init__(self, window_size: int = 5, threshold: int = 2) -> None:
|
||
self._window_size = window_size
|
||
self._threshold = threshold
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
ctx.metadata["loop_detection_window"] = []
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||
trajectory = getattr(result, "trajectory", None) or []
|
||
if len(trajectory) < self._threshold:
|
||
return result
|
||
|
||
# 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
|
||
# trajectory 存储 ReActStep dataclass 对象,需同时兼容 dict
|
||
recent = trajectory[-self._window_size :] if trajectory else []
|
||
tool_calls: list[tuple[str, str]] = []
|
||
for step in recent:
|
||
# 兼容 dataclass(ReActStep)和 dict 两种格式
|
||
if isinstance(step, dict):
|
||
name = step.get("tool_name", "")
|
||
args = step.get("arguments", {})
|
||
else:
|
||
name = getattr(step, "tool_name", "") or ""
|
||
args = getattr(step, "arguments", {}) or {}
|
||
if name:
|
||
args_str = json.dumps(args, sort_keys=True, default=str) if args else ""
|
||
tool_calls.append((name, args_str))
|
||
if not tool_calls:
|
||
return result
|
||
|
||
# 滑动窗口检测
|
||
window = tool_calls
|
||
unique = set(window)
|
||
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
|
||
logger.warning(
|
||
"LoopDetectionMiddleware: detected repeated tool calls in final trajectory"
|
||
)
|
||
ctx.metadata["loop_detected"] = True
|
||
|
||
return result
|