feat(core): add middleware pipeline architecture with onion model
U6: Unified middleware protocol (before/after) with MiddlewareChain implementing onion model execution. Parallel integration (KTD1) — middleware path controlled by presence of middleware_chain parameter, existing ReActEngine path unchanged when None. - New core/middleware.py: RequestContext, Middleware protocol, MiddlewareChain (onion model: before outer→inner, after inner→outer) - 3 example middlewares: SummarizationMiddleware (U3 headroom compression), TokenUsageMiddleware, LoopDetectionMiddleware (request-level audit) - ReActEngine.__init__ accepts middleware_chain parameter - execute() branches: middleware path when chain present, existing path otherwise - 22 tests covering ordering, error handling, state passing, backward compat
This commit is contained in:
parent
ef84e3fd53
commit
3dfda904d7
|
|
@ -0,0 +1,212 @@
|
||||||
|
"""Middleware pipeline architecture (U6)
|
||||||
|
|
||||||
|
洋葱模型中间件管道,将横切关注点(压缩/计量/安全/循环检测)集中化。
|
||||||
|
并行接入(KTD1):通过 feature flag 控制,与现有 ReActEngine 路径共存。
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from agentkit.core.middleware import (
|
||||||
|
MiddlewareChain,
|
||||||
|
RequestContext,
|
||||||
|
SummarizationMiddleware,
|
||||||
|
TokenUsageMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = MiddlewareChain([
|
||||||
|
SummarizationMiddleware(compressor),
|
||||||
|
TokenUsageMiddleware(),
|
||||||
|
])
|
||||||
|
result = await chain.execute(ctx, handler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestContext:
|
||||||
|
"""请求上下文,贯穿中间件链。
|
||||||
|
|
||||||
|
中间件可读写此上下文,实现中间件间状态传递。
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[dict[str, str]]
|
||||||
|
tools: list[Any] = field(default_factory=list)
|
||||||
|
system_prompt: str | None = None
|
||||||
|
model: str = "default"
|
||||||
|
agent_name: str = ""
|
||||||
|
task_type: str = ""
|
||||||
|
task_id: str | None = None
|
||||||
|
# 中间件间共享状态(压缩结果、token 用量、循环检测状态等)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Middleware(Protocol):
|
||||||
|
"""中间件协议 — 洋葱模型的 before/after 钩子。
|
||||||
|
|
||||||
|
before: 请求处理前调用,可修改 ctx(如压缩 conversation)
|
||||||
|
after: 请求处理后调用,可修改 result(如附加 token 用量)
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext: ...
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result: Any) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareChain:
|
||||||
|
"""洋葱模型中间件链。
|
||||||
|
|
||||||
|
before 由外到内执行(A → B → C),
|
||||||
|
handler 执行,
|
||||||
|
after 由内到外执行(C → B → A)。
|
||||||
|
|
||||||
|
若某中间件 before 抛异常,后续 before 不执行,after 链不触发,
|
||||||
|
异常向上传播由调用者处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, middlewares: list[Middleware] | None = None) -> None:
|
||||||
|
self._middlewares: list[Middleware] = middlewares or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def middlewares(self) -> list[Middleware]:
|
||||||
|
return list(self._middlewares)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
ctx: RequestContext,
|
||||||
|
handler: Callable[[RequestContext], Awaitable[Any]],
|
||||||
|
) -> Any:
|
||||||
|
"""执行中间件链 + handler。
|
||||||
|
|
||||||
|
洋葱模型:before 顺序执行 → handler → after 逆序执行。
|
||||||
|
"""
|
||||||
|
if not self._middlewares:
|
||||||
|
return await handler(ctx)
|
||||||
|
|
||||||
|
# before: 外 → 内
|
||||||
|
executed_befores: list[Middleware] = []
|
||||||
|
current_ctx = ctx
|
||||||
|
try:
|
||||||
|
for mw in self._middlewares:
|
||||||
|
current_ctx = await mw.before(current_ctx)
|
||||||
|
executed_befores.append(mw)
|
||||||
|
except Exception:
|
||||||
|
# before 异常:不执行 after,直接传播
|
||||||
|
raise
|
||||||
|
|
||||||
|
# handler
|
||||||
|
result = await handler(current_ctx)
|
||||||
|
|
||||||
|
# after: 内 → 外(逆序)
|
||||||
|
for mw in reversed(executed_befores):
|
||||||
|
try:
|
||||||
|
result = await mw.after(current_ctx, result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Middleware %s.after() failed: %s — continuing with current result",
|
||||||
|
type(mw).__name__,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── 示例中间件 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMiddleware:
|
||||||
|
"""压缩中间件 — 包装 ContextCompressor 的 headroom 压缩(U3)。
|
||||||
|
|
||||||
|
before: 若 conversation token 用量超过 headroom 阈值,压缩历史消息
|
||||||
|
after: 无操作(压缩在 before 完成)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compressor: Any = None) -> None:
|
||||||
|
self._compressor = compressor
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
if not self._compressor:
|
||||||
|
return ctx
|
||||||
|
# 检查是否需要压缩
|
||||||
|
should_compress_fn = getattr(self._compressor, "should_compress", None)
|
||||||
|
if should_compress_fn is not None and should_compress_fn(ctx.messages):
|
||||||
|
try:
|
||||||
|
compressed = await self._compressor.compress(ctx.messages)
|
||||||
|
ctx.messages = compressed
|
||||||
|
ctx.metadata["compressed"] = True
|
||||||
|
logger.info("SummarizationMiddleware: compressed conversation")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"SummarizationMiddleware: compression failed: {e}")
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TokenUsageMiddleware:
|
||||||
|
"""Token 计量中间件 — 记录请求的 token 用量。
|
||||||
|
|
||||||
|
before: 记录起始时间
|
||||||
|
after: 从 result 中提取 token usage,记录到 metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||||||
|
# 从 ReActResult 或类似结构提取 token usage
|
||||||
|
usage = getattr(result, "token_usage", None)
|
||||||
|
if usage is not None:
|
||||||
|
ctx.metadata["token_usage_total"] = usage
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LoopDetectionMiddleware:
|
||||||
|
"""循环检测中间件 — 包装 U1 的循环检测逻辑。
|
||||||
|
|
||||||
|
before: 初始化循环检测窗口
|
||||||
|
after: 检查最终 trajectory 是否有循环模式,记录警告
|
||||||
|
|
||||||
|
注意:per-step 循环检测仍在 _execute_loop 内进行(U1),
|
||||||
|
此中间件提供请求级的循环模式审计。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 5, threshold: int = 2) -> None:
|
||||||
|
self._window_size = window_size
|
||||||
|
self._threshold = threshold
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
ctx.metadata["loop_detection_window"] = []
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||||||
|
trajectory = getattr(result, "trajectory", None) or []
|
||||||
|
if len(trajectory) < self._threshold:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 检查最终 trajectory 中的重复工具调用模式
|
||||||
|
tool_calls = [
|
||||||
|
(step.get("tool_name", ""), step.get("arguments_hash", ""))
|
||||||
|
for step in trajectory
|
||||||
|
if isinstance(step, dict) and "tool_name" in step
|
||||||
|
]
|
||||||
|
if not tool_calls:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 滑动窗口检测
|
||||||
|
window = tool_calls[-self._window_size :]
|
||||||
|
unique = set(window)
|
||||||
|
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
|
||||||
|
logger.warning(
|
||||||
|
"LoopDetectionMiddleware: detected repeated tool calls in final trajectory"
|
||||||
|
)
|
||||||
|
ctx.metadata["loop_detected"] = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -27,6 +27,7 @@ from agentkit.telemetry.metrics import (
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from agentkit.core.compressor import CompressionStrategy
|
from agentkit.core.compressor import CompressionStrategy
|
||||||
|
from agentkit.core.middleware import MiddlewareChain
|
||||||
from agentkit.core.trace import TraceRecorder
|
from agentkit.core.trace import TraceRecorder
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
|
@ -162,6 +163,7 @@ class ReActEngine:
|
||||||
verification_commands: list[str] | None = None,
|
verification_commands: list[str] | None = None,
|
||||||
core_tool_names: list[str] | None = None,
|
core_tool_names: list[str] | None = None,
|
||||||
enable_tool_search: bool = True,
|
enable_tool_search: bool = True,
|
||||||
|
middleware_chain: "MiddlewareChain | None" = None,
|
||||||
):
|
):
|
||||||
if max_steps < 1:
|
if max_steps < 1:
|
||||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||||
|
|
@ -193,6 +195,8 @@ class ReActEngine:
|
||||||
self._loop_window: deque[str] = deque(maxlen=5)
|
self._loop_window: deque[str] = deque(maxlen=5)
|
||||||
self._loop_threshold: int = 2
|
self._loop_threshold: int = 2
|
||||||
self._loop_corrected: bool = False
|
self._loop_corrected: bool = False
|
||||||
|
# U6: Middleware chain (parallel integration, feature flag controlled)
|
||||||
|
self._middleware_chain = middleware_chain
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset internal state for reuse across conversations.
|
"""Reset internal state for reuse across conversations.
|
||||||
|
|
@ -263,6 +267,57 @@ class ReActEngine:
|
||||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# U6: Middleware chain (parallel integration, KTD1)
|
||||||
|
# If middleware_chain is present, wrap the handler with it.
|
||||||
|
# Otherwise, use the existing path (backward compatible).
|
||||||
|
if self._middleware_chain is not None:
|
||||||
|
from agentkit.core.middleware import RequestContext
|
||||||
|
|
||||||
|
ctx = RequestContext(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools or [],
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
agent_name=agent_name,
|
||||||
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handler(c: RequestContext) -> ReActResult:
|
||||||
|
return await self._execute_loop(
|
||||||
|
messages=c.messages,
|
||||||
|
tools=c.tools or None,
|
||||||
|
model=c.model,
|
||||||
|
agent_name=c.agent_name,
|
||||||
|
task_type=c.task_type,
|
||||||
|
system_prompt=c.system_prompt,
|
||||||
|
trace_recorder=trace_recorder,
|
||||||
|
memory_retriever=memory_retriever,
|
||||||
|
task_id=c.task_id,
|
||||||
|
compressor=effective_compressor,
|
||||||
|
retrieval_config=retrieval_config,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
|
confirmation_handler=confirmation_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if effective_timeout > 0:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._middleware_chain.execute(ctx, _handler),
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self._middleware_chain.execute(ctx, _handler)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TaskTimeoutError(
|
||||||
|
task_id=task_id or "",
|
||||||
|
timeout_seconds=int(effective_timeout),
|
||||||
|
)
|
||||||
|
except TaskCancelledError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if effective_timeout > 0:
|
if effective_timeout > 0:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,517 @@
|
||||||
|
"""U6: 中间件管道架构测试
|
||||||
|
|
||||||
|
测试洋葱模型中间件链的执行顺序、错误处理、向后兼容性。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.middleware import (
|
||||||
|
LoopDetectionMiddleware,
|
||||||
|
MiddlewareChain,
|
||||||
|
RequestContext,
|
||||||
|
SummarizationMiddleware,
|
||||||
|
TokenUsageMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 辅助 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _TraceMiddleware:
|
||||||
|
"""记录 before/after 调用顺序的测试中间件。"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, calls: list[str]):
|
||||||
|
self._name = name
|
||||||
|
self._calls = calls
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
self._calls.append(f"before:{self._name}")
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result):
|
||||||
|
self._calls.append(f"after:{self._name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _ErrorBeforeMiddleware:
|
||||||
|
"""before 抛异常的测试中间件。"""
|
||||||
|
|
||||||
|
def __init__(self, error: Exception):
|
||||||
|
self._error = error
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result):
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _ErrorAfterMiddleware:
|
||||||
|
"""after 抛异常的测试中间件。"""
|
||||||
|
|
||||||
|
def __init__(self, error: Exception):
|
||||||
|
self._error = error
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result):
|
||||||
|
raise self._error
|
||||||
|
|
||||||
|
|
||||||
|
class _ModifyCtxMiddleware:
|
||||||
|
"""修改 ctx 的测试中间件。"""
|
||||||
|
|
||||||
|
def __init__(self, key: str, value: str):
|
||||||
|
self._key = key
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
ctx.metadata[self._key] = self._value
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result):
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _ModifyResultMiddleware:
|
||||||
|
"""修改 result 的测试中间件。"""
|
||||||
|
|
||||||
|
def __init__(self, suffix: str):
|
||||||
|
self._suffix = suffix
|
||||||
|
|
||||||
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def after(self, ctx: RequestContext, result):
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result + self._suffix
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx() -> RequestContext:
|
||||||
|
return RequestContext(messages=[{"role": "user", "content": "test"}])
|
||||||
|
|
||||||
|
|
||||||
|
# ── 洋葱模型执行顺序 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareChainOrder:
|
||||||
|
"""洋葱模型 before/after 执行顺序。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_outer_to_inner_after_inner_to_outer(self) -> None:
|
||||||
|
"""3 个中间件:before A→B→C,after C→B→A。"""
|
||||||
|
calls: list[str] = []
|
||||||
|
chain = MiddlewareChain(
|
||||||
|
[
|
||||||
|
_TraceMiddleware("A", calls),
|
||||||
|
_TraceMiddleware("B", calls),
|
||||||
|
_TraceMiddleware("C", calls),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
calls.append("handler")
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
assert result == "result"
|
||||||
|
assert calls == [
|
||||||
|
"before:A",
|
||||||
|
"before:B",
|
||||||
|
"before:C",
|
||||||
|
"handler",
|
||||||
|
"after:C",
|
||||||
|
"after:B",
|
||||||
|
"after:A",
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_middleware(self) -> None:
|
||||||
|
"""单个中间件:before → handler → after。"""
|
||||||
|
calls: list[str] = []
|
||||||
|
chain = MiddlewareChain([_TraceMiddleware("X", calls)])
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
calls.append("handler")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert calls == ["before:X", "handler", "after:X"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_chain_calls_handler_directly(self) -> None:
|
||||||
|
"""空中间件链直接调用 handler。"""
|
||||||
|
chain = MiddlewareChain([])
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
return "direct"
|
||||||
|
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
assert result == "direct"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_none_middlewares_defaults_to_empty(self) -> None:
|
||||||
|
"""None middlewares 参数默认为空链。"""
|
||||||
|
chain = MiddlewareChain(None)
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
assert result == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 错误处理 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareChainErrors:
|
||||||
|
"""中间件链错误处理。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_error_stops_chain(self) -> None:
|
||||||
|
"""before 异常 → 后续 before 不执行,after 不触发,异常传播。"""
|
||||||
|
calls: list[str] = []
|
||||||
|
chain = MiddlewareChain(
|
||||||
|
[
|
||||||
|
_TraceMiddleware("A", calls),
|
||||||
|
_ErrorBeforeMiddleware(ValueError("before error")),
|
||||||
|
_TraceMiddleware("C", calls),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
calls.append("handler")
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="before error"):
|
||||||
|
await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
# A.before ran, B.before raised, C.before and handler never ran
|
||||||
|
assert calls == ["before:A"]
|
||||||
|
# No after calls (chain was interrupted)
|
||||||
|
assert "after:" not in " ".join(calls)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_error_does_not_stop_chain(self) -> None:
|
||||||
|
"""after 异常 → 记录 warning,不中断后续 after,返回当前 result。"""
|
||||||
|
calls: list[str] = []
|
||||||
|
chain = MiddlewareChain(
|
||||||
|
[
|
||||||
|
_TraceMiddleware("A", calls),
|
||||||
|
_ErrorAfterMiddleware(RuntimeError("after error")),
|
||||||
|
_TraceMiddleware("C", calls),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
calls.append("handler")
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
# Should NOT raise — after errors are caught and logged
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
assert result == "result"
|
||||||
|
# All befores ran, handler ran, C.after ran (B.after failed but didn't stop)
|
||||||
|
assert "before:A" in calls
|
||||||
|
assert "before:B" not in calls # _ErrorAfterMiddleware doesn't trace
|
||||||
|
assert "before:C" in calls
|
||||||
|
assert "handler" in calls
|
||||||
|
assert "after:C" in calls
|
||||||
|
assert "after:A" in calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── 状态传递 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareChainState:
|
||||||
|
"""中间件间状态通过 RequestContext 传递。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_modifies_ctx_visible_to_handler(self) -> None:
|
||||||
|
"""before 修改的 ctx.metadata 在 handler 中可见。"""
|
||||||
|
chain = MiddlewareChain([_ModifyCtxMiddleware("key", "value")])
|
||||||
|
|
||||||
|
captured: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
captured.update(ctx.metadata)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
assert captured["key"] == "value"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_modifies_result_visible_to_outer(self) -> None:
|
||||||
|
"""after 修改的 result 对外层可见。"""
|
||||||
|
chain = MiddlewareChain(
|
||||||
|
[
|
||||||
|
_ModifyResultMiddleware("_1"),
|
||||||
|
_ModifyResultMiddleware("_2"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
return "base"
|
||||||
|
|
||||||
|
result = await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
# after runs inner to outer: _2 first, then _1
|
||||||
|
# But _ModifyResultMiddleware appends suffix, so:
|
||||||
|
# handler returns "base"
|
||||||
|
# C.after: "base" + "_2" = "base_2"
|
||||||
|
# B.after: "base_2" + "_1" = "base_2_1"
|
||||||
|
# Wait — after runs in reverse order of before
|
||||||
|
# before: B(_1), C(_2) → after: C(_2), B(_1)
|
||||||
|
assert result == "base_2_1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_middlewares_share_ctx(self) -> None:
|
||||||
|
"""多个中间件共享同一个 ctx 对象。"""
|
||||||
|
chain = MiddlewareChain(
|
||||||
|
[
|
||||||
|
_ModifyCtxMiddleware("a", "1"),
|
||||||
|
_ModifyCtxMiddleware("b", "2"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
captured: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def handler(ctx: RequestContext) -> str:
|
||||||
|
captured.update(ctx.metadata)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
await chain.execute(_make_ctx(), handler)
|
||||||
|
|
||||||
|
assert captured == {"a": "1", "b": "2"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 向后兼容 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareBackwardCompat:
|
||||||
|
"""middleware_chain=None 时行为不变(向后兼容)。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_react_engine_without_middleware_works(self) -> None:
|
||||||
|
"""ReActEngine 无 middleware_chain 时走现有路径。"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(
|
||||||
|
return_value=MagicMock(
|
||||||
|
content="hello",
|
||||||
|
tool_calls=[],
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=1)
|
||||||
|
# No middleware_chain — should use existing path
|
||||||
|
assert engine._middleware_chain is None
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=[],
|
||||||
|
timeout_seconds=0,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_react_engine_with_middleware_uses_chain(self) -> None:
|
||||||
|
"""ReActEngine 有 middleware_chain 时走中间件路径。"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(
|
||||||
|
return_value=MagicMock(
|
||||||
|
content="hello",
|
||||||
|
tool_calls=[],
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
calls: list[str] = []
|
||||||
|
chain = MiddlewareChain([_TraceMiddleware("test", calls)])
|
||||||
|
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_steps=1,
|
||||||
|
middleware_chain=chain,
|
||||||
|
)
|
||||||
|
assert engine._middleware_chain is chain
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=[],
|
||||||
|
timeout_seconds=0,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
# Middleware before/after should have been called
|
||||||
|
assert "before:test" in calls
|
||||||
|
assert "after:test" in calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── 示例中间件测试 ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarizationMiddleware:
|
||||||
|
"""SummarizationMiddleware 测试。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_compressor_passes_through(self) -> None:
|
||||||
|
"""无 compressor 时,before 不修改 ctx。"""
|
||||||
|
mw = SummarizationMiddleware(compressor=None)
|
||||||
|
ctx = _make_ctx()
|
||||||
|
result_ctx = await mw.before(ctx)
|
||||||
|
assert result_ctx is ctx
|
||||||
|
assert "compressed" not in ctx.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compressor_not_needed_passes_through(self) -> None:
|
||||||
|
"""compressor 报告不需要压缩时,before 不修改 messages。"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.should_compress = MagicMock(return_value=False)
|
||||||
|
mw = SummarizationMiddleware(compressor=compressor)
|
||||||
|
|
||||||
|
ctx = _make_ctx()
|
||||||
|
original_messages = list(ctx.messages)
|
||||||
|
await mw.before(ctx)
|
||||||
|
|
||||||
|
assert ctx.messages == original_messages
|
||||||
|
assert "compressed" not in ctx.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compressor_triggers_compression(self) -> None:
|
||||||
|
"""compressor 报告需要压缩时,before 压缩 messages。"""
|
||||||
|
compressed_messages = [{"role": "user", "content": "compressed"}]
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.should_compress = MagicMock(return_value=True)
|
||||||
|
compressor.compress = AsyncMock(return_value=compressed_messages)
|
||||||
|
mw = SummarizationMiddleware(compressor=compressor)
|
||||||
|
|
||||||
|
ctx = _make_ctx()
|
||||||
|
await mw.before(ctx)
|
||||||
|
|
||||||
|
assert ctx.messages == compressed_messages
|
||||||
|
assert ctx.metadata["compressed"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compressor_failure_does_not_block(self) -> None:
|
||||||
|
"""compressor 压缩失败时,不阻断执行,messages 不变。"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.should_compress = MagicMock(return_value=True)
|
||||||
|
compressor.compress = AsyncMock(side_effect=RuntimeError("compress failed"))
|
||||||
|
mw = SummarizationMiddleware(compressor=compressor)
|
||||||
|
|
||||||
|
ctx = _make_ctx()
|
||||||
|
original_messages = list(ctx.messages)
|
||||||
|
await mw.before(ctx)
|
||||||
|
|
||||||
|
# Should not raise, messages unchanged
|
||||||
|
assert ctx.messages == original_messages
|
||||||
|
assert "compressed" not in ctx.metadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenUsageMiddleware:
|
||||||
|
"""TokenUsageMiddleware 测试。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_initializes_metadata(self) -> None:
|
||||||
|
"""before 初始化 token_usage_start。"""
|
||||||
|
mw = TokenUsageMiddleware()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
await mw.before(ctx)
|
||||||
|
assert "token_usage_start" in ctx.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_extracts_usage_from_result(self) -> None:
|
||||||
|
"""after 从 result 提取 token_usage。"""
|
||||||
|
mw = TokenUsageMiddleware()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = MagicMock()
|
||||||
|
result.token_usage = {"total": 100}
|
||||||
|
|
||||||
|
await mw.after(ctx, result)
|
||||||
|
|
||||||
|
assert ctx.metadata["token_usage_total"] == {"total": 100}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_no_usage_in_result(self) -> None:
|
||||||
|
"""result 无 token_usage 时,after 不修改 metadata。"""
|
||||||
|
mw = TokenUsageMiddleware()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = MagicMock(spec=[]) # No token_usage attribute
|
||||||
|
await mw.after(ctx, result)
|
||||||
|
|
||||||
|
assert "token_usage_total" not in ctx.metadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoopDetectionMiddleware:
|
||||||
|
"""LoopDetectionMiddleware 测试。"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_before_initializes_window(self) -> None:
|
||||||
|
"""before 初始化 loop_detection_window。"""
|
||||||
|
mw = LoopDetectionMiddleware()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
await mw.before(ctx)
|
||||||
|
assert ctx.metadata["loop_detection_window"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_no_trajectory_no_detection(self) -> None:
|
||||||
|
"""result 无 trajectory 时,after 不检测。"""
|
||||||
|
mw = LoopDetectionMiddleware()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = MagicMock(spec=[]) # No trajectory attribute
|
||||||
|
await mw.after(ctx, result)
|
||||||
|
|
||||||
|
assert "loop_detected" not in ctx.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_detects_repeated_tool_calls(self) -> None:
|
||||||
|
"""trajectory 中有重复工具调用时,after 标记 loop_detected。"""
|
||||||
|
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = MagicMock()
|
||||||
|
result.trajectory = [
|
||||||
|
{"tool_name": "search", "arguments_hash": "abc"},
|
||||||
|
{"tool_name": "search", "arguments_hash": "abc"},
|
||||||
|
{"tool_name": "search", "arguments_hash": "abc"},
|
||||||
|
]
|
||||||
|
|
||||||
|
await mw.after(ctx, result)
|
||||||
|
|
||||||
|
assert ctx.metadata["loop_detected"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_no_loop_with_different_calls(self) -> None:
|
||||||
|
"""trajectory 中工具调用不同时,after 不标记 loop_detected。"""
|
||||||
|
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = MagicMock()
|
||||||
|
result.trajectory = [
|
||||||
|
{"tool_name": "search", "arguments_hash": "abc"},
|
||||||
|
{"tool_name": "read_file", "arguments_hash": "def"},
|
||||||
|
{"tool_name": "write_file", "arguments_hash": "ghi"},
|
||||||
|
]
|
||||||
|
|
||||||
|
await mw.after(ctx, result)
|
||||||
|
|
||||||
|
assert "loop_detected" not in ctx.metadata
|
||||||
Loading…
Reference in New Issue