feat(core): add middleware pipeline architecture with onion model

U6: Unified middleware protocol (before/after) with MiddlewareChain
implementing onion model execution. Parallel integration (KTD1) —
middleware path controlled by presence of middleware_chain parameter,
existing ReActEngine path unchanged when None.

- New core/middleware.py: RequestContext, Middleware protocol,
  MiddlewareChain (onion model: before outer→inner, after inner→outer)
- 3 example middlewares: SummarizationMiddleware (U3 headroom compression),
  TokenUsageMiddleware, LoopDetectionMiddleware (request-level audit)
- ReActEngine.__init__ accepts middleware_chain parameter
- execute() branches: middleware path when chain present, existing path otherwise
- 22 tests covering ordering, error handling, state passing, backward compat
This commit is contained in:
chiguyong 2026-06-24 20:52:15 +08:00
parent ef84e3fd53
commit 3dfda904d7
3 changed files with 784 additions and 0 deletions

View File

@ -0,0 +1,212 @@
"""Middleware pipeline architecture (U6)
洋葱模型中间件管道将横切关注点压缩/计量/安全/循环检测集中化
并行接入KTD1通过 feature flag 控制与现有 ReActEngine 路径共存
Usage::
from agentkit.core.middleware import (
MiddlewareChain,
RequestContext,
SummarizationMiddleware,
TokenUsageMiddleware,
)
chain = MiddlewareChain([
SummarizationMiddleware(compressor),
TokenUsageMiddleware(),
])
result = await chain.execute(ctx, handler)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass
class RequestContext:
"""请求上下文,贯穿中间件链。
中间件可读写此上下文实现中间件间状态传递
"""
messages: list[dict[str, str]]
tools: list[Any] = field(default_factory=list)
system_prompt: str | None = None
model: str = "default"
agent_name: str = ""
task_type: str = ""
task_id: str | None = None
# 中间件间共享状态压缩结果、token 用量、循环检测状态等)
metadata: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class Middleware(Protocol):
"""中间件协议 — 洋葱模型的 before/after 钩子。
before: 请求处理前调用可修改 ctx如压缩 conversation
after: 请求处理后调用可修改 result如附加 token 用量
"""
async def before(self, ctx: RequestContext) -> RequestContext: ...
async def after(self, ctx: RequestContext, result: Any) -> Any: ...
class MiddlewareChain:
"""洋葱模型中间件链。
before 由外到内执行A B C
handler 执行
after 由内到外执行C B A
若某中间件 before 抛异常后续 before 不执行after 链不触发
异常向上传播由调用者处理
"""
def __init__(self, middlewares: list[Middleware] | None = None) -> None:
self._middlewares: list[Middleware] = middlewares or []
@property
def middlewares(self) -> list[Middleware]:
return list(self._middlewares)
async def execute(
self,
ctx: RequestContext,
handler: Callable[[RequestContext], Awaitable[Any]],
) -> Any:
"""执行中间件链 + handler。
洋葱模型before 顺序执行 handler after 逆序执行
"""
if not self._middlewares:
return await handler(ctx)
# before: 外 → 内
executed_befores: list[Middleware] = []
current_ctx = ctx
try:
for mw in self._middlewares:
current_ctx = await mw.before(current_ctx)
executed_befores.append(mw)
except Exception:
# before 异常:不执行 after直接传播
raise
# handler
result = await handler(current_ctx)
# after: 内 → 外(逆序)
for mw in reversed(executed_befores):
try:
result = await mw.after(current_ctx, result)
except Exception as e:
logger.warning(
"Middleware %s.after() failed: %s — continuing with current result",
type(mw).__name__,
e,
)
return result
# ── 示例中间件 ──────────────────────────────────────────────
class SummarizationMiddleware:
"""压缩中间件 — 包装 ContextCompressor 的 headroom 压缩U3
before: conversation token 用量超过 headroom 阈值压缩历史消息
after: 无操作压缩在 before 完成
"""
def __init__(self, compressor: Any = None) -> None:
self._compressor = compressor
async def before(self, ctx: RequestContext) -> RequestContext:
if not self._compressor:
return ctx
# 检查是否需要压缩
should_compress_fn = getattr(self._compressor, "should_compress", None)
if should_compress_fn is not None and should_compress_fn(ctx.messages):
try:
compressed = await self._compressor.compress(ctx.messages)
ctx.messages = compressed
ctx.metadata["compressed"] = True
logger.info("SummarizationMiddleware: compressed conversation")
except Exception as e:
logger.warning(f"SummarizationMiddleware: compression failed: {e}")
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
return result
class TokenUsageMiddleware:
"""Token 计量中间件 — 记录请求的 token 用量。
before: 记录起始时间
after: result 中提取 token usage记录到 metadata
"""
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
# 从 ReActResult 或类似结构提取 token usage
usage = getattr(result, "token_usage", None)
if usage is not None:
ctx.metadata["token_usage_total"] = usage
return result
class LoopDetectionMiddleware:
"""循环检测中间件 — 包装 U1 的循环检测逻辑。
before: 初始化循环检测窗口
after: 检查最终 trajectory 是否有循环模式记录警告
注意per-step 循环检测仍在 _execute_loop 内进行U1
此中间件提供请求级的循环模式审计
"""
def __init__(self, window_size: int = 5, threshold: int = 2) -> None:
self._window_size = window_size
self._threshold = threshold
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata["loop_detection_window"] = []
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
trajectory = getattr(result, "trajectory", None) or []
if len(trajectory) < self._threshold:
return result
# 检查最终 trajectory 中的重复工具调用模式
tool_calls = [
(step.get("tool_name", ""), step.get("arguments_hash", ""))
for step in trajectory
if isinstance(step, dict) and "tool_name" in step
]
if not tool_calls:
return result
# 滑动窗口检测
window = tool_calls[-self._window_size :]
unique = set(window)
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
logger.warning(
"LoopDetectionMiddleware: detected repeated tool calls in final trajectory"
)
ctx.metadata["loop_detected"] = True
return result

View File

@ -27,6 +27,7 @@ from agentkit.telemetry.metrics import (
if TYPE_CHECKING:
from agentkit.core.compressor import CompressionStrategy
from agentkit.core.middleware import MiddlewareChain
from agentkit.core.trace import TraceRecorder
from agentkit.memory.retriever import MemoryRetriever
@ -162,6 +163,7 @@ class ReActEngine:
verification_commands: list[str] | None = None,
core_tool_names: list[str] | None = None,
enable_tool_search: bool = True,
middleware_chain: "MiddlewareChain | None" = None,
):
if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
@ -193,6 +195,8 @@ class ReActEngine:
self._loop_window: deque[str] = deque(maxlen=5)
self._loop_threshold: int = 2
self._loop_corrected: bool = False
# U6: Middleware chain (parallel integration, feature flag controlled)
self._middleware_chain = middleware_chain
def reset(self) -> None:
"""Reset internal state for reuse across conversations.
@ -263,6 +267,57 @@ class ReActEngine:
timeout_seconds if timeout_seconds is not None else self._default_timeout
)
# U6: Middleware chain (parallel integration, KTD1)
# If middleware_chain is present, wrap the handler with it.
# Otherwise, use the existing path (backward compatible).
if self._middleware_chain is not None:
from agentkit.core.middleware import RequestContext
ctx = RequestContext(
messages=messages,
tools=tools or [],
system_prompt=system_prompt,
model=model,
agent_name=agent_name,
task_type=task_type,
task_id=task_id,
)
async def _handler(c: RequestContext) -> ReActResult:
return await self._execute_loop(
messages=c.messages,
tools=c.tools or None,
model=c.model,
agent_name=c.agent_name,
task_type=c.task_type,
system_prompt=c.system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=c.task_id,
compressor=effective_compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
)
try:
if effective_timeout > 0:
result = await asyncio.wait_for(
self._middleware_chain.execute(ctx, _handler),
timeout=effective_timeout,
)
else:
result = await self._middleware_chain.execute(ctx, _handler)
except asyncio.TimeoutError:
raise TaskTimeoutError(
task_id=task_id or "",
timeout_seconds=int(effective_timeout),
)
except TaskCancelledError:
raise
return result
try:
if effective_timeout > 0:
result = await asyncio.wait_for(

View File

@ -0,0 +1,517 @@
"""U6: 中间件管道架构测试
测试洋葱模型中间件链的执行顺序错误处理向后兼容性
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.middleware import (
LoopDetectionMiddleware,
MiddlewareChain,
RequestContext,
SummarizationMiddleware,
TokenUsageMiddleware,
)
# ── 辅助 ──────────────────────────────────────────────────
class _TraceMiddleware:
"""记录 before/after 调用顺序的测试中间件。"""
def __init__(self, name: str, calls: list[str]):
self._name = name
self._calls = calls
async def before(self, ctx: RequestContext) -> RequestContext:
self._calls.append(f"before:{self._name}")
return ctx
async def after(self, ctx: RequestContext, result):
self._calls.append(f"after:{self._name}")
return result
class _ErrorBeforeMiddleware:
"""before 抛异常的测试中间件。"""
def __init__(self, error: Exception):
self._error = error
async def before(self, ctx: RequestContext) -> RequestContext:
raise self._error
async def after(self, ctx: RequestContext, result):
return result
class _ErrorAfterMiddleware:
"""after 抛异常的测试中间件。"""
def __init__(self, error: Exception):
self._error = error
async def before(self, ctx: RequestContext) -> RequestContext:
return ctx
async def after(self, ctx: RequestContext, result):
raise self._error
class _ModifyCtxMiddleware:
"""修改 ctx 的测试中间件。"""
def __init__(self, key: str, value: str):
self._key = key
self._value = value
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata[self._key] = self._value
return ctx
async def after(self, ctx: RequestContext, result):
return result
class _ModifyResultMiddleware:
"""修改 result 的测试中间件。"""
def __init__(self, suffix: str):
self._suffix = suffix
async def before(self, ctx: RequestContext) -> RequestContext:
return ctx
async def after(self, ctx: RequestContext, result):
if isinstance(result, str):
return result + self._suffix
return result
def _make_ctx() -> RequestContext:
return RequestContext(messages=[{"role": "user", "content": "test"}])
# ── 洋葱模型执行顺序 ──────────────────────────────────────
class TestMiddlewareChainOrder:
"""洋葱模型 before/after 执行顺序。"""
@pytest.mark.asyncio
async def test_before_outer_to_inner_after_inner_to_outer(self) -> None:
"""3 个中间件before A→B→Cafter C→B→A。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_TraceMiddleware("B", calls),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
result = await chain.execute(_make_ctx(), handler)
assert result == "result"
assert calls == [
"before:A",
"before:B",
"before:C",
"handler",
"after:C",
"after:B",
"after:A",
]
@pytest.mark.asyncio
async def test_single_middleware(self) -> None:
"""单个中间件before → handler → after。"""
calls: list[str] = []
chain = MiddlewareChain([_TraceMiddleware("X", calls)])
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "ok"
result = await chain.execute(_make_ctx(), handler)
assert result == "ok"
assert calls == ["before:X", "handler", "after:X"]
@pytest.mark.asyncio
async def test_empty_chain_calls_handler_directly(self) -> None:
"""空中间件链直接调用 handler。"""
chain = MiddlewareChain([])
async def handler(ctx: RequestContext) -> str:
return "direct"
result = await chain.execute(_make_ctx(), handler)
assert result == "direct"
@pytest.mark.asyncio
async def test_none_middlewares_defaults_to_empty(self) -> None:
"""None middlewares 参数默认为空链。"""
chain = MiddlewareChain(None)
async def handler(ctx: RequestContext) -> str:
return "ok"
result = await chain.execute(_make_ctx(), handler)
assert result == "ok"
# ── 错误处理 ──────────────────────────────────────────────
class TestMiddlewareChainErrors:
"""中间件链错误处理。"""
@pytest.mark.asyncio
async def test_before_error_stops_chain(self) -> None:
"""before 异常 → 后续 before 不执行after 不触发,异常传播。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_ErrorBeforeMiddleware(ValueError("before error")),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
with pytest.raises(ValueError, match="before error"):
await chain.execute(_make_ctx(), handler)
# A.before ran, B.before raised, C.before and handler never ran
assert calls == ["before:A"]
# No after calls (chain was interrupted)
assert "after:" not in " ".join(calls)
@pytest.mark.asyncio
async def test_after_error_does_not_stop_chain(self) -> None:
"""after 异常 → 记录 warning不中断后续 after返回当前 result。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_ErrorAfterMiddleware(RuntimeError("after error")),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
# Should NOT raise — after errors are caught and logged
result = await chain.execute(_make_ctx(), handler)
assert result == "result"
# All befores ran, handler ran, C.after ran (B.after failed but didn't stop)
assert "before:A" in calls
assert "before:B" not in calls # _ErrorAfterMiddleware doesn't trace
assert "before:C" in calls
assert "handler" in calls
assert "after:C" in calls
assert "after:A" in calls
# ── 状态传递 ──────────────────────────────────────────────
class TestMiddlewareChainState:
"""中间件间状态通过 RequestContext 传递。"""
@pytest.mark.asyncio
async def test_before_modifies_ctx_visible_to_handler(self) -> None:
"""before 修改的 ctx.metadata 在 handler 中可见。"""
chain = MiddlewareChain([_ModifyCtxMiddleware("key", "value")])
captured: dict[str, str] = {}
async def handler(ctx: RequestContext) -> str:
captured.update(ctx.metadata)
return "ok"
await chain.execute(_make_ctx(), handler)
assert captured["key"] == "value"
@pytest.mark.asyncio
async def test_after_modifies_result_visible_to_outer(self) -> None:
"""after 修改的 result 对外层可见。"""
chain = MiddlewareChain(
[
_ModifyResultMiddleware("_1"),
_ModifyResultMiddleware("_2"),
]
)
async def handler(ctx: RequestContext) -> str:
return "base"
result = await chain.execute(_make_ctx(), handler)
# after runs inner to outer: _2 first, then _1
# But _ModifyResultMiddleware appends suffix, so:
# handler returns "base"
# C.after: "base" + "_2" = "base_2"
# B.after: "base_2" + "_1" = "base_2_1"
# Wait — after runs in reverse order of before
# before: B(_1), C(_2) → after: C(_2), B(_1)
assert result == "base_2_1"
@pytest.mark.asyncio
async def test_multiple_middlewares_share_ctx(self) -> None:
"""多个中间件共享同一个 ctx 对象。"""
chain = MiddlewareChain(
[
_ModifyCtxMiddleware("a", "1"),
_ModifyCtxMiddleware("b", "2"),
]
)
captured: dict[str, str] = {}
async def handler(ctx: RequestContext) -> str:
captured.update(ctx.metadata)
return "ok"
await chain.execute(_make_ctx(), handler)
assert captured == {"a": "1", "b": "2"}
# ── 向后兼容 ──────────────────────────────────────────────
class TestMiddlewareBackwardCompat:
"""middleware_chain=None 时行为不变(向后兼容)。"""
@pytest.mark.asyncio
async def test_react_engine_without_middleware_works(self) -> None:
"""ReActEngine 无 middleware_chain 时走现有路径。"""
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=MagicMock(
content="hello",
tool_calls=[],
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
)
)
engine = ReActEngine(llm_gateway=gateway, max_steps=1)
# No middleware_chain — should use existing path
assert engine._middleware_chain is None
result = await engine.execute(
messages=[{"role": "user", "content": "hi"}],
tools=[],
timeout_seconds=0,
)
assert result is not None
@pytest.mark.asyncio
async def test_react_engine_with_middleware_uses_chain(self) -> None:
"""ReActEngine 有 middleware_chain 时走中间件路径。"""
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=MagicMock(
content="hello",
tool_calls=[],
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
)
)
calls: list[str] = []
chain = MiddlewareChain([_TraceMiddleware("test", calls)])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=1,
middleware_chain=chain,
)
assert engine._middleware_chain is chain
result = await engine.execute(
messages=[{"role": "user", "content": "hi"}],
tools=[],
timeout_seconds=0,
)
assert result is not None
# Middleware before/after should have been called
assert "before:test" in calls
assert "after:test" in calls
# ── 示例中间件测试 ─────────────────────────────────────────
class TestSummarizationMiddleware:
"""SummarizationMiddleware 测试。"""
@pytest.mark.asyncio
async def test_no_compressor_passes_through(self) -> None:
"""无 compressor 时before 不修改 ctx。"""
mw = SummarizationMiddleware(compressor=None)
ctx = _make_ctx()
result_ctx = await mw.before(ctx)
assert result_ctx is ctx
assert "compressed" not in ctx.metadata
@pytest.mark.asyncio
async def test_compressor_not_needed_passes_through(self) -> None:
"""compressor 报告不需要压缩时before 不修改 messages。"""
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=False)
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
original_messages = list(ctx.messages)
await mw.before(ctx)
assert ctx.messages == original_messages
assert "compressed" not in ctx.metadata
@pytest.mark.asyncio
async def test_compressor_triggers_compression(self) -> None:
"""compressor 报告需要压缩时before 压缩 messages。"""
compressed_messages = [{"role": "user", "content": "compressed"}]
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=True)
compressor.compress = AsyncMock(return_value=compressed_messages)
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
await mw.before(ctx)
assert ctx.messages == compressed_messages
assert ctx.metadata["compressed"] is True
@pytest.mark.asyncio
async def test_compressor_failure_does_not_block(self) -> None:
"""compressor 压缩失败时不阻断执行messages 不变。"""
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=True)
compressor.compress = AsyncMock(side_effect=RuntimeError("compress failed"))
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
original_messages = list(ctx.messages)
await mw.before(ctx)
# Should not raise, messages unchanged
assert ctx.messages == original_messages
assert "compressed" not in ctx.metadata
class TestTokenUsageMiddleware:
"""TokenUsageMiddleware 测试。"""
@pytest.mark.asyncio
async def test_before_initializes_metadata(self) -> None:
"""before 初始化 token_usage_start。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
await mw.before(ctx)
assert "token_usage_start" in ctx.metadata
@pytest.mark.asyncio
async def test_after_extracts_usage_from_result(self) -> None:
"""after 从 result 提取 token_usage。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
result = MagicMock()
result.token_usage = {"total": 100}
await mw.after(ctx, result)
assert ctx.metadata["token_usage_total"] == {"total": 100}
@pytest.mark.asyncio
async def test_after_no_usage_in_result(self) -> None:
"""result 无 token_usage 时after 不修改 metadata。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
result = MagicMock(spec=[]) # No token_usage attribute
await mw.after(ctx, result)
assert "token_usage_total" not in ctx.metadata
class TestLoopDetectionMiddleware:
"""LoopDetectionMiddleware 测试。"""
@pytest.mark.asyncio
async def test_before_initializes_window(self) -> None:
"""before 初始化 loop_detection_window。"""
mw = LoopDetectionMiddleware()
ctx = _make_ctx()
await mw.before(ctx)
assert ctx.metadata["loop_detection_window"] == []
@pytest.mark.asyncio
async def test_after_no_trajectory_no_detection(self) -> None:
"""result 无 trajectory 时after 不检测。"""
mw = LoopDetectionMiddleware()
ctx = _make_ctx()
result = MagicMock(spec=[]) # No trajectory attribute
await mw.after(ctx, result)
assert "loop_detected" not in ctx.metadata
@pytest.mark.asyncio
async def test_after_detects_repeated_tool_calls(self) -> None:
"""trajectory 中有重复工具调用时after 标记 loop_detected。"""
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
ctx = _make_ctx()
result = MagicMock()
result.trajectory = [
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "search", "arguments_hash": "abc"},
]
await mw.after(ctx, result)
assert ctx.metadata["loop_detected"] is True
@pytest.mark.asyncio
async def test_after_no_loop_with_different_calls(self) -> None:
"""trajectory 中工具调用不同时after 不标记 loop_detected。"""
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
ctx = _make_ctx()
result = MagicMock()
result.trajectory = [
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "read_file", "arguments_hash": "def"},
{"tool_name": "write_file", "arguments_hash": "ghi"},
]
await mw.after(ctx, result)
assert "loop_detected" not in ctx.metadata