335 lines
13 KiB
Python
335 lines
13 KiB
Python
"""AlignmentGuard 单元测试"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.quality.alignment import (
|
|
AlignmentCheckResult,
|
|
AlignmentConfig,
|
|
AlignmentGuard,
|
|
CascadeAlert,
|
|
ConstraintInjector,
|
|
)
|
|
from agentkit.quality.cascade_detector import CascadeDetector
|
|
from agentkit.skills.base import SkillConfig
|
|
|
|
|
|
# ── AlignmentConfig 测试 ──────────────────────────────────
|
|
|
|
|
|
class TestAlignmentConfig:
|
|
"""AlignmentConfig 默认值测试"""
|
|
|
|
def test_default_values(self):
|
|
config = AlignmentConfig()
|
|
assert config.constraints == []
|
|
assert config.cascade_max_interactions == 10
|
|
assert config.cascade_max_depth == 3
|
|
assert config.audit_enabled is False
|
|
assert config.audit_model == "default"
|
|
|
|
def test_custom_values(self):
|
|
config = AlignmentConfig(
|
|
constraints=["no_harm", "be_honest"],
|
|
cascade_max_interactions=5,
|
|
cascade_max_depth=2,
|
|
audit_enabled=True,
|
|
audit_model="gpt-4",
|
|
)
|
|
assert config.constraints == ["no_harm", "be_honest"]
|
|
assert config.cascade_max_interactions == 5
|
|
assert config.cascade_max_depth == 2
|
|
assert config.audit_enabled is True
|
|
assert config.audit_model == "gpt-4"
|
|
|
|
|
|
# ── ConstraintInjector 测试 ───────────────────────────────
|
|
|
|
|
|
class TestConstraintInjector:
|
|
"""ConstraintInjector 约束注入测试"""
|
|
|
|
def test_inject_constraints_into_input_data(self):
|
|
config = AlignmentConfig(constraints=["no_harm", "be_honest"])
|
|
injector = ConstraintInjector(config)
|
|
result = injector.inject({"task": "translate"})
|
|
assert "alignment_constraints" in result
|
|
assert result["alignment_constraints"] == ["no_harm", "be_honest"]
|
|
assert result["task"] == "translate"
|
|
|
|
def test_does_not_modify_original_dict(self):
|
|
config = AlignmentConfig(constraints=["no_harm"])
|
|
injector = ConstraintInjector(config)
|
|
original = {"task": "translate"}
|
|
result = injector.inject(original)
|
|
assert "alignment_constraints" not in original
|
|
assert "alignment_constraints" in result
|
|
|
|
def test_empty_constraints(self):
|
|
config = AlignmentConfig(constraints=[])
|
|
injector = ConstraintInjector(config)
|
|
result = injector.inject({"task": "translate"})
|
|
assert result["alignment_constraints"] == []
|
|
|
|
|
|
# ── AlignmentGuard.check_output 测试 ──────────────────────
|
|
|
|
|
|
class TestAlignmentGuardCheckOutput:
|
|
"""AlignmentGuard.check_output 对齐检查"""
|
|
|
|
async def test_rule_check_violation_keyword_match(self):
|
|
config = AlignmentConfig(constraints=["forbidden_word"])
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "This contains forbidden_word in text"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is False
|
|
assert "forbidden_word" in result.violations
|
|
assert result.checked_by == "rule"
|
|
|
|
async def test_rule_check_passes_no_violations(self):
|
|
config = AlignmentConfig(constraints=["forbidden_word"])
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "This is clean text"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is True
|
|
assert result.violations == []
|
|
assert result.checked_by == "rule"
|
|
|
|
async def test_no_constraints_passes(self):
|
|
config = AlignmentConfig(constraints=[])
|
|
guard = AlignmentGuard(config)
|
|
result = await guard.check_output({"content": "anything"})
|
|
assert result.passed is True
|
|
assert result.checked_by == "rule"
|
|
|
|
async def test_audit_disabled_does_not_call_llm(self):
|
|
config = AlignmentConfig(
|
|
constraints=["no_harm"], audit_enabled=False
|
|
)
|
|
mock_gateway = AsyncMock()
|
|
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
|
output = {"content": "This is safe"}
|
|
result = await guard.check_output(output)
|
|
assert result.checked_by == "rule"
|
|
mock_gateway.chat.assert_not_called()
|
|
|
|
async def test_audit_enabled_calls_llm_for_semantic_check(self):
|
|
config = AlignmentConfig(
|
|
constraints=["be_respectful"], audit_enabled=True, audit_model="gpt-4"
|
|
)
|
|
mock_response = MagicMock()
|
|
mock_response.content = "PASS"
|
|
mock_gateway = AsyncMock()
|
|
mock_gateway.chat.return_value = mock_response
|
|
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
|
output = {"content": "This is respectful text"}
|
|
# Rule check passes first (no keyword match), then LLM audit
|
|
result = await guard.check_output(output)
|
|
assert result.checked_by == "llm"
|
|
mock_gateway.chat.assert_called_once()
|
|
|
|
async def test_audit_enabled_llm_detects_violation(self):
|
|
config = AlignmentConfig(
|
|
constraints=["be_respectful"], audit_enabled=True
|
|
)
|
|
mock_response = MagicMock()
|
|
mock_response.content = "VIOLATION: Output is disrespectful"
|
|
mock_gateway = AsyncMock()
|
|
mock_gateway.chat.return_value = mock_response
|
|
guard = AlignmentGuard(config, llm_gateway=mock_gateway)
|
|
output = {"content": "This is borderline text"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is False
|
|
assert result.checked_by == "llm"
|
|
|
|
async def test_audit_enabled_no_llm_gateway_skips_llm(self):
|
|
config = AlignmentConfig(
|
|
constraints=["be_respectful"], audit_enabled=True
|
|
)
|
|
guard = AlignmentGuard(config, llm_gateway=None)
|
|
output = {"content": "This is safe"}
|
|
result = await guard.check_output(output)
|
|
assert result.checked_by == "rule"
|
|
|
|
async def test_custom_constraints_override_config(self):
|
|
config = AlignmentConfig(constraints=["default_constraint"])
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "This has custom_violation in it"}
|
|
result = await guard.check_output(output, constraints=["custom_violation"])
|
|
assert result.passed is False
|
|
assert "custom_violation" in result.violations
|
|
|
|
async def test_case_insensitive_matching(self):
|
|
config = AlignmentConfig(constraints=["ForBiDdEn"])
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "This has forbidden in it"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is False
|
|
|
|
|
|
# ── AlignmentGuard 级联检测测试 ───────────────────────────
|
|
|
|
|
|
class TestAlignmentGuardCascade:
|
|
"""AlignmentGuard 级联故障检测"""
|
|
|
|
def test_record_interaction_returns_alert_when_exceeded(self):
|
|
config = AlignmentConfig(cascade_max_interactions=3)
|
|
guard = AlignmentGuard(config)
|
|
# 前 3 次不触发
|
|
assert guard.record_interaction("s1") is None
|
|
assert guard.record_interaction("s1") is None
|
|
assert guard.record_interaction("s1") is None
|
|
# 第 4 次触发
|
|
alert = guard.record_interaction("s1")
|
|
assert alert is not None
|
|
assert alert.session_id == "s1"
|
|
assert alert.alert_type == "interaction_limit"
|
|
assert alert.current_value == 4
|
|
assert alert.threshold == 3
|
|
|
|
def test_record_interaction_below_threshold_returns_none(self):
|
|
config = AlignmentConfig(cascade_max_interactions=10)
|
|
guard = AlignmentGuard(config)
|
|
assert guard.record_interaction("s1") is None
|
|
|
|
def test_record_loop_depth_returns_alert_when_exceeded(self):
|
|
config = AlignmentConfig(cascade_max_depth=2)
|
|
guard = AlignmentGuard(config)
|
|
assert guard.record_loop_depth("s1", 2) is None
|
|
alert = guard.record_loop_depth("s1", 3)
|
|
assert alert is not None
|
|
assert alert.alert_type == "loop_depth"
|
|
assert alert.current_value == 3
|
|
|
|
def test_reset_session_clears_counters(self):
|
|
config = AlignmentConfig(cascade_max_interactions=5)
|
|
guard = AlignmentGuard(config)
|
|
guard.record_interaction("s1")
|
|
guard.record_interaction("s1")
|
|
assert guard.get_interaction_count("s1") == 2
|
|
guard.reset_session("s1")
|
|
assert guard.get_interaction_count("s1") == 0
|
|
|
|
def test_get_interaction_count_default_zero(self):
|
|
config = AlignmentConfig()
|
|
guard = AlignmentGuard(config)
|
|
assert guard.get_interaction_count("unknown") == 0
|
|
|
|
def test_inject_constraints_delegates_to_injector(self):
|
|
config = AlignmentConfig(constraints=["no_harm"])
|
|
guard = AlignmentGuard(config)
|
|
result = guard.inject_constraints({"task": "test"})
|
|
assert result["alignment_constraints"] == ["no_harm"]
|
|
|
|
|
|
# ── CascadeDetector 测试 ──────────────────────────────────
|
|
|
|
|
|
class TestCascadeDetector:
|
|
"""CascadeDetector 独立级联检测测试"""
|
|
|
|
def test_interaction_exceeds_threshold_triggers_alert(self):
|
|
detector = CascadeDetector(max_interactions=3)
|
|
assert detector.check_interaction("s1") is None
|
|
assert detector.check_interaction("s1") is None
|
|
assert detector.check_interaction("s1") is None
|
|
alert = detector.check_interaction("s1")
|
|
assert alert is not None
|
|
assert alert.alert_type == "interaction_limit"
|
|
assert alert.current_value == 4
|
|
assert alert.threshold == 3
|
|
|
|
def test_interaction_below_threshold_returns_none(self):
|
|
detector = CascadeDetector(max_interactions=10)
|
|
assert detector.check_interaction("s1") is None
|
|
|
|
def test_loop_depth_exceeds_threshold_triggers_alert(self):
|
|
detector = CascadeDetector(max_depth=3)
|
|
assert detector.check_depth("s1", 3) is None
|
|
alert = detector.check_depth("s1", 4)
|
|
assert alert is not None
|
|
assert alert.alert_type == "loop_depth"
|
|
assert alert.current_value == 4
|
|
|
|
def test_reset_clears_counters(self):
|
|
detector = CascadeDetector(max_interactions=2)
|
|
detector.check_interaction("s1")
|
|
detector.check_interaction("s1")
|
|
detector.reset("s1")
|
|
stats = detector.get_stats("s1")
|
|
assert stats["interactions"] == 0
|
|
assert stats["depth"] == 0
|
|
|
|
def test_get_stats_returns_current_values(self):
|
|
detector = CascadeDetector()
|
|
detector.check_interaction("s1")
|
|
detector.check_interaction("s1")
|
|
detector.check_depth("s1", 5)
|
|
stats = detector.get_stats("s1")
|
|
assert stats["interactions"] == 2
|
|
assert stats["depth"] == 5
|
|
|
|
def test_get_stats_unknown_session(self):
|
|
detector = CascadeDetector()
|
|
stats = detector.get_stats("unknown")
|
|
assert stats["interactions"] == 0
|
|
assert stats["depth"] == 0
|
|
|
|
|
|
# ── SkillConfig alignment 字段测试 ────────────────────────
|
|
|
|
|
|
class TestSkillConfigAlignment:
|
|
"""SkillConfig alignment 字段测试"""
|
|
|
|
def test_default_alignment(self):
|
|
config = SkillConfig(name="test", agent_type="test", prompt={"identity": "test"})
|
|
assert config.alignment.constraints == []
|
|
assert config.alignment.cascade_max_interactions == 10
|
|
assert config.alignment.cascade_max_depth == 3
|
|
assert config.alignment.audit_enabled is False
|
|
assert config.alignment.audit_model == "default"
|
|
|
|
def test_alignment_from_dict(self):
|
|
config = SkillConfig.from_dict({
|
|
"name": "test",
|
|
"agent_type": "test",
|
|
"prompt": {"identity": "test"},
|
|
"alignment": {
|
|
"constraints": ["no_harm"],
|
|
"cascade_max_interactions": 5,
|
|
"cascade_max_depth": 2,
|
|
"audit_enabled": True,
|
|
"audit_model": "gpt-4",
|
|
},
|
|
})
|
|
assert config.alignment.constraints == ["no_harm"]
|
|
assert config.alignment.cascade_max_interactions == 5
|
|
assert config.alignment.cascade_max_depth == 2
|
|
assert config.alignment.audit_enabled is True
|
|
assert config.alignment.audit_model == "gpt-4"
|
|
|
|
def test_alignment_to_dict(self):
|
|
config = SkillConfig(
|
|
name="test",
|
|
agent_type="test",
|
|
prompt={"identity": "test"},
|
|
alignment={"constraints": ["no_harm"], "audit_enabled": True},
|
|
)
|
|
d = config.to_dict()
|
|
assert "alignment" in d
|
|
assert d["alignment"]["constraints"] == ["no_harm"]
|
|
assert d["alignment"]["audit_enabled"] is True
|
|
|
|
def test_backward_compatibility_no_alignment(self):
|
|
config = SkillConfig.from_dict({
|
|
"name": "test",
|
|
"agent_type": "test",
|
|
"prompt": {"identity": "test"},
|
|
})
|
|
assert config.alignment.constraints == []
|