fischer-agentkit/src/agentkit/core/middleware.py

220 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Middleware pipeline architecture (U6)
洋葱模型中间件管道,将横切关注点(压缩/计量/安全/循环检测)集中化。
并行接入KTD1通过 feature flag 控制,与现有 ReActEngine 路径共存。
Usage::
from agentkit.core.middleware import (
MiddlewareChain,
RequestContext,
SummarizationMiddleware,
TokenUsageMiddleware,
)
chain = MiddlewareChain([
SummarizationMiddleware(compressor),
TokenUsageMiddleware(),
])
result = await chain.execute(ctx, handler)
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass
class RequestContext:
"""请求上下文,贯穿中间件链。
中间件可读写此上下文,实现中间件间状态传递。
"""
messages: list[dict[str, str]]
tools: list[Any] = field(default_factory=list)
system_prompt: str | None = None
model: str = "default"
agent_name: str = ""
task_type: str = ""
task_id: str | None = None
# 中间件间共享状态压缩结果、token 用量、循环检测状态等)
metadata: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class Middleware(Protocol):
"""中间件协议 — 洋葱模型的 before/after 钩子。
before: 请求处理前调用,可修改 ctx如压缩 conversation
after: 请求处理后调用,可修改 result如附加 token 用量)
"""
async def before(self, ctx: RequestContext) -> RequestContext: ...
async def after(self, ctx: RequestContext, result: Any) -> Any: ...
class MiddlewareChain:
"""洋葱模型中间件链。
before 由外到内执行A → B → C
handler 执行,
after 由内到外执行C → B → A
若某中间件 before 抛异常,后续 before 不执行after 链不触发,
异常向上传播由调用者处理。
"""
def __init__(self, middlewares: list[Middleware] | None = None) -> None:
self._middlewares: list[Middleware] = middlewares or []
@property
def middlewares(self) -> list[Middleware]:
return list(self._middlewares)
async def execute(
self,
ctx: RequestContext,
handler: Callable[[RequestContext], Awaitable[Any]],
) -> Any:
"""执行中间件链 + handler。
洋葱模型before 顺序执行 → handler → after 逆序执行。
"""
if not self._middlewares:
return await handler(ctx)
# before: 外 → 内
# before 异常自然传播,不执行 after 链
executed_befores: list[Middleware] = []
current_ctx = ctx
for mw in self._middlewares:
current_ctx = await mw.before(current_ctx)
executed_befores.append(mw)
# handler
result = await handler(current_ctx)
# after: 内 → 外(逆序)
for mw in reversed(executed_befores):
try:
result = await mw.after(current_ctx, result)
except Exception as e:
logger.warning(
"Middleware %s.after() failed: %s — continuing with current result",
type(mw).__name__,
e,
)
return result
# ── 示例中间件 ──────────────────────────────────────────────
class SummarizationMiddleware:
"""压缩中间件 — 包装 ContextCompressor 的 headroom 压缩U3
before: 若 conversation token 用量超过 headroom 阈值,压缩历史消息
after: 无操作(压缩在 before 完成)
"""
def __init__(self, compressor: Any = None) -> None:
self._compressor = compressor
async def before(self, ctx: RequestContext) -> RequestContext:
if not self._compressor:
return ctx
# 检查是否需要压缩
should_compress_fn = getattr(self._compressor, "should_compress", None)
if should_compress_fn is not None and should_compress_fn(ctx.messages):
try:
compressed = await self._compressor.compress(ctx.messages)
ctx.messages = compressed
ctx.metadata["compressed"] = True
logger.info("SummarizationMiddleware: compressed conversation")
except Exception as e:
logger.warning(f"SummarizationMiddleware: compression failed: {e}")
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
return result
class TokenUsageMiddleware:
"""Token 计量中间件 — 记录请求的 token 用量。
before: 无操作
after: 从 result 中提取 token usage记录到 metadata
"""
async def before(self, ctx: RequestContext) -> RequestContext:
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
# 从 ReActResult 或类似结构提取 token 用量
# ReActResult 有 total_tokens 属性(非 token_usage
usage = getattr(result, "total_tokens", None)
if usage is not None:
ctx.metadata["token_usage_total"] = usage
return result
class LoopDetectionMiddleware:
"""循环检测中间件 — 包装 U1 的循环检测逻辑。
before: 初始化循环检测窗口
after: 检查最终 trajectory 是否有循环模式,记录警告
注意per-step 循环检测仍在 _execute_loop 内进行U1
此中间件提供请求级的循环模式审计。
"""
def __init__(self, window_size: int = 5, threshold: int = 2) -> None:
self._window_size = window_size
self._threshold = threshold
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata["loop_detection_window"] = []
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
trajectory = getattr(result, "trajectory", None) or []
if len(trajectory) < self._threshold:
return result
# 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
# trajectory 存储 ReActStep dataclass 对象,需同时兼容 dict
recent = trajectory[-self._window_size :] if trajectory else []
tool_calls: list[tuple[str, str]] = []
for step in recent:
# 兼容 dataclassReActStep和 dict 两种格式
if isinstance(step, dict):
name = step.get("tool_name", "")
args = step.get("arguments", {})
else:
name = getattr(step, "tool_name", "") or ""
args = getattr(step, "arguments", {}) or {}
if name:
args_str = json.dumps(args, sort_keys=True, default=str) if args else ""
tool_calls.append((name, args_str))
if not tool_calls:
return result
# 滑动窗口检测
window = tool_calls
unique = set(window)
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
logger.warning(
"LoopDetectionMiddleware: detected repeated tool calls in final trajectory"
)
ctx.metadata["loop_detected"] = True
return result