518 lines
17 KiB
Python
518 lines
17 KiB
Python
"""U6: 中间件管道架构测试
|
||
|
||
测试洋葱模型中间件链的执行顺序、错误处理、向后兼容性。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.middleware import (
|
||
LoopDetectionMiddleware,
|
||
MiddlewareChain,
|
||
RequestContext,
|
||
SummarizationMiddleware,
|
||
TokenUsageMiddleware,
|
||
)
|
||
|
||
|
||
# ── 辅助 ──────────────────────────────────────────────────
|
||
|
||
|
||
class _TraceMiddleware:
|
||
"""记录 before/after 调用顺序的测试中间件。"""
|
||
|
||
def __init__(self, name: str, calls: list[str]):
|
||
self._name = name
|
||
self._calls = calls
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
self._calls.append(f"before:{self._name}")
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result):
|
||
self._calls.append(f"after:{self._name}")
|
||
return result
|
||
|
||
|
||
class _ErrorBeforeMiddleware:
|
||
"""before 抛异常的测试中间件。"""
|
||
|
||
def __init__(self, error: Exception):
|
||
self._error = error
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
raise self._error
|
||
|
||
async def after(self, ctx: RequestContext, result):
|
||
return result
|
||
|
||
|
||
class _ErrorAfterMiddleware:
|
||
"""after 抛异常的测试中间件。"""
|
||
|
||
def __init__(self, error: Exception):
|
||
self._error = error
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result):
|
||
raise self._error
|
||
|
||
|
||
class _ModifyCtxMiddleware:
|
||
"""修改 ctx 的测试中间件。"""
|
||
|
||
def __init__(self, key: str, value: str):
|
||
self._key = key
|
||
self._value = value
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
ctx.metadata[self._key] = self._value
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result):
|
||
return result
|
||
|
||
|
||
class _ModifyResultMiddleware:
|
||
"""修改 result 的测试中间件。"""
|
||
|
||
def __init__(self, suffix: str):
|
||
self._suffix = suffix
|
||
|
||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||
return ctx
|
||
|
||
async def after(self, ctx: RequestContext, result):
|
||
if isinstance(result, str):
|
||
return result + self._suffix
|
||
return result
|
||
|
||
|
||
def _make_ctx() -> RequestContext:
|
||
return RequestContext(messages=[{"role": "user", "content": "test"}])
|
||
|
||
|
||
# ── 洋葱模型执行顺序 ──────────────────────────────────────
|
||
|
||
|
||
class TestMiddlewareChainOrder:
|
||
"""洋葱模型 before/after 执行顺序。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_before_outer_to_inner_after_inner_to_outer(self) -> None:
|
||
"""3 个中间件:before A→B→C,after C→B→A。"""
|
||
calls: list[str] = []
|
||
chain = MiddlewareChain(
|
||
[
|
||
_TraceMiddleware("A", calls),
|
||
_TraceMiddleware("B", calls),
|
||
_TraceMiddleware("C", calls),
|
||
]
|
||
)
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
calls.append("handler")
|
||
return "result"
|
||
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
|
||
assert result == "result"
|
||
assert calls == [
|
||
"before:A",
|
||
"before:B",
|
||
"before:C",
|
||
"handler",
|
||
"after:C",
|
||
"after:B",
|
||
"after:A",
|
||
]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_middleware(self) -> None:
|
||
"""单个中间件:before → handler → after。"""
|
||
calls: list[str] = []
|
||
chain = MiddlewareChain([_TraceMiddleware("X", calls)])
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
calls.append("handler")
|
||
return "ok"
|
||
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
|
||
assert result == "ok"
|
||
assert calls == ["before:X", "handler", "after:X"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_chain_calls_handler_directly(self) -> None:
|
||
"""空中间件链直接调用 handler。"""
|
||
chain = MiddlewareChain([])
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
return "direct"
|
||
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
assert result == "direct"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_none_middlewares_defaults_to_empty(self) -> None:
|
||
"""None middlewares 参数默认为空链。"""
|
||
chain = MiddlewareChain(None)
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
return "ok"
|
||
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
assert result == "ok"
|
||
|
||
|
||
# ── 错误处理 ──────────────────────────────────────────────
|
||
|
||
|
||
class TestMiddlewareChainErrors:
|
||
"""中间件链错误处理。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_before_error_stops_chain(self) -> None:
|
||
"""before 异常 → 后续 before 不执行,after 不触发,异常传播。"""
|
||
calls: list[str] = []
|
||
chain = MiddlewareChain(
|
||
[
|
||
_TraceMiddleware("A", calls),
|
||
_ErrorBeforeMiddleware(ValueError("before error")),
|
||
_TraceMiddleware("C", calls),
|
||
]
|
||
)
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
calls.append("handler")
|
||
return "result"
|
||
|
||
with pytest.raises(ValueError, match="before error"):
|
||
await chain.execute(_make_ctx(), handler)
|
||
|
||
# A.before ran, B.before raised, C.before and handler never ran
|
||
assert calls == ["before:A"]
|
||
# No after calls (chain was interrupted)
|
||
assert "after:" not in " ".join(calls)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_error_does_not_stop_chain(self) -> None:
|
||
"""after 异常 → 记录 warning,不中断后续 after,返回当前 result。"""
|
||
calls: list[str] = []
|
||
chain = MiddlewareChain(
|
||
[
|
||
_TraceMiddleware("A", calls),
|
||
_ErrorAfterMiddleware(RuntimeError("after error")),
|
||
_TraceMiddleware("C", calls),
|
||
]
|
||
)
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
calls.append("handler")
|
||
return "result"
|
||
|
||
# Should NOT raise — after errors are caught and logged
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
|
||
assert result == "result"
|
||
# All befores ran, handler ran, C.after ran (B.after failed but didn't stop)
|
||
assert "before:A" in calls
|
||
assert "before:B" not in calls # _ErrorAfterMiddleware doesn't trace
|
||
assert "before:C" in calls
|
||
assert "handler" in calls
|
||
assert "after:C" in calls
|
||
assert "after:A" in calls
|
||
|
||
|
||
# ── 状态传递 ──────────────────────────────────────────────
|
||
|
||
|
||
class TestMiddlewareChainState:
|
||
"""中间件间状态通过 RequestContext 传递。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_before_modifies_ctx_visible_to_handler(self) -> None:
|
||
"""before 修改的 ctx.metadata 在 handler 中可见。"""
|
||
chain = MiddlewareChain([_ModifyCtxMiddleware("key", "value")])
|
||
|
||
captured: dict[str, str] = {}
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
captured.update(ctx.metadata)
|
||
return "ok"
|
||
|
||
await chain.execute(_make_ctx(), handler)
|
||
|
||
assert captured["key"] == "value"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_modifies_result_visible_to_outer(self) -> None:
|
||
"""after 修改的 result 对外层可见。"""
|
||
chain = MiddlewareChain(
|
||
[
|
||
_ModifyResultMiddleware("_1"),
|
||
_ModifyResultMiddleware("_2"),
|
||
]
|
||
)
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
return "base"
|
||
|
||
result = await chain.execute(_make_ctx(), handler)
|
||
|
||
# after runs inner to outer: _2 first, then _1
|
||
# But _ModifyResultMiddleware appends suffix, so:
|
||
# handler returns "base"
|
||
# C.after: "base" + "_2" = "base_2"
|
||
# B.after: "base_2" + "_1" = "base_2_1"
|
||
# Wait — after runs in reverse order of before
|
||
# before: B(_1), C(_2) → after: C(_2), B(_1)
|
||
assert result == "base_2_1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multiple_middlewares_share_ctx(self) -> None:
|
||
"""多个中间件共享同一个 ctx 对象。"""
|
||
chain = MiddlewareChain(
|
||
[
|
||
_ModifyCtxMiddleware("a", "1"),
|
||
_ModifyCtxMiddleware("b", "2"),
|
||
]
|
||
)
|
||
|
||
captured: dict[str, str] = {}
|
||
|
||
async def handler(ctx: RequestContext) -> str:
|
||
captured.update(ctx.metadata)
|
||
return "ok"
|
||
|
||
await chain.execute(_make_ctx(), handler)
|
||
|
||
assert captured == {"a": "1", "b": "2"}
|
||
|
||
|
||
# ── 向后兼容 ──────────────────────────────────────────────
|
||
|
||
|
||
class TestMiddlewareBackwardCompat:
|
||
"""middleware_chain=None 时行为不变(向后兼容)。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_react_engine_without_middleware_works(self) -> None:
|
||
"""ReActEngine 无 middleware_chain 时走现有路径。"""
|
||
from agentkit.core.react import ReActEngine
|
||
from agentkit.llm.protocol import TokenUsage
|
||
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=MagicMock(
|
||
content="hello",
|
||
tool_calls=[],
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||
)
|
||
)
|
||
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=1)
|
||
# No middleware_chain — should use existing path
|
||
assert engine._middleware_chain is None
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
tools=[],
|
||
timeout_seconds=0,
|
||
)
|
||
assert result is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_react_engine_with_middleware_uses_chain(self) -> None:
|
||
"""ReActEngine 有 middleware_chain 时走中间件路径。"""
|
||
from agentkit.core.react import ReActEngine
|
||
from agentkit.llm.protocol import TokenUsage
|
||
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=MagicMock(
|
||
content="hello",
|
||
tool_calls=[],
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||
)
|
||
)
|
||
|
||
calls: list[str] = []
|
||
chain = MiddlewareChain([_TraceMiddleware("test", calls)])
|
||
|
||
engine = ReActEngine(
|
||
llm_gateway=gateway,
|
||
max_steps=1,
|
||
middleware_chain=chain,
|
||
)
|
||
assert engine._middleware_chain is chain
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
tools=[],
|
||
timeout_seconds=0,
|
||
)
|
||
assert result is not None
|
||
# Middleware before/after should have been called
|
||
assert "before:test" in calls
|
||
assert "after:test" in calls
|
||
|
||
|
||
# ── 示例中间件测试 ─────────────────────────────────────────
|
||
|
||
|
||
class TestSummarizationMiddleware:
|
||
"""SummarizationMiddleware 测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_compressor_passes_through(self) -> None:
|
||
"""无 compressor 时,before 不修改 ctx。"""
|
||
mw = SummarizationMiddleware(compressor=None)
|
||
ctx = _make_ctx()
|
||
result_ctx = await mw.before(ctx)
|
||
assert result_ctx is ctx
|
||
assert "compressed" not in ctx.metadata
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_compressor_not_needed_passes_through(self) -> None:
|
||
"""compressor 报告不需要压缩时,before 不修改 messages。"""
|
||
compressor = MagicMock()
|
||
compressor.should_compress = MagicMock(return_value=False)
|
||
mw = SummarizationMiddleware(compressor=compressor)
|
||
|
||
ctx = _make_ctx()
|
||
original_messages = list(ctx.messages)
|
||
await mw.before(ctx)
|
||
|
||
assert ctx.messages == original_messages
|
||
assert "compressed" not in ctx.metadata
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_compressor_triggers_compression(self) -> None:
|
||
"""compressor 报告需要压缩时,before 压缩 messages。"""
|
||
compressed_messages = [{"role": "user", "content": "compressed"}]
|
||
compressor = MagicMock()
|
||
compressor.should_compress = MagicMock(return_value=True)
|
||
compressor.compress = AsyncMock(return_value=compressed_messages)
|
||
mw = SummarizationMiddleware(compressor=compressor)
|
||
|
||
ctx = _make_ctx()
|
||
await mw.before(ctx)
|
||
|
||
assert ctx.messages == compressed_messages
|
||
assert ctx.metadata["compressed"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_compressor_failure_does_not_block(self) -> None:
|
||
"""compressor 压缩失败时,不阻断执行,messages 不变。"""
|
||
compressor = MagicMock()
|
||
compressor.should_compress = MagicMock(return_value=True)
|
||
compressor.compress = AsyncMock(side_effect=RuntimeError("compress failed"))
|
||
mw = SummarizationMiddleware(compressor=compressor)
|
||
|
||
ctx = _make_ctx()
|
||
original_messages = list(ctx.messages)
|
||
await mw.before(ctx)
|
||
|
||
# Should not raise, messages unchanged
|
||
assert ctx.messages == original_messages
|
||
assert "compressed" not in ctx.metadata
|
||
|
||
|
||
class TestTokenUsageMiddleware:
|
||
"""TokenUsageMiddleware 测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_before_returns_ctx_unchanged(self) -> None:
|
||
"""before 返回 ctx 不做修改(token 计量在 after 中完成)。"""
|
||
mw = TokenUsageMiddleware()
|
||
ctx = _make_ctx()
|
||
result = await mw.before(ctx)
|
||
assert result is ctx
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_extracts_usage_from_result(self) -> None:
|
||
"""after 从 result 提取 total_tokens(ReActResult 属性名)。"""
|
||
mw = TokenUsageMiddleware()
|
||
ctx = _make_ctx()
|
||
|
||
result = MagicMock()
|
||
result.total_tokens = {"total": 100}
|
||
|
||
await mw.after(ctx, result)
|
||
|
||
assert ctx.metadata["token_usage_total"] == {"total": 100}
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_no_usage_in_result(self) -> None:
|
||
"""result 无 token_usage 时,after 不修改 metadata。"""
|
||
mw = TokenUsageMiddleware()
|
||
ctx = _make_ctx()
|
||
|
||
result = MagicMock(spec=[]) # No token_usage attribute
|
||
await mw.after(ctx, result)
|
||
|
||
assert "token_usage_total" not in ctx.metadata
|
||
|
||
|
||
class TestLoopDetectionMiddleware:
|
||
"""LoopDetectionMiddleware 测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_before_initializes_window(self) -> None:
|
||
"""before 初始化 loop_detection_window。"""
|
||
mw = LoopDetectionMiddleware()
|
||
ctx = _make_ctx()
|
||
await mw.before(ctx)
|
||
assert ctx.metadata["loop_detection_window"] == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_no_trajectory_no_detection(self) -> None:
|
||
"""result 无 trajectory 时,after 不检测。"""
|
||
mw = LoopDetectionMiddleware()
|
||
ctx = _make_ctx()
|
||
|
||
result = MagicMock(spec=[]) # No trajectory attribute
|
||
await mw.after(ctx, result)
|
||
|
||
assert "loop_detected" not in ctx.metadata
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_detects_repeated_tool_calls(self) -> None:
|
||
"""trajectory 中有重复工具调用时,after 标记 loop_detected。"""
|
||
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
|
||
ctx = _make_ctx()
|
||
|
||
result = MagicMock()
|
||
result.trajectory = [
|
||
{"tool_name": "search", "arguments_hash": "abc"},
|
||
{"tool_name": "search", "arguments_hash": "abc"},
|
||
{"tool_name": "search", "arguments_hash": "abc"},
|
||
]
|
||
|
||
await mw.after(ctx, result)
|
||
|
||
assert ctx.metadata["loop_detected"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_after_no_loop_with_different_calls(self) -> None:
|
||
"""trajectory 中工具调用不同时,after 不标记 loop_detected。"""
|
||
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
|
||
ctx = _make_ctx()
|
||
|
||
result = MagicMock()
|
||
result.trajectory = [
|
||
{"tool_name": "search", "arguments_hash": "abc"},
|
||
{"tool_name": "read_file", "arguments_hash": "def"},
|
||
{"tool_name": "write_file", "arguments_hash": "ghi"},
|
||
]
|
||
|
||
await mw.after(ctx, result)
|
||
|
||
assert "loop_detected" not in ctx.metadata
|