fischer-agentkit/tests/e2e/test_capability_alignment.py

306 lines
12 KiB
Python

"""E2E capability tests for AlignmentGuard and CascadeDetector.
Tests constraint injection, output checking, and cascade alert behavior
using the actual AlignmentGuard and CascadeDetector implementations.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from agentkit.quality.alignment import (
AlignmentCheckResult,
AlignmentConfig,
AlignmentGuard,
ConstraintInjector,
)
from agentkit.quality.cascade_detector import CascadeDetector
from tests.e2e.benchmark_dataset import BenchmarkCase
from tests.e2e.capability_metrics import CapabilityObservation, MetricsCollector
# ---------------------------------------------------------------------------
# Alignment benchmark cases
# ---------------------------------------------------------------------------
ALIGNMENT_BENCHMARKS: list[BenchmarkCase] = [
# Negative constraints: "不要X" → output should NOT contain X
BenchmarkCase(
id="align-neg-001",
input="请分析市场趋势,不要提及价格信息",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="high",
category="alignment",
subcategory="negative_constraint",
),
BenchmarkCase(
id="align-neg-002",
input="总结这篇文章,禁止包含个人观点",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="medium",
category="alignment",
subcategory="negative_constraint",
),
# Positive constraints: "必须X" → output SHOULD contain X
BenchmarkCase(
id="align-pos-001",
input="分析竞争对手,必须包含摘要部分",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="high",
category="alignment",
subcategory="positive_constraint",
),
BenchmarkCase(
id="align-pos-002",
input="审查代码,需要提供改进建议",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="medium",
category="alignment",
subcategory="positive_constraint",
),
# Cascade alert: repeated interactions should trigger alert
BenchmarkCase(
id="align-cascade-001",
input="重复执行相似查询触发级联告警",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="medium",
category="alignment",
subcategory="cascade_alert",
),
# No constraints: should pass cleanly
BenchmarkCase(
id="align-none-001",
input="帮我分析一下用户数据",
expected_skill=None,
expected_execution_mode="react",
expected_complexity="medium",
category="alignment",
subcategory="no_constraint",
),
]
# ---------------------------------------------------------------------------
# Tests: ConstraintInjector
# ---------------------------------------------------------------------------
class TestConstraintInjector:
def test_inject_constraints(self) -> None:
config = AlignmentConfig(constraints=["不要提及价格", "必须包含摘要"])
injector = ConstraintInjector(config)
input_data = {"query": "分析市场趋势"}
result = injector.inject(input_data)
assert "alignment_constraints" in result
assert result["alignment_constraints"] == ["不要提及价格", "必须包含摘要"]
# Original data should not be modified
assert "alignment_constraints" not in input_data
# ---------------------------------------------------------------------------
# Tests: AlignmentGuard rule-based checking
# ---------------------------------------------------------------------------
class TestAlignmentGuardRuleCheck:
@pytest.fixture
def guard(self) -> AlignmentGuard:
config = AlignmentConfig(
constraints=["不要提及价格信息", "必须摘要"],
audit_enabled=False,
)
return AlignmentGuard(config)
@pytest.mark.asyncio
async def test_negative_constraint_pass(self, guard: AlignmentGuard) -> None:
"""Output without forbidden content should pass."""
output = {"content": "市场趋势分析:整体呈上升趋势。摘要:市场表现良好。"}
result = await guard.check_output(output)
assert isinstance(result, AlignmentCheckResult)
# "价格信息" not in output → should pass
assert result.passed is True
@pytest.mark.asyncio
async def test_negative_constraint_violation(self, guard: AlignmentGuard) -> None:
"""Output containing forbidden content should fail."""
output = {"content": "当前提及价格信息显示市场上涨。摘要:市场持续走高。"}
result = await guard.check_output(output)
assert result.passed is False
assert len(result.violations) > 0
@pytest.mark.asyncio
async def test_positive_constraint_pass(self, guard: AlignmentGuard) -> None:
"""Output containing required content should pass."""
output = {"content": "分析结果如下。摘要:市场趋势向好。"}
result = await guard.check_output(output)
assert result.passed is True
@pytest.mark.asyncio
async def test_positive_constraint_violation(self, guard: AlignmentGuard) -> None:
"""Output missing required content should fail."""
output = {"content": "分析结果如下。市场趋势向好。"}
result = await guard.check_output(output)
assert result.passed is False
@pytest.mark.asyncio
async def test_no_constraints(self) -> None:
"""Guard with no constraints should always pass."""
config = AlignmentConfig(constraints=[], audit_enabled=False)
guard = AlignmentGuard(config)
output = {"content": "任意内容"}
result = await guard.check_output(output)
assert result.passed is True
@pytest.mark.asyncio
async def test_negation_context_not_violation(self) -> None:
"""Mentioning forbidden content in negative context should not be a violation."""
config = AlignmentConfig(
constraints=["不要提及价格"],
audit_enabled=False,
)
guard = AlignmentGuard(config)
output = {"content": "我们不会提及价格信息,请放心。摘要:市场分析完成。"}
result = await guard.check_output(output)
# "价格" appears but in negative context ("不会提及价格")
assert result.passed is True
# ---------------------------------------------------------------------------
# Tests: CascadeDetector
# ---------------------------------------------------------------------------
class TestCascadeDetector:
def test_no_alert_below_threshold(self) -> None:
detector = CascadeDetector(max_interactions=5)
for _ in range(5):
alert = detector.check_interaction("session-1")
assert alert is None
def test_alert_above_interaction_threshold(self) -> None:
detector = CascadeDetector(max_interactions=5)
for _ in range(5):
detector.check_interaction("session-2")
# 6th interaction should trigger alert
alert = detector.check_interaction("session-2")
assert alert is not None
assert alert.alert_type == "interaction_limit"
assert alert.current_value == 6
def test_alert_above_depth_threshold(self) -> None:
detector = CascadeDetector(max_depth=3)
alert = detector.check_depth("session-3", 4)
assert alert is not None
assert alert.alert_type == "loop_depth"
assert alert.current_value == 4
def test_no_alert_below_depth_threshold(self) -> None:
detector = CascadeDetector(max_depth=3)
alert = detector.check_depth("session-4", 3)
assert alert is None
def test_reset_clears_state(self) -> None:
detector = CascadeDetector(max_interactions=3)
for _ in range(3):
detector.check_interaction("session-5")
detector.reset("session-5")
# After reset, count should be back to 0
alert = detector.check_interaction("session-5")
assert alert is None # count is now 1, below threshold
# ---------------------------------------------------------------------------
# Tests: AlignmentGuard cascade integration
# ---------------------------------------------------------------------------
class TestAlignmentGuardCascade:
def test_record_interaction_returns_alert(self) -> None:
config = AlignmentConfig(cascade_max_interactions=3)
guard = AlignmentGuard(config)
for _ in range(3):
guard.record_interaction("session-10")
alert = guard.record_interaction("session-10")
assert alert is not None
assert alert.alert_type == "interaction_limit"
def test_record_loop_depth_returns_alert(self) -> None:
config = AlignmentConfig(cascade_max_depth=2)
guard = AlignmentGuard(config)
alert = guard.record_loop_depth("session-11", 3)
assert alert is not None
assert alert.alert_type == "loop_depth"
def test_reset_session(self) -> None:
config = AlignmentConfig(cascade_max_interactions=2)
guard = AlignmentGuard(config)
guard.record_interaction("session-12")
guard.record_interaction("session-12")
guard.reset_session("session-12")
assert guard.get_interaction_count("session-12") == 0
# ---------------------------------------------------------------------------
# Tests: Metrics collection for alignment
# ---------------------------------------------------------------------------
class TestAlignmentMetricsCollection:
def test_record_alignment_observation(self) -> None:
collector = MetricsCollector()
obs = CapabilityObservation(
benchmark_id="align-neg-001",
test_name="test_neg_constraint",
timestamp=datetime.now(timezone.utc).isoformat(),
input_query="请分析市场趋势,不要提及价格信息",
category="alignment",
subcategory="negative_constraint",
alignment_violations=0,
cascade_alert=False,
)
collector.record(obs)
alignment_obs = collector.get_observations_by_category("alignment")
assert len(alignment_obs) == 1
assert alignment_obs[0].alignment_violations == 0
def test_record_alignment_with_violations(self) -> None:
collector = MetricsCollector()
obs = CapabilityObservation(
benchmark_id="align-neg-002",
test_name="test_neg_constraint_violation",
timestamp=datetime.now(timezone.utc).isoformat(),
input_query="总结这篇文章,禁止包含个人观点",
category="alignment",
subcategory="negative_constraint",
alignment_violations=1,
cascade_alert=False,
)
collector.record(obs)
alignment_obs = collector.get_observations_by_category("alignment")
assert alignment_obs[0].alignment_violations == 1
def test_record_cascade_alert(self) -> None:
collector = MetricsCollector()
obs = CapabilityObservation(
benchmark_id="align-cascade-001",
test_name="test_cascade_alert",
timestamp=datetime.now(timezone.utc).isoformat(),
input_query="重复执行相似查询触发级联告警",
category="alignment",
subcategory="cascade_alert",
alignment_violations=0,
cascade_alert=True,
)
collector.record(obs)
alignment_obs = collector.get_observations_by_category("alignment")
assert alignment_obs[0].cascade_alert is True