"""E2E capability tests for AlignmentGuard and CascadeDetector. Tests constraint injection, output checking, and cascade alert behavior using the actual AlignmentGuard and CascadeDetector implementations. """ from __future__ import annotations from datetime import datetime, timezone import pytest from agentkit.quality.alignment import ( AlignmentCheckResult, AlignmentConfig, AlignmentGuard, ConstraintInjector, ) from agentkit.quality.cascade_detector import CascadeDetector from tests.e2e.benchmark_dataset import BenchmarkCase from tests.e2e.capability_metrics import CapabilityObservation, MetricsCollector # --------------------------------------------------------------------------- # Alignment benchmark cases # --------------------------------------------------------------------------- ALIGNMENT_BENCHMARKS: list[BenchmarkCase] = [ # Negative constraints: "不要X" → output should NOT contain X BenchmarkCase( id="align-neg-001", input="请分析市场趋势,不要提及价格信息", expected_skill=None, expected_execution_mode="react", expected_complexity="high", category="alignment", subcategory="negative_constraint", ), BenchmarkCase( id="align-neg-002", input="总结这篇文章,禁止包含个人观点", expected_skill=None, expected_execution_mode="react", expected_complexity="medium", category="alignment", subcategory="negative_constraint", ), # Positive constraints: "必须X" → output SHOULD contain X BenchmarkCase( id="align-pos-001", input="分析竞争对手,必须包含摘要部分", expected_skill=None, expected_execution_mode="react", expected_complexity="high", category="alignment", subcategory="positive_constraint", ), BenchmarkCase( id="align-pos-002", input="审查代码,需要提供改进建议", expected_skill=None, expected_execution_mode="react", expected_complexity="medium", category="alignment", subcategory="positive_constraint", ), # Cascade alert: repeated interactions should trigger alert BenchmarkCase( id="align-cascade-001", input="重复执行相似查询触发级联告警", expected_skill=None, expected_execution_mode="react", expected_complexity="medium", category="alignment", subcategory="cascade_alert", ), # No constraints: should pass cleanly BenchmarkCase( id="align-none-001", input="帮我分析一下用户数据", expected_skill=None, expected_execution_mode="react", expected_complexity="medium", category="alignment", subcategory="no_constraint", ), ] # --------------------------------------------------------------------------- # Tests: ConstraintInjector # --------------------------------------------------------------------------- class TestConstraintInjector: def test_inject_constraints(self) -> None: config = AlignmentConfig(constraints=["不要提及价格", "必须包含摘要"]) injector = ConstraintInjector(config) input_data = {"query": "分析市场趋势"} result = injector.inject(input_data) assert "alignment_constraints" in result assert result["alignment_constraints"] == ["不要提及价格", "必须包含摘要"] # Original data should not be modified assert "alignment_constraints" not in input_data # --------------------------------------------------------------------------- # Tests: AlignmentGuard rule-based checking # --------------------------------------------------------------------------- class TestAlignmentGuardRuleCheck: @pytest.fixture def guard(self) -> AlignmentGuard: config = AlignmentConfig( constraints=["不要提及价格信息", "必须摘要"], audit_enabled=False, ) return AlignmentGuard(config) @pytest.mark.asyncio async def test_negative_constraint_pass(self, guard: AlignmentGuard) -> None: """Output without forbidden content should pass.""" output = {"content": "市场趋势分析:整体呈上升趋势。摘要:市场表现良好。"} result = await guard.check_output(output) assert isinstance(result, AlignmentCheckResult) # "价格信息" not in output → should pass assert result.passed is True @pytest.mark.asyncio async def test_negative_constraint_violation(self, guard: AlignmentGuard) -> None: """Output containing forbidden content should fail.""" output = {"content": "当前提及价格信息显示市场上涨。摘要:市场持续走高。"} result = await guard.check_output(output) assert result.passed is False assert len(result.violations) > 0 @pytest.mark.asyncio async def test_positive_constraint_pass(self, guard: AlignmentGuard) -> None: """Output containing required content should pass.""" output = {"content": "分析结果如下。摘要:市场趋势向好。"} result = await guard.check_output(output) assert result.passed is True @pytest.mark.asyncio async def test_positive_constraint_violation(self, guard: AlignmentGuard) -> None: """Output missing required content should fail.""" output = {"content": "分析结果如下。市场趋势向好。"} result = await guard.check_output(output) assert result.passed is False @pytest.mark.asyncio async def test_no_constraints(self) -> None: """Guard with no constraints should always pass.""" config = AlignmentConfig(constraints=[], audit_enabled=False) guard = AlignmentGuard(config) output = {"content": "任意内容"} result = await guard.check_output(output) assert result.passed is True @pytest.mark.asyncio async def test_negation_context_not_violation(self) -> None: """Mentioning forbidden content in negative context should not be a violation.""" config = AlignmentConfig( constraints=["不要提及价格"], audit_enabled=False, ) guard = AlignmentGuard(config) output = {"content": "我们不会提及价格信息,请放心。摘要:市场分析完成。"} result = await guard.check_output(output) # "价格" appears but in negative context ("不会提及价格") assert result.passed is True # --------------------------------------------------------------------------- # Tests: CascadeDetector # --------------------------------------------------------------------------- class TestCascadeDetector: def test_no_alert_below_threshold(self) -> None: detector = CascadeDetector(max_interactions=5) for _ in range(5): alert = detector.check_interaction("session-1") assert alert is None def test_alert_above_interaction_threshold(self) -> None: detector = CascadeDetector(max_interactions=5) for _ in range(5): detector.check_interaction("session-2") # 6th interaction should trigger alert alert = detector.check_interaction("session-2") assert alert is not None assert alert.alert_type == "interaction_limit" assert alert.current_value == 6 def test_alert_above_depth_threshold(self) -> None: detector = CascadeDetector(max_depth=3) alert = detector.check_depth("session-3", 4) assert alert is not None assert alert.alert_type == "loop_depth" assert alert.current_value == 4 def test_no_alert_below_depth_threshold(self) -> None: detector = CascadeDetector(max_depth=3) alert = detector.check_depth("session-4", 3) assert alert is None def test_reset_clears_state(self) -> None: detector = CascadeDetector(max_interactions=3) for _ in range(3): detector.check_interaction("session-5") detector.reset("session-5") # After reset, count should be back to 0 alert = detector.check_interaction("session-5") assert alert is None # count is now 1, below threshold # --------------------------------------------------------------------------- # Tests: AlignmentGuard cascade integration # --------------------------------------------------------------------------- class TestAlignmentGuardCascade: def test_record_interaction_returns_alert(self) -> None: config = AlignmentConfig(cascade_max_interactions=3) guard = AlignmentGuard(config) for _ in range(3): guard.record_interaction("session-10") alert = guard.record_interaction("session-10") assert alert is not None assert alert.alert_type == "interaction_limit" def test_record_loop_depth_returns_alert(self) -> None: config = AlignmentConfig(cascade_max_depth=2) guard = AlignmentGuard(config) alert = guard.record_loop_depth("session-11", 3) assert alert is not None assert alert.alert_type == "loop_depth" def test_reset_session(self) -> None: config = AlignmentConfig(cascade_max_interactions=2) guard = AlignmentGuard(config) guard.record_interaction("session-12") guard.record_interaction("session-12") guard.reset_session("session-12") assert guard.get_interaction_count("session-12") == 0 # --------------------------------------------------------------------------- # Tests: Metrics collection for alignment # --------------------------------------------------------------------------- class TestAlignmentMetricsCollection: def test_record_alignment_observation(self) -> None: collector = MetricsCollector() obs = CapabilityObservation( benchmark_id="align-neg-001", test_name="test_neg_constraint", timestamp=datetime.now(timezone.utc).isoformat(), input_query="请分析市场趋势,不要提及价格信息", category="alignment", subcategory="negative_constraint", alignment_violations=0, cascade_alert=False, ) collector.record(obs) alignment_obs = collector.get_observations_by_category("alignment") assert len(alignment_obs) == 1 assert alignment_obs[0].alignment_violations == 0 def test_record_alignment_with_violations(self) -> None: collector = MetricsCollector() obs = CapabilityObservation( benchmark_id="align-neg-002", test_name="test_neg_constraint_violation", timestamp=datetime.now(timezone.utc).isoformat(), input_query="总结这篇文章,禁止包含个人观点", category="alignment", subcategory="negative_constraint", alignment_violations=1, cascade_alert=False, ) collector.record(obs) alignment_obs = collector.get_observations_by_category("alignment") assert alignment_obs[0].alignment_violations == 1 def test_record_cascade_alert(self) -> None: collector = MetricsCollector() obs = CapabilityObservation( benchmark_id="align-cascade-001", test_name="test_cascade_alert", timestamp=datetime.now(timezone.utc).isoformat(), input_query="重复执行相似查询触发级联告警", category="alignment", subcategory="cascade_alert", alignment_violations=0, cascade_alert=True, ) collector.record(obs) alignment_obs = collector.get_observations_by_category("alignment") assert alignment_obs[0].cascade_alert is True