"""Middleware pipeline architecture (U6) 洋葱模型中间件管道,将横切关注点(压缩/计量/安全/循环检测)集中化。 并行接入(KTD1):通过 feature flag 控制,与现有 ReActEngine 路径共存。 Usage:: from agentkit.core.middleware import ( MiddlewareChain, RequestContext, SummarizationMiddleware, TokenUsageMiddleware, ) chain = MiddlewareChain([ SummarizationMiddleware(compressor), TokenUsageMiddleware(), ]) result = await chain.execute(ctx, handler) """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Protocol, runtime_checkable logger = logging.getLogger(__name__) @dataclass class RequestContext: """请求上下文,贯穿中间件链。 中间件可读写此上下文,实现中间件间状态传递。 """ messages: list[dict[str, str]] tools: list[Any] = field(default_factory=list) system_prompt: str | None = None model: str = "default" agent_name: str = "" task_type: str = "" task_id: str | None = None # 中间件间共享状态(压缩结果、token 用量、循环检测状态等) metadata: dict[str, Any] = field(default_factory=dict) @runtime_checkable class Middleware(Protocol): """中间件协议 — 洋葱模型的 before/after 钩子。 before: 请求处理前调用,可修改 ctx(如压缩 conversation) after: 请求处理后调用,可修改 result(如附加 token 用量) """ async def before(self, ctx: RequestContext) -> RequestContext: ... async def after(self, ctx: RequestContext, result: Any) -> Any: ... class MiddlewareChain: """洋葱模型中间件链。 before 由外到内执行(A → B → C), handler 执行, after 由内到外执行(C → B → A)。 若某中间件 before 抛异常,后续 before 不执行,after 链不触发, 异常向上传播由调用者处理。 """ def __init__(self, middlewares: list[Middleware] | None = None) -> None: self._middlewares: list[Middleware] = middlewares or [] @property def middlewares(self) -> list[Middleware]: return list(self._middlewares) async def execute( self, ctx: RequestContext, handler: Callable[[RequestContext], Awaitable[Any]], ) -> Any: """执行中间件链 + handler。 洋葱模型:before 顺序执行 → handler → after 逆序执行。 """ if not self._middlewares: return await handler(ctx) # before: 外 → 内 # before 异常自然传播,不执行 after 链 executed_befores: list[Middleware] = [] current_ctx = ctx for mw in self._middlewares: current_ctx = await mw.before(current_ctx) executed_befores.append(mw) # handler result = await handler(current_ctx) # after: 内 → 外(逆序) for mw in reversed(executed_befores): try: result = await mw.after(current_ctx, result) except Exception as e: logger.warning( "Middleware %s.after() failed: %s — continuing with current result", type(mw).__name__, e, ) return result # ── 示例中间件 ────────────────────────────────────────────── class SummarizationMiddleware: """压缩中间件 — 包装 ContextCompressor 的 headroom 压缩(U3)。 before: 若 conversation token 用量超过 headroom 阈值,压缩历史消息 after: 无操作(压缩在 before 完成) """ def __init__(self, compressor: Any = None) -> None: self._compressor = compressor async def before(self, ctx: RequestContext) -> RequestContext: if not self._compressor: return ctx # 检查是否需要压缩 should_compress_fn = getattr(self._compressor, "should_compress", None) if should_compress_fn is not None and should_compress_fn(ctx.messages): try: compressed = await self._compressor.compress(ctx.messages) ctx.messages = compressed ctx.metadata["compressed"] = True logger.info("SummarizationMiddleware: compressed conversation") except Exception as e: logger.warning(f"SummarizationMiddleware: compression failed: {e}") return ctx async def after(self, ctx: RequestContext, result: Any) -> Any: return result class TokenUsageMiddleware: """Token 计量中间件 — 记录请求的 token 用量。 before: 无操作 after: 从 result 中提取 token usage,记录到 metadata """ async def before(self, ctx: RequestContext) -> RequestContext: return ctx async def after(self, ctx: RequestContext, result: Any) -> Any: # 从 ReActResult 或类似结构提取 token 用量 # ReActResult 有 total_tokens 属性(非 token_usage) usage = getattr(result, "total_tokens", None) if usage is not None: ctx.metadata["token_usage_total"] = usage return result class LoopDetectionMiddleware: """循环检测中间件 — 包装 U1 的循环检测逻辑。 before: 初始化循环检测窗口 after: 检查最终 trajectory 是否有循环模式,记录警告 注意:per-step 循环检测仍在 _execute_loop 内进行(U1), 此中间件提供请求级的循环模式审计。 """ def __init__(self, window_size: int = 5, threshold: int = 2) -> None: self._window_size = window_size self._threshold = threshold async def before(self, ctx: RequestContext) -> RequestContext: ctx.metadata["loop_detection_window"] = [] return ctx async def after(self, ctx: RequestContext, result: Any) -> Any: trajectory = getattr(result, "trajectory", None) or [] if len(trajectory) < self._threshold: return result # 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口) # trajectory 存储 ReActStep dataclass 对象,需同时兼容 dict recent = trajectory[-self._window_size :] if trajectory else [] tool_calls: list[tuple[str, str]] = [] for step in recent: # 兼容 dataclass(ReActStep)和 dict 两种格式 if isinstance(step, dict): name = step.get("tool_name", "") args = step.get("arguments", {}) else: name = getattr(step, "tool_name", "") or "" args = getattr(step, "arguments", {}) or {} if name: args_str = json.dumps(args, sort_keys=True, default=str) if args else "" tool_calls.append((name, args_str)) if not tool_calls: return result # 滑动窗口检测 window = tool_calls unique = set(window) if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1: logger.warning( "LoopDetectionMiddleware: detected repeated tool calls in final trajectory" ) ctx.metadata["loop_detected"] = True return result