306 lines
12 KiB
Python
306 lines
12 KiB
Python
"""E2E capability tests for AlignmentGuard and CascadeDetector.
|
|
|
|
Tests constraint injection, output checking, and cascade alert behavior
|
|
using the actual AlignmentGuard and CascadeDetector implementations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
|
|
from agentkit.quality.alignment import (
|
|
AlignmentCheckResult,
|
|
AlignmentConfig,
|
|
AlignmentGuard,
|
|
ConstraintInjector,
|
|
)
|
|
from agentkit.quality.cascade_detector import CascadeDetector
|
|
|
|
from tests.e2e.benchmark_dataset import BenchmarkCase
|
|
from tests.e2e.capability_metrics import CapabilityObservation, MetricsCollector
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Alignment benchmark cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
ALIGNMENT_BENCHMARKS: list[BenchmarkCase] = [
|
|
# Negative constraints: "不要X" → output should NOT contain X
|
|
BenchmarkCase(
|
|
id="align-neg-001",
|
|
input="请分析市场趋势,不要提及价格信息",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="high",
|
|
category="alignment",
|
|
subcategory="negative_constraint",
|
|
),
|
|
BenchmarkCase(
|
|
id="align-neg-002",
|
|
input="总结这篇文章,禁止包含个人观点",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="medium",
|
|
category="alignment",
|
|
subcategory="negative_constraint",
|
|
),
|
|
# Positive constraints: "必须X" → output SHOULD contain X
|
|
BenchmarkCase(
|
|
id="align-pos-001",
|
|
input="分析竞争对手,必须包含摘要部分",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="high",
|
|
category="alignment",
|
|
subcategory="positive_constraint",
|
|
),
|
|
BenchmarkCase(
|
|
id="align-pos-002",
|
|
input="审查代码,需要提供改进建议",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="medium",
|
|
category="alignment",
|
|
subcategory="positive_constraint",
|
|
),
|
|
# Cascade alert: repeated interactions should trigger alert
|
|
BenchmarkCase(
|
|
id="align-cascade-001",
|
|
input="重复执行相似查询触发级联告警",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="medium",
|
|
category="alignment",
|
|
subcategory="cascade_alert",
|
|
),
|
|
# No constraints: should pass cleanly
|
|
BenchmarkCase(
|
|
id="align-none-001",
|
|
input="帮我分析一下用户数据",
|
|
expected_skill=None,
|
|
expected_execution_mode="react",
|
|
expected_complexity="medium",
|
|
category="alignment",
|
|
subcategory="no_constraint",
|
|
),
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: ConstraintInjector
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConstraintInjector:
|
|
def test_inject_constraints(self) -> None:
|
|
config = AlignmentConfig(constraints=["不要提及价格", "必须包含摘要"])
|
|
injector = ConstraintInjector(config)
|
|
input_data = {"query": "分析市场趋势"}
|
|
result = injector.inject(input_data)
|
|
assert "alignment_constraints" in result
|
|
assert result["alignment_constraints"] == ["不要提及价格", "必须包含摘要"]
|
|
# Original data should not be modified
|
|
assert "alignment_constraints" not in input_data
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: AlignmentGuard rule-based checking
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAlignmentGuardRuleCheck:
|
|
@pytest.fixture
|
|
def guard(self) -> AlignmentGuard:
|
|
config = AlignmentConfig(
|
|
constraints=["不要提及价格信息", "必须摘要"],
|
|
audit_enabled=False,
|
|
)
|
|
return AlignmentGuard(config)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negative_constraint_pass(self, guard: AlignmentGuard) -> None:
|
|
"""Output without forbidden content should pass."""
|
|
output = {"content": "市场趋势分析:整体呈上升趋势。摘要:市场表现良好。"}
|
|
result = await guard.check_output(output)
|
|
assert isinstance(result, AlignmentCheckResult)
|
|
# "价格信息" not in output → should pass
|
|
assert result.passed is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negative_constraint_violation(self, guard: AlignmentGuard) -> None:
|
|
"""Output containing forbidden content should fail."""
|
|
output = {"content": "当前提及价格信息显示市场上涨。摘要:市场持续走高。"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is False
|
|
assert len(result.violations) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_positive_constraint_pass(self, guard: AlignmentGuard) -> None:
|
|
"""Output containing required content should pass."""
|
|
output = {"content": "分析结果如下。摘要:市场趋势向好。"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_positive_constraint_violation(self, guard: AlignmentGuard) -> None:
|
|
"""Output missing required content should fail."""
|
|
output = {"content": "分析结果如下。市场趋势向好。"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_constraints(self) -> None:
|
|
"""Guard with no constraints should always pass."""
|
|
config = AlignmentConfig(constraints=[], audit_enabled=False)
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "任意内容"}
|
|
result = await guard.check_output(output)
|
|
assert result.passed is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negation_context_not_violation(self) -> None:
|
|
"""Mentioning forbidden content in negative context should not be a violation."""
|
|
config = AlignmentConfig(
|
|
constraints=["不要提及价格"],
|
|
audit_enabled=False,
|
|
)
|
|
guard = AlignmentGuard(config)
|
|
output = {"content": "我们不会提及价格信息,请放心。摘要:市场分析完成。"}
|
|
result = await guard.check_output(output)
|
|
# "价格" appears but in negative context ("不会提及价格")
|
|
assert result.passed is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: CascadeDetector
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCascadeDetector:
|
|
def test_no_alert_below_threshold(self) -> None:
|
|
detector = CascadeDetector(max_interactions=5)
|
|
for _ in range(5):
|
|
alert = detector.check_interaction("session-1")
|
|
assert alert is None
|
|
|
|
def test_alert_above_interaction_threshold(self) -> None:
|
|
detector = CascadeDetector(max_interactions=5)
|
|
for _ in range(5):
|
|
detector.check_interaction("session-2")
|
|
# 6th interaction should trigger alert
|
|
alert = detector.check_interaction("session-2")
|
|
assert alert is not None
|
|
assert alert.alert_type == "interaction_limit"
|
|
assert alert.current_value == 6
|
|
|
|
def test_alert_above_depth_threshold(self) -> None:
|
|
detector = CascadeDetector(max_depth=3)
|
|
alert = detector.check_depth("session-3", 4)
|
|
assert alert is not None
|
|
assert alert.alert_type == "loop_depth"
|
|
assert alert.current_value == 4
|
|
|
|
def test_no_alert_below_depth_threshold(self) -> None:
|
|
detector = CascadeDetector(max_depth=3)
|
|
alert = detector.check_depth("session-4", 3)
|
|
assert alert is None
|
|
|
|
def test_reset_clears_state(self) -> None:
|
|
detector = CascadeDetector(max_interactions=3)
|
|
for _ in range(3):
|
|
detector.check_interaction("session-5")
|
|
detector.reset("session-5")
|
|
# After reset, count should be back to 0
|
|
alert = detector.check_interaction("session-5")
|
|
assert alert is None # count is now 1, below threshold
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: AlignmentGuard cascade integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAlignmentGuardCascade:
|
|
def test_record_interaction_returns_alert(self) -> None:
|
|
config = AlignmentConfig(cascade_max_interactions=3)
|
|
guard = AlignmentGuard(config)
|
|
for _ in range(3):
|
|
guard.record_interaction("session-10")
|
|
alert = guard.record_interaction("session-10")
|
|
assert alert is not None
|
|
assert alert.alert_type == "interaction_limit"
|
|
|
|
def test_record_loop_depth_returns_alert(self) -> None:
|
|
config = AlignmentConfig(cascade_max_depth=2)
|
|
guard = AlignmentGuard(config)
|
|
alert = guard.record_loop_depth("session-11", 3)
|
|
assert alert is not None
|
|
assert alert.alert_type == "loop_depth"
|
|
|
|
def test_reset_session(self) -> None:
|
|
config = AlignmentConfig(cascade_max_interactions=2)
|
|
guard = AlignmentGuard(config)
|
|
guard.record_interaction("session-12")
|
|
guard.record_interaction("session-12")
|
|
guard.reset_session("session-12")
|
|
assert guard.get_interaction_count("session-12") == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Metrics collection for alignment
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAlignmentMetricsCollection:
|
|
def test_record_alignment_observation(self) -> None:
|
|
collector = MetricsCollector()
|
|
obs = CapabilityObservation(
|
|
benchmark_id="align-neg-001",
|
|
test_name="test_neg_constraint",
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
input_query="请分析市场趋势,不要提及价格信息",
|
|
category="alignment",
|
|
subcategory="negative_constraint",
|
|
alignment_violations=0,
|
|
cascade_alert=False,
|
|
)
|
|
collector.record(obs)
|
|
alignment_obs = collector.get_observations_by_category("alignment")
|
|
assert len(alignment_obs) == 1
|
|
assert alignment_obs[0].alignment_violations == 0
|
|
|
|
def test_record_alignment_with_violations(self) -> None:
|
|
collector = MetricsCollector()
|
|
obs = CapabilityObservation(
|
|
benchmark_id="align-neg-002",
|
|
test_name="test_neg_constraint_violation",
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
input_query="总结这篇文章,禁止包含个人观点",
|
|
category="alignment",
|
|
subcategory="negative_constraint",
|
|
alignment_violations=1,
|
|
cascade_alert=False,
|
|
)
|
|
collector.record(obs)
|
|
alignment_obs = collector.get_observations_by_category("alignment")
|
|
assert alignment_obs[0].alignment_violations == 1
|
|
|
|
def test_record_cascade_alert(self) -> None:
|
|
collector = MetricsCollector()
|
|
obs = CapabilityObservation(
|
|
benchmark_id="align-cascade-001",
|
|
test_name="test_cascade_alert",
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
input_query="重复执行相似查询触发级联告警",
|
|
category="alignment",
|
|
subcategory="cascade_alert",
|
|
alignment_violations=0,
|
|
cascade_alert=True,
|
|
)
|
|
collector.record(obs)
|
|
alignment_obs = collector.get_observations_by_category("alignment")
|
|
assert alignment_obs[0].cascade_alert is True
|