fischer-agentkit/tests/unit/test_alignment_guard.py

335 lines
13 KiB
Python

"""AlignmentGuard 单元测试"""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.quality.alignment import (
AlignmentCheckResult,
AlignmentConfig,
AlignmentGuard,
CascadeAlert,
ConstraintInjector,
)
from agentkit.quality.cascade_detector import CascadeDetector
from agentkit.skills.base import SkillConfig
# ── AlignmentConfig 测试 ──────────────────────────────────
class TestAlignmentConfig:
"""AlignmentConfig 默认值测试"""
def test_default_values(self):
config = AlignmentConfig()
assert config.constraints == []
assert config.cascade_max_interactions == 10
assert config.cascade_max_depth == 3
assert config.audit_enabled is False
assert config.audit_model == "default"
def test_custom_values(self):
config = AlignmentConfig(
constraints=["no_harm", "be_honest"],
cascade_max_interactions=5,
cascade_max_depth=2,
audit_enabled=True,
audit_model="gpt-4",
)
assert config.constraints == ["no_harm", "be_honest"]
assert config.cascade_max_interactions == 5
assert config.cascade_max_depth == 2
assert config.audit_enabled is True
assert config.audit_model == "gpt-4"
# ── ConstraintInjector 测试 ───────────────────────────────
class TestConstraintInjector:
"""ConstraintInjector 约束注入测试"""
def test_inject_constraints_into_input_data(self):
config = AlignmentConfig(constraints=["no_harm", "be_honest"])
injector = ConstraintInjector(config)
result = injector.inject({"task": "translate"})
assert "alignment_constraints" in result
assert result["alignment_constraints"] == ["no_harm", "be_honest"]
assert result["task"] == "translate"
def test_does_not_modify_original_dict(self):
config = AlignmentConfig(constraints=["no_harm"])
injector = ConstraintInjector(config)
original = {"task": "translate"}
result = injector.inject(original)
assert "alignment_constraints" not in original
assert "alignment_constraints" in result
def test_empty_constraints(self):
config = AlignmentConfig(constraints=[])
injector = ConstraintInjector(config)
result = injector.inject({"task": "translate"})
assert result["alignment_constraints"] == []
# ── AlignmentGuard.check_output 测试 ──────────────────────
class TestAlignmentGuardCheckOutput:
"""AlignmentGuard.check_output 对齐检查"""
async def test_rule_check_violation_keyword_match(self):
config = AlignmentConfig(constraints=["forbidden_word"])
guard = AlignmentGuard(config)
output = {"content": "This contains forbidden_word in text"}
result = await guard.check_output(output)
assert result.passed is False
assert "forbidden_word" in result.violations
assert result.checked_by == "rule"
async def test_rule_check_passes_no_violations(self):
config = AlignmentConfig(constraints=["forbidden_word"])
guard = AlignmentGuard(config)
output = {"content": "This is clean text"}
result = await guard.check_output(output)
assert result.passed is True
assert result.violations == []
assert result.checked_by == "rule"
async def test_no_constraints_passes(self):
config = AlignmentConfig(constraints=[])
guard = AlignmentGuard(config)
result = await guard.check_output({"content": "anything"})
assert result.passed is True
assert result.checked_by == "rule"
async def test_audit_disabled_does_not_call_llm(self):
config = AlignmentConfig(
constraints=["no_harm"], audit_enabled=False
)
mock_gateway = AsyncMock()
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
output = {"content": "This is safe"}
result = await guard.check_output(output)
assert result.checked_by == "rule"
mock_gateway.chat.assert_not_called()
async def test_audit_enabled_calls_llm_for_semantic_check(self):
config = AlignmentConfig(
constraints=["be_respectful"], audit_enabled=True, audit_model="gpt-4"
)
mock_response = MagicMock()
mock_response.content = "PASS"
mock_gateway = AsyncMock()
mock_gateway.chat.return_value = mock_response
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
output = {"content": "This is respectful text"}
# Rule check passes first (no keyword match), then LLM audit
result = await guard.check_output(output)
assert result.checked_by == "llm"
mock_gateway.chat.assert_called_once()
async def test_audit_enabled_llm_detects_violation(self):
config = AlignmentConfig(
constraints=["be_respectful"], audit_enabled=True
)
mock_response = MagicMock()
mock_response.content = "VIOLATION: Output is disrespectful"
mock_gateway = AsyncMock()
mock_gateway.chat.return_value = mock_response
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
output = {"content": "This is borderline text"}
result = await guard.check_output(output)
assert result.passed is False
assert result.checked_by == "llm"
async def test_audit_enabled_no_llm_gateway_skips_llm(self):
config = AlignmentConfig(
constraints=["be_respectful"], audit_enabled=True
)
guard = AlignmentGuard(config, llm_gateway=None)
output = {"content": "This is safe"}
result = await guard.check_output(output)
assert result.checked_by == "rule"
async def test_custom_constraints_override_config(self):
config = AlignmentConfig(constraints=["default_constraint"])
guard = AlignmentGuard(config)
output = {"content": "This has custom_violation in it"}
result = await guard.check_output(output, constraints=["custom_violation"])
assert result.passed is False
assert "custom_violation" in result.violations
async def test_case_insensitive_matching(self):
config = AlignmentConfig(constraints=["ForBiDdEn"])
guard = AlignmentGuard(config)
output = {"content": "This has forbidden in it"}
result = await guard.check_output(output)
assert result.passed is False
# ── AlignmentGuard 级联检测测试 ───────────────────────────
class TestAlignmentGuardCascade:
"""AlignmentGuard 级联故障检测"""
def test_record_interaction_returns_alert_when_exceeded(self):
config = AlignmentConfig(cascade_max_interactions=3)
guard = AlignmentGuard(config)
# 前 3 次不触发
assert guard.record_interaction("s1") is None
assert guard.record_interaction("s1") is None
assert guard.record_interaction("s1") is None
# 第 4 次触发
alert = guard.record_interaction("s1")
assert alert is not None
assert alert.session_id == "s1"
assert alert.alert_type == "interaction_limit"
assert alert.current_value == 4
assert alert.threshold == 3
def test_record_interaction_below_threshold_returns_none(self):
config = AlignmentConfig(cascade_max_interactions=10)
guard = AlignmentGuard(config)
assert guard.record_interaction("s1") is None
def test_record_loop_depth_returns_alert_when_exceeded(self):
config = AlignmentConfig(cascade_max_depth=2)
guard = AlignmentGuard(config)
assert guard.record_loop_depth("s1", 2) is None
alert = guard.record_loop_depth("s1", 3)
assert alert is not None
assert alert.alert_type == "loop_depth"
assert alert.current_value == 3
def test_reset_session_clears_counters(self):
config = AlignmentConfig(cascade_max_interactions=5)
guard = AlignmentGuard(config)
guard.record_interaction("s1")
guard.record_interaction("s1")
assert guard.get_interaction_count("s1") == 2
guard.reset_session("s1")
assert guard.get_interaction_count("s1") == 0
def test_get_interaction_count_default_zero(self):
config = AlignmentConfig()
guard = AlignmentGuard(config)
assert guard.get_interaction_count("unknown") == 0
def test_inject_constraints_delegates_to_injector(self):
config = AlignmentConfig(constraints=["no_harm"])
guard = AlignmentGuard(config)
result = guard.inject_constraints({"task": "test"})
assert result["alignment_constraints"] == ["no_harm"]
# ── CascadeDetector 测试 ──────────────────────────────────
class TestCascadeDetector:
"""CascadeDetector 独立级联检测测试"""
def test_interaction_exceeds_threshold_triggers_alert(self):
detector = CascadeDetector(max_interactions=3)
assert detector.check_interaction("s1") is None
assert detector.check_interaction("s1") is None
assert detector.check_interaction("s1") is None
alert = detector.check_interaction("s1")
assert alert is not None
assert alert.alert_type == "interaction_limit"
assert alert.current_value == 4
assert alert.threshold == 3
def test_interaction_below_threshold_returns_none(self):
detector = CascadeDetector(max_interactions=10)
assert detector.check_interaction("s1") is None
def test_loop_depth_exceeds_threshold_triggers_alert(self):
detector = CascadeDetector(max_depth=3)
assert detector.check_depth("s1", 3) is None
alert = detector.check_depth("s1", 4)
assert alert is not None
assert alert.alert_type == "loop_depth"
assert alert.current_value == 4
def test_reset_clears_counters(self):
detector = CascadeDetector(max_interactions=2)
detector.check_interaction("s1")
detector.check_interaction("s1")
detector.reset("s1")
stats = detector.get_stats("s1")
assert stats["interactions"] == 0
assert stats["depth"] == 0
def test_get_stats_returns_current_values(self):
detector = CascadeDetector()
detector.check_interaction("s1")
detector.check_interaction("s1")
detector.check_depth("s1", 5)
stats = detector.get_stats("s1")
assert stats["interactions"] == 2
assert stats["depth"] == 5
def test_get_stats_unknown_session(self):
detector = CascadeDetector()
stats = detector.get_stats("unknown")
assert stats["interactions"] == 0
assert stats["depth"] == 0
# ── SkillConfig alignment 字段测试 ────────────────────────
class TestSkillConfigAlignment:
"""SkillConfig alignment 字段测试"""
def test_default_alignment(self):
config = SkillConfig(name="test", agent_type="test", prompt={"identity": "test"})
assert config.alignment.constraints == []
assert config.alignment.cascade_max_interactions == 10
assert config.alignment.cascade_max_depth == 3
assert config.alignment.audit_enabled is False
assert config.alignment.audit_model == "default"
def test_alignment_from_dict(self):
config = SkillConfig.from_dict({
"name": "test",
"agent_type": "test",
"prompt": {"identity": "test"},
"alignment": {
"constraints": ["no_harm"],
"cascade_max_interactions": 5,
"cascade_max_depth": 2,
"audit_enabled": True,
"audit_model": "gpt-4",
},
})
assert config.alignment.constraints == ["no_harm"]
assert config.alignment.cascade_max_interactions == 5
assert config.alignment.cascade_max_depth == 2
assert config.alignment.audit_enabled is True
assert config.alignment.audit_model == "gpt-4"
def test_alignment_to_dict(self):
config = SkillConfig(
name="test",
agent_type="test",
prompt={"identity": "test"},
alignment={"constraints": ["no_harm"], "audit_enabled": True},
)
d = config.to_dict()
assert "alignment" in d
assert d["alignment"]["constraints"] == ["no_harm"]
assert d["alignment"]["audit_enabled"] is True
def test_backward_compatibility_no_alignment(self):
config = SkillConfig.from_dict({
"name": "test",
"agent_type": "test",
"prompt": {"identity": "test"},
})
assert config.alignment.constraints == []