"""AlignmentGuard 单元测试""" import asyncio from unittest.mock import AsyncMock, MagicMock import pytest from agentkit.quality.alignment import ( AlignmentCheckResult, AlignmentConfig, AlignmentGuard, CascadeAlert, ConstraintInjector, ) from agentkit.quality.cascade_detector import CascadeDetector from agentkit.skills.base import SkillConfig # ── AlignmentConfig 测试 ────────────────────────────────── class TestAlignmentConfig: """AlignmentConfig 默认值测试""" def test_default_values(self): config = AlignmentConfig() assert config.constraints == [] assert config.cascade_max_interactions == 10 assert config.cascade_max_depth == 3 assert config.audit_enabled is False assert config.audit_model == "default" def test_custom_values(self): config = AlignmentConfig( constraints=["no_harm", "be_honest"], cascade_max_interactions=5, cascade_max_depth=2, audit_enabled=True, audit_model="gpt-4", ) assert config.constraints == ["no_harm", "be_honest"] assert config.cascade_max_interactions == 5 assert config.cascade_max_depth == 2 assert config.audit_enabled is True assert config.audit_model == "gpt-4" # ── ConstraintInjector 测试 ─────────────────────────────── class TestConstraintInjector: """ConstraintInjector 约束注入测试""" def test_inject_constraints_into_input_data(self): config = AlignmentConfig(constraints=["no_harm", "be_honest"]) injector = ConstraintInjector(config) result = injector.inject({"task": "translate"}) assert "alignment_constraints" in result assert result["alignment_constraints"] == ["no_harm", "be_honest"] assert result["task"] == "translate" def test_does_not_modify_original_dict(self): config = AlignmentConfig(constraints=["no_harm"]) injector = ConstraintInjector(config) original = {"task": "translate"} result = injector.inject(original) assert "alignment_constraints" not in original assert "alignment_constraints" in result def test_empty_constraints(self): config = AlignmentConfig(constraints=[]) injector = ConstraintInjector(config) result = injector.inject({"task": "translate"}) assert result["alignment_constraints"] == [] # ── AlignmentGuard.check_output 测试 ────────────────────── class TestAlignmentGuardCheckOutput: """AlignmentGuard.check_output 对齐检查""" async def test_rule_check_violation_keyword_match(self): config = AlignmentConfig(constraints=["forbidden_word"]) guard = AlignmentGuard(config) output = {"content": "This contains forbidden_word in text"} result = await guard.check_output(output) assert result.passed is False assert "forbidden_word" in result.violations assert result.checked_by == "rule" async def test_rule_check_passes_no_violations(self): config = AlignmentConfig(constraints=["forbidden_word"]) guard = AlignmentGuard(config) output = {"content": "This is clean text"} result = await guard.check_output(output) assert result.passed is True assert result.violations == [] assert result.checked_by == "rule" async def test_no_constraints_passes(self): config = AlignmentConfig(constraints=[]) guard = AlignmentGuard(config) result = await guard.check_output({"content": "anything"}) assert result.passed is True assert result.checked_by == "rule" async def test_audit_disabled_does_not_call_llm(self): config = AlignmentConfig( constraints=["no_harm"], audit_enabled=False ) mock_gateway = AsyncMock() guard = AlignmentGuard(config, llm_gateway=mock_gateway) output = {"content": "This is safe"} result = await guard.check_output(output) assert result.checked_by == "rule" mock_gateway.chat.assert_not_called() async def test_audit_enabled_calls_llm_for_semantic_check(self): config = AlignmentConfig( constraints=["be_respectful"], audit_enabled=True, audit_model="gpt-4" ) mock_response = MagicMock() mock_response.content = "PASS" mock_gateway = AsyncMock() mock_gateway.chat.return_value = mock_response guard = AlignmentGuard(config, llm_gateway=mock_gateway) output = {"content": "This is respectful text"} # Rule check passes first (no keyword match), then LLM audit result = await guard.check_output(output) assert result.checked_by == "llm" mock_gateway.chat.assert_called_once() async def test_audit_enabled_llm_detects_violation(self): config = AlignmentConfig( constraints=["be_respectful"], audit_enabled=True ) mock_response = MagicMock() mock_response.content = "VIOLATION: Output is disrespectful" mock_gateway = AsyncMock() mock_gateway.chat.return_value = mock_response guard = AlignmentGuard(config, llm_gateway=mock_gateway) output = {"content": "This is borderline text"} result = await guard.check_output(output) assert result.passed is False assert result.checked_by == "llm" async def test_audit_enabled_no_llm_gateway_skips_llm(self): config = AlignmentConfig( constraints=["be_respectful"], audit_enabled=True ) guard = AlignmentGuard(config, llm_gateway=None) output = {"content": "This is safe"} result = await guard.check_output(output) assert result.checked_by == "rule" async def test_custom_constraints_override_config(self): config = AlignmentConfig(constraints=["default_constraint"]) guard = AlignmentGuard(config) output = {"content": "This has custom_violation in it"} result = await guard.check_output(output, constraints=["custom_violation"]) assert result.passed is False assert "custom_violation" in result.violations async def test_case_insensitive_matching(self): config = AlignmentConfig(constraints=["ForBiDdEn"]) guard = AlignmentGuard(config) output = {"content": "This has forbidden in it"} result = await guard.check_output(output) assert result.passed is False # ── AlignmentGuard 级联检测测试 ─────────────────────────── class TestAlignmentGuardCascade: """AlignmentGuard 级联故障检测""" def test_record_interaction_returns_alert_when_exceeded(self): config = AlignmentConfig(cascade_max_interactions=3) guard = AlignmentGuard(config) # 前 3 次不触发 assert guard.record_interaction("s1") is None assert guard.record_interaction("s1") is None assert guard.record_interaction("s1") is None # 第 4 次触发 alert = guard.record_interaction("s1") assert alert is not None assert alert.session_id == "s1" assert alert.alert_type == "interaction_limit" assert alert.current_value == 4 assert alert.threshold == 3 def test_record_interaction_below_threshold_returns_none(self): config = AlignmentConfig(cascade_max_interactions=10) guard = AlignmentGuard(config) assert guard.record_interaction("s1") is None def test_record_loop_depth_returns_alert_when_exceeded(self): config = AlignmentConfig(cascade_max_depth=2) guard = AlignmentGuard(config) assert guard.record_loop_depth("s1", 2) is None alert = guard.record_loop_depth("s1", 3) assert alert is not None assert alert.alert_type == "loop_depth" assert alert.current_value == 3 def test_reset_session_clears_counters(self): config = AlignmentConfig(cascade_max_interactions=5) guard = AlignmentGuard(config) guard.record_interaction("s1") guard.record_interaction("s1") assert guard.get_interaction_count("s1") == 2 guard.reset_session("s1") assert guard.get_interaction_count("s1") == 0 def test_get_interaction_count_default_zero(self): config = AlignmentConfig() guard = AlignmentGuard(config) assert guard.get_interaction_count("unknown") == 0 def test_inject_constraints_delegates_to_injector(self): config = AlignmentConfig(constraints=["no_harm"]) guard = AlignmentGuard(config) result = guard.inject_constraints({"task": "test"}) assert result["alignment_constraints"] == ["no_harm"] # ── CascadeDetector 测试 ────────────────────────────────── class TestCascadeDetector: """CascadeDetector 独立级联检测测试""" def test_interaction_exceeds_threshold_triggers_alert(self): detector = CascadeDetector(max_interactions=3) assert detector.check_interaction("s1") is None assert detector.check_interaction("s1") is None assert detector.check_interaction("s1") is None alert = detector.check_interaction("s1") assert alert is not None assert alert.alert_type == "interaction_limit" assert alert.current_value == 4 assert alert.threshold == 3 def test_interaction_below_threshold_returns_none(self): detector = CascadeDetector(max_interactions=10) assert detector.check_interaction("s1") is None def test_loop_depth_exceeds_threshold_triggers_alert(self): detector = CascadeDetector(max_depth=3) assert detector.check_depth("s1", 3) is None alert = detector.check_depth("s1", 4) assert alert is not None assert alert.alert_type == "loop_depth" assert alert.current_value == 4 def test_reset_clears_counters(self): detector = CascadeDetector(max_interactions=2) detector.check_interaction("s1") detector.check_interaction("s1") detector.reset("s1") stats = detector.get_stats("s1") assert stats["interactions"] == 0 assert stats["depth"] == 0 def test_get_stats_returns_current_values(self): detector = CascadeDetector() detector.check_interaction("s1") detector.check_interaction("s1") detector.check_depth("s1", 5) stats = detector.get_stats("s1") assert stats["interactions"] == 2 assert stats["depth"] == 5 def test_get_stats_unknown_session(self): detector = CascadeDetector() stats = detector.get_stats("unknown") assert stats["interactions"] == 0 assert stats["depth"] == 0 # ── SkillConfig alignment 字段测试 ──────────────────────── class TestSkillConfigAlignment: """SkillConfig alignment 字段测试""" def test_default_alignment(self): config = SkillConfig(name="test", agent_type="test", prompt={"identity": "test"}) assert config.alignment.constraints == [] assert config.alignment.cascade_max_interactions == 10 assert config.alignment.cascade_max_depth == 3 assert config.alignment.audit_enabled is False assert config.alignment.audit_model == "default" def test_alignment_from_dict(self): config = SkillConfig.from_dict({ "name": "test", "agent_type": "test", "prompt": {"identity": "test"}, "alignment": { "constraints": ["no_harm"], "cascade_max_interactions": 5, "cascade_max_depth": 2, "audit_enabled": True, "audit_model": "gpt-4", }, }) assert config.alignment.constraints == ["no_harm"] assert config.alignment.cascade_max_interactions == 5 assert config.alignment.cascade_max_depth == 2 assert config.alignment.audit_enabled is True assert config.alignment.audit_model == "gpt-4" def test_alignment_to_dict(self): config = SkillConfig( name="test", agent_type="test", prompt={"identity": "test"}, alignment={"constraints": ["no_harm"], "audit_enabled": True}, ) d = config.to_dict() assert "alignment" in d assert d["alignment"]["constraints"] == ["no_harm"] assert d["alignment"]["audit_enabled"] is True def test_backward_compatibility_no_alignment(self): config = SkillConfig.from_dict({ "name": "test", "agent_type": "test", "prompt": {"identity": "test"}, }) assert config.alignment.constraints == []