fischer-agentkit/tests/unit/test_middleware.py

518 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U6: 中间件管道架构测试
测试洋葱模型中间件链的执行顺序、错误处理、向后兼容性。
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.middleware import (
LoopDetectionMiddleware,
MiddlewareChain,
RequestContext,
SummarizationMiddleware,
TokenUsageMiddleware,
)
# ── 辅助 ──────────────────────────────────────────────────
class _TraceMiddleware:
"""记录 before/after 调用顺序的测试中间件。"""
def __init__(self, name: str, calls: list[str]):
self._name = name
self._calls = calls
async def before(self, ctx: RequestContext) -> RequestContext:
self._calls.append(f"before:{self._name}")
return ctx
async def after(self, ctx: RequestContext, result):
self._calls.append(f"after:{self._name}")
return result
class _ErrorBeforeMiddleware:
"""before 抛异常的测试中间件。"""
def __init__(self, error: Exception):
self._error = error
async def before(self, ctx: RequestContext) -> RequestContext:
raise self._error
async def after(self, ctx: RequestContext, result):
return result
class _ErrorAfterMiddleware:
"""after 抛异常的测试中间件。"""
def __init__(self, error: Exception):
self._error = error
async def before(self, ctx: RequestContext) -> RequestContext:
return ctx
async def after(self, ctx: RequestContext, result):
raise self._error
class _ModifyCtxMiddleware:
"""修改 ctx 的测试中间件。"""
def __init__(self, key: str, value: str):
self._key = key
self._value = value
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata[self._key] = self._value
return ctx
async def after(self, ctx: RequestContext, result):
return result
class _ModifyResultMiddleware:
"""修改 result 的测试中间件。"""
def __init__(self, suffix: str):
self._suffix = suffix
async def before(self, ctx: RequestContext) -> RequestContext:
return ctx
async def after(self, ctx: RequestContext, result):
if isinstance(result, str):
return result + self._suffix
return result
def _make_ctx() -> RequestContext:
return RequestContext(messages=[{"role": "user", "content": "test"}])
# ── 洋葱模型执行顺序 ──────────────────────────────────────
class TestMiddlewareChainOrder:
"""洋葱模型 before/after 执行顺序。"""
@pytest.mark.asyncio
async def test_before_outer_to_inner_after_inner_to_outer(self) -> None:
"""3 个中间件before A→B→Cafter C→B→A。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_TraceMiddleware("B", calls),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
result = await chain.execute(_make_ctx(), handler)
assert result == "result"
assert calls == [
"before:A",
"before:B",
"before:C",
"handler",
"after:C",
"after:B",
"after:A",
]
@pytest.mark.asyncio
async def test_single_middleware(self) -> None:
"""单个中间件before → handler → after。"""
calls: list[str] = []
chain = MiddlewareChain([_TraceMiddleware("X", calls)])
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "ok"
result = await chain.execute(_make_ctx(), handler)
assert result == "ok"
assert calls == ["before:X", "handler", "after:X"]
@pytest.mark.asyncio
async def test_empty_chain_calls_handler_directly(self) -> None:
"""空中间件链直接调用 handler。"""
chain = MiddlewareChain([])
async def handler(ctx: RequestContext) -> str:
return "direct"
result = await chain.execute(_make_ctx(), handler)
assert result == "direct"
@pytest.mark.asyncio
async def test_none_middlewares_defaults_to_empty(self) -> None:
"""None middlewares 参数默认为空链。"""
chain = MiddlewareChain(None)
async def handler(ctx: RequestContext) -> str:
return "ok"
result = await chain.execute(_make_ctx(), handler)
assert result == "ok"
# ── 错误处理 ──────────────────────────────────────────────
class TestMiddlewareChainErrors:
"""中间件链错误处理。"""
@pytest.mark.asyncio
async def test_before_error_stops_chain(self) -> None:
"""before 异常 → 后续 before 不执行after 不触发,异常传播。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_ErrorBeforeMiddleware(ValueError("before error")),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
with pytest.raises(ValueError, match="before error"):
await chain.execute(_make_ctx(), handler)
# A.before ran, B.before raised, C.before and handler never ran
assert calls == ["before:A"]
# No after calls (chain was interrupted)
assert "after:" not in " ".join(calls)
@pytest.mark.asyncio
async def test_after_error_does_not_stop_chain(self) -> None:
"""after 异常 → 记录 warning不中断后续 after返回当前 result。"""
calls: list[str] = []
chain = MiddlewareChain(
[
_TraceMiddleware("A", calls),
_ErrorAfterMiddleware(RuntimeError("after error")),
_TraceMiddleware("C", calls),
]
)
async def handler(ctx: RequestContext) -> str:
calls.append("handler")
return "result"
# Should NOT raise — after errors are caught and logged
result = await chain.execute(_make_ctx(), handler)
assert result == "result"
# All befores ran, handler ran, C.after ran (B.after failed but didn't stop)
assert "before:A" in calls
assert "before:B" not in calls # _ErrorAfterMiddleware doesn't trace
assert "before:C" in calls
assert "handler" in calls
assert "after:C" in calls
assert "after:A" in calls
# ── 状态传递 ──────────────────────────────────────────────
class TestMiddlewareChainState:
"""中间件间状态通过 RequestContext 传递。"""
@pytest.mark.asyncio
async def test_before_modifies_ctx_visible_to_handler(self) -> None:
"""before 修改的 ctx.metadata 在 handler 中可见。"""
chain = MiddlewareChain([_ModifyCtxMiddleware("key", "value")])
captured: dict[str, str] = {}
async def handler(ctx: RequestContext) -> str:
captured.update(ctx.metadata)
return "ok"
await chain.execute(_make_ctx(), handler)
assert captured["key"] == "value"
@pytest.mark.asyncio
async def test_after_modifies_result_visible_to_outer(self) -> None:
"""after 修改的 result 对外层可见。"""
chain = MiddlewareChain(
[
_ModifyResultMiddleware("_1"),
_ModifyResultMiddleware("_2"),
]
)
async def handler(ctx: RequestContext) -> str:
return "base"
result = await chain.execute(_make_ctx(), handler)
# after runs inner to outer: _2 first, then _1
# But _ModifyResultMiddleware appends suffix, so:
# handler returns "base"
# C.after: "base" + "_2" = "base_2"
# B.after: "base_2" + "_1" = "base_2_1"
# Wait — after runs in reverse order of before
# before: B(_1), C(_2) → after: C(_2), B(_1)
assert result == "base_2_1"
@pytest.mark.asyncio
async def test_multiple_middlewares_share_ctx(self) -> None:
"""多个中间件共享同一个 ctx 对象。"""
chain = MiddlewareChain(
[
_ModifyCtxMiddleware("a", "1"),
_ModifyCtxMiddleware("b", "2"),
]
)
captured: dict[str, str] = {}
async def handler(ctx: RequestContext) -> str:
captured.update(ctx.metadata)
return "ok"
await chain.execute(_make_ctx(), handler)
assert captured == {"a": "1", "b": "2"}
# ── 向后兼容 ──────────────────────────────────────────────
class TestMiddlewareBackwardCompat:
"""middleware_chain=None 时行为不变(向后兼容)。"""
@pytest.mark.asyncio
async def test_react_engine_without_middleware_works(self) -> None:
"""ReActEngine 无 middleware_chain 时走现有路径。"""
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=MagicMock(
content="hello",
tool_calls=[],
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
)
)
engine = ReActEngine(llm_gateway=gateway, max_steps=1)
# No middleware_chain — should use existing path
assert engine._middleware_chain is None
result = await engine.execute(
messages=[{"role": "user", "content": "hi"}],
tools=[],
timeout_seconds=0,
)
assert result is not None
@pytest.mark.asyncio
async def test_react_engine_with_middleware_uses_chain(self) -> None:
"""ReActEngine 有 middleware_chain 时走中间件路径。"""
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=MagicMock(
content="hello",
tool_calls=[],
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
)
)
calls: list[str] = []
chain = MiddlewareChain([_TraceMiddleware("test", calls)])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=1,
middleware_chain=chain,
)
assert engine._middleware_chain is chain
result = await engine.execute(
messages=[{"role": "user", "content": "hi"}],
tools=[],
timeout_seconds=0,
)
assert result is not None
# Middleware before/after should have been called
assert "before:test" in calls
assert "after:test" in calls
# ── 示例中间件测试 ─────────────────────────────────────────
class TestSummarizationMiddleware:
"""SummarizationMiddleware 测试。"""
@pytest.mark.asyncio
async def test_no_compressor_passes_through(self) -> None:
"""无 compressor 时before 不修改 ctx。"""
mw = SummarizationMiddleware(compressor=None)
ctx = _make_ctx()
result_ctx = await mw.before(ctx)
assert result_ctx is ctx
assert "compressed" not in ctx.metadata
@pytest.mark.asyncio
async def test_compressor_not_needed_passes_through(self) -> None:
"""compressor 报告不需要压缩时before 不修改 messages。"""
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=False)
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
original_messages = list(ctx.messages)
await mw.before(ctx)
assert ctx.messages == original_messages
assert "compressed" not in ctx.metadata
@pytest.mark.asyncio
async def test_compressor_triggers_compression(self) -> None:
"""compressor 报告需要压缩时before 压缩 messages。"""
compressed_messages = [{"role": "user", "content": "compressed"}]
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=True)
compressor.compress = AsyncMock(return_value=compressed_messages)
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
await mw.before(ctx)
assert ctx.messages == compressed_messages
assert ctx.metadata["compressed"] is True
@pytest.mark.asyncio
async def test_compressor_failure_does_not_block(self) -> None:
"""compressor 压缩失败时不阻断执行messages 不变。"""
compressor = MagicMock()
compressor.should_compress = MagicMock(return_value=True)
compressor.compress = AsyncMock(side_effect=RuntimeError("compress failed"))
mw = SummarizationMiddleware(compressor=compressor)
ctx = _make_ctx()
original_messages = list(ctx.messages)
await mw.before(ctx)
# Should not raise, messages unchanged
assert ctx.messages == original_messages
assert "compressed" not in ctx.metadata
class TestTokenUsageMiddleware:
"""TokenUsageMiddleware 测试。"""
@pytest.mark.asyncio
async def test_before_returns_ctx_unchanged(self) -> None:
"""before 返回 ctx 不做修改token 计量在 after 中完成)。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
result = await mw.before(ctx)
assert result is ctx
@pytest.mark.asyncio
async def test_after_extracts_usage_from_result(self) -> None:
"""after 从 result 提取 total_tokensReActResult 属性名)。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
result = MagicMock()
result.total_tokens = {"total": 100}
await mw.after(ctx, result)
assert ctx.metadata["token_usage_total"] == {"total": 100}
@pytest.mark.asyncio
async def test_after_no_usage_in_result(self) -> None:
"""result 无 token_usage 时after 不修改 metadata。"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
result = MagicMock(spec=[]) # No token_usage attribute
await mw.after(ctx, result)
assert "token_usage_total" not in ctx.metadata
class TestLoopDetectionMiddleware:
"""LoopDetectionMiddleware 测试。"""
@pytest.mark.asyncio
async def test_before_initializes_window(self) -> None:
"""before 初始化 loop_detection_window。"""
mw = LoopDetectionMiddleware()
ctx = _make_ctx()
await mw.before(ctx)
assert ctx.metadata["loop_detection_window"] == []
@pytest.mark.asyncio
async def test_after_no_trajectory_no_detection(self) -> None:
"""result 无 trajectory 时after 不检测。"""
mw = LoopDetectionMiddleware()
ctx = _make_ctx()
result = MagicMock(spec=[]) # No trajectory attribute
await mw.after(ctx, result)
assert "loop_detected" not in ctx.metadata
@pytest.mark.asyncio
async def test_after_detects_repeated_tool_calls(self) -> None:
"""trajectory 中有重复工具调用时after 标记 loop_detected。"""
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
ctx = _make_ctx()
result = MagicMock()
result.trajectory = [
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "search", "arguments_hash": "abc"},
]
await mw.after(ctx, result)
assert ctx.metadata["loop_detected"] is True
@pytest.mark.asyncio
async def test_after_no_loop_with_different_calls(self) -> None:
"""trajectory 中工具调用不同时after 不标记 loop_detected。"""
mw = LoopDetectionMiddleware(window_size=3, threshold=2)
ctx = _make_ctx()
result = MagicMock()
result.trajectory = [
{"tool_name": "search", "arguments_hash": "abc"},
{"tool_name": "read_file", "arguments_hash": "def"},
{"tool_name": "write_file", "arguments_hash": "ghi"},
]
await mw.after(ctx, result)
assert "loop_detected" not in ctx.metadata