"""U6: 中间件管道架构测试 测试洋葱模型中间件链的执行顺序、错误处理、向后兼容性。 """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from agentkit.core.middleware import ( LoopDetectionMiddleware, MiddlewareChain, RequestContext, SummarizationMiddleware, TokenUsageMiddleware, ) # ── 辅助 ────────────────────────────────────────────────── class _TraceMiddleware: """记录 before/after 调用顺序的测试中间件。""" def __init__(self, name: str, calls: list[str]): self._name = name self._calls = calls async def before(self, ctx: RequestContext) -> RequestContext: self._calls.append(f"before:{self._name}") return ctx async def after(self, ctx: RequestContext, result): self._calls.append(f"after:{self._name}") return result class _ErrorBeforeMiddleware: """before 抛异常的测试中间件。""" def __init__(self, error: Exception): self._error = error async def before(self, ctx: RequestContext) -> RequestContext: raise self._error async def after(self, ctx: RequestContext, result): return result class _ErrorAfterMiddleware: """after 抛异常的测试中间件。""" def __init__(self, error: Exception): self._error = error async def before(self, ctx: RequestContext) -> RequestContext: return ctx async def after(self, ctx: RequestContext, result): raise self._error class _ModifyCtxMiddleware: """修改 ctx 的测试中间件。""" def __init__(self, key: str, value: str): self._key = key self._value = value async def before(self, ctx: RequestContext) -> RequestContext: ctx.metadata[self._key] = self._value return ctx async def after(self, ctx: RequestContext, result): return result class _ModifyResultMiddleware: """修改 result 的测试中间件。""" def __init__(self, suffix: str): self._suffix = suffix async def before(self, ctx: RequestContext) -> RequestContext: return ctx async def after(self, ctx: RequestContext, result): if isinstance(result, str): return result + self._suffix return result def _make_ctx() -> RequestContext: return RequestContext(messages=[{"role": "user", "content": "test"}]) # ── 洋葱模型执行顺序 ────────────────────────────────────── class TestMiddlewareChainOrder: """洋葱模型 before/after 执行顺序。""" @pytest.mark.asyncio async def test_before_outer_to_inner_after_inner_to_outer(self) -> None: """3 个中间件:before A→B→C,after C→B→A。""" calls: list[str] = [] chain = MiddlewareChain( [ _TraceMiddleware("A", calls), _TraceMiddleware("B", calls), _TraceMiddleware("C", calls), ] ) async def handler(ctx: RequestContext) -> str: calls.append("handler") return "result" result = await chain.execute(_make_ctx(), handler) assert result == "result" assert calls == [ "before:A", "before:B", "before:C", "handler", "after:C", "after:B", "after:A", ] @pytest.mark.asyncio async def test_single_middleware(self) -> None: """单个中间件:before → handler → after。""" calls: list[str] = [] chain = MiddlewareChain([_TraceMiddleware("X", calls)]) async def handler(ctx: RequestContext) -> str: calls.append("handler") return "ok" result = await chain.execute(_make_ctx(), handler) assert result == "ok" assert calls == ["before:X", "handler", "after:X"] @pytest.mark.asyncio async def test_empty_chain_calls_handler_directly(self) -> None: """空中间件链直接调用 handler。""" chain = MiddlewareChain([]) async def handler(ctx: RequestContext) -> str: return "direct" result = await chain.execute(_make_ctx(), handler) assert result == "direct" @pytest.mark.asyncio async def test_none_middlewares_defaults_to_empty(self) -> None: """None middlewares 参数默认为空链。""" chain = MiddlewareChain(None) async def handler(ctx: RequestContext) -> str: return "ok" result = await chain.execute(_make_ctx(), handler) assert result == "ok" # ── 错误处理 ────────────────────────────────────────────── class TestMiddlewareChainErrors: """中间件链错误处理。""" @pytest.mark.asyncio async def test_before_error_stops_chain(self) -> None: """before 异常 → 后续 before 不执行,after 不触发,异常传播。""" calls: list[str] = [] chain = MiddlewareChain( [ _TraceMiddleware("A", calls), _ErrorBeforeMiddleware(ValueError("before error")), _TraceMiddleware("C", calls), ] ) async def handler(ctx: RequestContext) -> str: calls.append("handler") return "result" with pytest.raises(ValueError, match="before error"): await chain.execute(_make_ctx(), handler) # A.before ran, B.before raised, C.before and handler never ran assert calls == ["before:A"] # No after calls (chain was interrupted) assert "after:" not in " ".join(calls) @pytest.mark.asyncio async def test_after_error_does_not_stop_chain(self) -> None: """after 异常 → 记录 warning,不中断后续 after,返回当前 result。""" calls: list[str] = [] chain = MiddlewareChain( [ _TraceMiddleware("A", calls), _ErrorAfterMiddleware(RuntimeError("after error")), _TraceMiddleware("C", calls), ] ) async def handler(ctx: RequestContext) -> str: calls.append("handler") return "result" # Should NOT raise — after errors are caught and logged result = await chain.execute(_make_ctx(), handler) assert result == "result" # All befores ran, handler ran, C.after ran (B.after failed but didn't stop) assert "before:A" in calls assert "before:B" not in calls # _ErrorAfterMiddleware doesn't trace assert "before:C" in calls assert "handler" in calls assert "after:C" in calls assert "after:A" in calls # ── 状态传递 ────────────────────────────────────────────── class TestMiddlewareChainState: """中间件间状态通过 RequestContext 传递。""" @pytest.mark.asyncio async def test_before_modifies_ctx_visible_to_handler(self) -> None: """before 修改的 ctx.metadata 在 handler 中可见。""" chain = MiddlewareChain([_ModifyCtxMiddleware("key", "value")]) captured: dict[str, str] = {} async def handler(ctx: RequestContext) -> str: captured.update(ctx.metadata) return "ok" await chain.execute(_make_ctx(), handler) assert captured["key"] == "value" @pytest.mark.asyncio async def test_after_modifies_result_visible_to_outer(self) -> None: """after 修改的 result 对外层可见。""" chain = MiddlewareChain( [ _ModifyResultMiddleware("_1"), _ModifyResultMiddleware("_2"), ] ) async def handler(ctx: RequestContext) -> str: return "base" result = await chain.execute(_make_ctx(), handler) # after runs inner to outer: _2 first, then _1 # But _ModifyResultMiddleware appends suffix, so: # handler returns "base" # C.after: "base" + "_2" = "base_2" # B.after: "base_2" + "_1" = "base_2_1" # Wait — after runs in reverse order of before # before: B(_1), C(_2) → after: C(_2), B(_1) assert result == "base_2_1" @pytest.mark.asyncio async def test_multiple_middlewares_share_ctx(self) -> None: """多个中间件共享同一个 ctx 对象。""" chain = MiddlewareChain( [ _ModifyCtxMiddleware("a", "1"), _ModifyCtxMiddleware("b", "2"), ] ) captured: dict[str, str] = {} async def handler(ctx: RequestContext) -> str: captured.update(ctx.metadata) return "ok" await chain.execute(_make_ctx(), handler) assert captured == {"a": "1", "b": "2"} # ── 向后兼容 ────────────────────────────────────────────── class TestMiddlewareBackwardCompat: """middleware_chain=None 时行为不变(向后兼容)。""" @pytest.mark.asyncio async def test_react_engine_without_middleware_works(self) -> None: """ReActEngine 无 middleware_chain 时走现有路径。""" from agentkit.core.react import ReActEngine from agentkit.llm.protocol import TokenUsage gateway = MagicMock() gateway.chat = AsyncMock( return_value=MagicMock( content="hello", tool_calls=[], usage=TokenUsage(prompt_tokens=10, completion_tokens=5), ) ) engine = ReActEngine(llm_gateway=gateway, max_steps=1) # No middleware_chain — should use existing path assert engine._middleware_chain is None result = await engine.execute( messages=[{"role": "user", "content": "hi"}], tools=[], timeout_seconds=0, ) assert result is not None @pytest.mark.asyncio async def test_react_engine_with_middleware_uses_chain(self) -> None: """ReActEngine 有 middleware_chain 时走中间件路径。""" from agentkit.core.react import ReActEngine from agentkit.llm.protocol import TokenUsage gateway = MagicMock() gateway.chat = AsyncMock( return_value=MagicMock( content="hello", tool_calls=[], usage=TokenUsage(prompt_tokens=10, completion_tokens=5), ) ) calls: list[str] = [] chain = MiddlewareChain([_TraceMiddleware("test", calls)]) engine = ReActEngine( llm_gateway=gateway, max_steps=1, middleware_chain=chain, ) assert engine._middleware_chain is chain result = await engine.execute( messages=[{"role": "user", "content": "hi"}], tools=[], timeout_seconds=0, ) assert result is not None # Middleware before/after should have been called assert "before:test" in calls assert "after:test" in calls # ── 示例中间件测试 ───────────────────────────────────────── class TestSummarizationMiddleware: """SummarizationMiddleware 测试。""" @pytest.mark.asyncio async def test_no_compressor_passes_through(self) -> None: """无 compressor 时,before 不修改 ctx。""" mw = SummarizationMiddleware(compressor=None) ctx = _make_ctx() result_ctx = await mw.before(ctx) assert result_ctx is ctx assert "compressed" not in ctx.metadata @pytest.mark.asyncio async def test_compressor_not_needed_passes_through(self) -> None: """compressor 报告不需要压缩时,before 不修改 messages。""" compressor = MagicMock() compressor.should_compress = MagicMock(return_value=False) mw = SummarizationMiddleware(compressor=compressor) ctx = _make_ctx() original_messages = list(ctx.messages) await mw.before(ctx) assert ctx.messages == original_messages assert "compressed" not in ctx.metadata @pytest.mark.asyncio async def test_compressor_triggers_compression(self) -> None: """compressor 报告需要压缩时,before 压缩 messages。""" compressed_messages = [{"role": "user", "content": "compressed"}] compressor = MagicMock() compressor.should_compress = MagicMock(return_value=True) compressor.compress = AsyncMock(return_value=compressed_messages) mw = SummarizationMiddleware(compressor=compressor) ctx = _make_ctx() await mw.before(ctx) assert ctx.messages == compressed_messages assert ctx.metadata["compressed"] is True @pytest.mark.asyncio async def test_compressor_failure_does_not_block(self) -> None: """compressor 压缩失败时,不阻断执行,messages 不变。""" compressor = MagicMock() compressor.should_compress = MagicMock(return_value=True) compressor.compress = AsyncMock(side_effect=RuntimeError("compress failed")) mw = SummarizationMiddleware(compressor=compressor) ctx = _make_ctx() original_messages = list(ctx.messages) await mw.before(ctx) # Should not raise, messages unchanged assert ctx.messages == original_messages assert "compressed" not in ctx.metadata class TestTokenUsageMiddleware: """TokenUsageMiddleware 测试。""" @pytest.mark.asyncio async def test_before_returns_ctx_unchanged(self) -> None: """before 返回 ctx 不做修改(token 计量在 after 中完成)。""" mw = TokenUsageMiddleware() ctx = _make_ctx() result = await mw.before(ctx) assert result is ctx @pytest.mark.asyncio async def test_after_extracts_usage_from_result(self) -> None: """after 从 result 提取 token_usage。""" mw = TokenUsageMiddleware() ctx = _make_ctx() result = MagicMock() result.token_usage = {"total": 100} await mw.after(ctx, result) assert ctx.metadata["token_usage_total"] == {"total": 100} @pytest.mark.asyncio async def test_after_no_usage_in_result(self) -> None: """result 无 token_usage 时,after 不修改 metadata。""" mw = TokenUsageMiddleware() ctx = _make_ctx() result = MagicMock(spec=[]) # No token_usage attribute await mw.after(ctx, result) assert "token_usage_total" not in ctx.metadata class TestLoopDetectionMiddleware: """LoopDetectionMiddleware 测试。""" @pytest.mark.asyncio async def test_before_initializes_window(self) -> None: """before 初始化 loop_detection_window。""" mw = LoopDetectionMiddleware() ctx = _make_ctx() await mw.before(ctx) assert ctx.metadata["loop_detection_window"] == [] @pytest.mark.asyncio async def test_after_no_trajectory_no_detection(self) -> None: """result 无 trajectory 时,after 不检测。""" mw = LoopDetectionMiddleware() ctx = _make_ctx() result = MagicMock(spec=[]) # No trajectory attribute await mw.after(ctx, result) assert "loop_detected" not in ctx.metadata @pytest.mark.asyncio async def test_after_detects_repeated_tool_calls(self) -> None: """trajectory 中有重复工具调用时,after 标记 loop_detected。""" mw = LoopDetectionMiddleware(window_size=3, threshold=2) ctx = _make_ctx() result = MagicMock() result.trajectory = [ {"tool_name": "search", "arguments_hash": "abc"}, {"tool_name": "search", "arguments_hash": "abc"}, {"tool_name": "search", "arguments_hash": "abc"}, ] await mw.after(ctx, result) assert ctx.metadata["loop_detected"] is True @pytest.mark.asyncio async def test_after_no_loop_with_different_calls(self) -> None: """trajectory 中工具调用不同时,after 不标记 loop_detected。""" mw = LoopDetectionMiddleware(window_size=3, threshold=2) ctx = _make_ctx() result = MagicMock() result.trajectory = [ {"tool_name": "search", "arguments_hash": "abc"}, {"tool_name": "read_file", "arguments_hash": "def"}, {"tool_name": "write_file", "arguments_hash": "ghi"}, ] await mw.after(ctx, result) assert "loop_detected" not in ctx.metadata