304 lines
12 KiB
Python
304 lines
12 KiB
Python
"""AlignmentGuard - 对齐守卫:约束注入 + 级联故障检测"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
|
||
from agentkit.telemetry.tracer import get_tracer
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AlignmentConfig:
|
||
"""对齐守卫配置"""
|
||
|
||
constraints: list[str] = field(default_factory=list)
|
||
cascade_max_interactions: int = 10
|
||
cascade_max_depth: int = 3
|
||
audit_enabled: bool = False
|
||
audit_model: str = "default"
|
||
audit_sample_rate: float = 1.0 # 审计采样率 0.0-1.0,1.0=每次都审计
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict[str, object]) -> "AlignmentConfig":
|
||
"""从字典创建,忽略未知键"""
|
||
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
||
filtered = {k: v for k, v in data.items() if k in known_fields}
|
||
return cls(**filtered)
|
||
|
||
|
||
@dataclass
|
||
class AlignmentCheckResult:
|
||
"""对齐检查结果"""
|
||
|
||
passed: bool
|
||
violations: list[str] = field(default_factory=list)
|
||
checked_by: str = "" # "rule" or "llm"
|
||
|
||
|
||
@dataclass
|
||
class CascadeAlert:
|
||
"""级联故障告警"""
|
||
|
||
session_id: str
|
||
alert_type: str # "interaction_limit" or "loop_depth"
|
||
current_value: int
|
||
threshold: int
|
||
message: str
|
||
|
||
|
||
class ConstraintInjector:
|
||
"""将全局约束注入到任务 input_data 中"""
|
||
|
||
def __init__(self, config: AlignmentConfig):
|
||
self._config = config
|
||
|
||
def inject(self, input_data: dict[str, object]) -> dict[str, object]:
|
||
"""注入约束指令到 input_data
|
||
|
||
在 input_data 中添加 'alignment_constraints' 键,值为约束列表。
|
||
不修改原始 dict,返回新 dict。
|
||
"""
|
||
result = {**input_data, "alignment_constraints": list(self._config.constraints)}
|
||
return result
|
||
|
||
|
||
class AlignmentGuard:
|
||
"""对齐守卫 — 扩展 QualityGate,增加约束注入和级联检测"""
|
||
|
||
def __init__(self, config: AlignmentConfig, llm_gateway=None):
|
||
self._config = config
|
||
self._injector = ConstraintInjector(config)
|
||
self._llm_gateway = llm_gateway
|
||
self._interaction_counts: dict[str, int] = {}
|
||
self._loop_depths: dict[str, int] = {}
|
||
|
||
def inject_constraints(self, input_data: dict[str, object]) -> dict[str, object]:
|
||
"""委托给 ConstraintInjector"""
|
||
return self._injector.inject(input_data)
|
||
|
||
async def check_output(
|
||
self,
|
||
output: dict[str, object],
|
||
constraints: list[str] | None = None,
|
||
) -> AlignmentCheckResult:
|
||
"""检查输出是否符合约束
|
||
|
||
- 系统级约束:基于规则的检查(关键词 + 正则匹配)
|
||
- 组织级约束:LLM 语义检查(仅当 audit_enabled=True)
|
||
"""
|
||
effective_constraints = constraints if constraints is not None else self._config.constraints
|
||
if not effective_constraints:
|
||
return AlignmentCheckResult(passed=True, checked_by="rule")
|
||
|
||
tracer = get_tracer()
|
||
with tracer.start_span("guard.check") as span:
|
||
span.set_attribute("guard.constraints_count", len(effective_constraints))
|
||
|
||
# 1. 基于规则的检查:关键词/子串匹配
|
||
violations = self._rule_check(output, effective_constraints)
|
||
if violations:
|
||
result = AlignmentCheckResult(
|
||
passed=False,
|
||
violations=violations,
|
||
checked_by="rule",
|
||
)
|
||
span.set_attribute("guard.passed", result.passed)
|
||
span.set_attribute("guard.checked_by", result.checked_by)
|
||
return result
|
||
|
||
# 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway,按采样率执行)
|
||
if self._config.audit_enabled and self._llm_gateway is not None:
|
||
import random
|
||
if random.random() < self._config.audit_sample_rate:
|
||
result = await self._llm_check(output, effective_constraints)
|
||
span.set_attribute("guard.passed", result.passed)
|
||
span.set_attribute("guard.checked_by", result.checked_by)
|
||
return result
|
||
# 采样未命中,信任规则检查结果
|
||
logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate)
|
||
|
||
result = AlignmentCheckResult(passed=True, checked_by="rule")
|
||
span.set_attribute("guard.passed", result.passed)
|
||
span.set_attribute("guard.checked_by", result.checked_by)
|
||
return result
|
||
|
||
def _rule_check(
|
||
self, output: dict[str, object], constraints: list[str]
|
||
) -> list[str]:
|
||
"""基于规则的约束检查:方向性判断,区分'禁止X'和'提及X'
|
||
|
||
约束格式:
|
||
- "不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" → 输出中不应执行X
|
||
- "必须X" / "需要X" / "务必X" / "must X" / "should X" → 输出中应包含X
|
||
- 其他 → 简单子串匹配(约束关键词出现在输出中即违规)
|
||
"""
|
||
import re
|
||
|
||
content = self._extract_text(output)
|
||
content_lower = content.lower()
|
||
violations: list[str] = []
|
||
|
||
for constraint in constraints:
|
||
constraint_lower = constraint.lower().strip()
|
||
|
||
# 检测否定约束:"不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X"
|
||
neg_match = re.match(
|
||
r"^(?:不要|禁止|不得|不能|不可以|别|no\s+|don'?t\s+|never\s+|must\s+not\s+|should\s+not\s+)\s*(.+)",
|
||
constraint_lower,
|
||
)
|
||
if neg_match:
|
||
# 否定约束:提取被禁止的内容关键词
|
||
forbidden = neg_match.group(1).strip()
|
||
# 只有当输出中实际执行了该行为时才判违规(而非仅仅提及)
|
||
# 简单启发式:如果输出中包含"执行/输出/提供了 + forbidden"则违规
|
||
# 更安全的做法:如果 forbidden 直接出现在输出中且不是否定语境
|
||
if self._is_positive_mention(content_lower, forbidden):
|
||
violations.append(constraint)
|
||
continue
|
||
|
||
# 检测肯定约束:"必须X" / "需要X" / "务必X" / "must X" / "should X"
|
||
pos_match = re.match(
|
||
r"^(?:必须|需要|务必|一定要|must\s+|should\s+|shall\s+)\s*(.+)",
|
||
constraint_lower,
|
||
)
|
||
if pos_match:
|
||
# 肯定约束:输出中应包含该内容
|
||
required = pos_match.group(1).strip()
|
||
if required not in content_lower:
|
||
violations.append(constraint)
|
||
continue
|
||
|
||
# 默认:简单子串匹配
|
||
if constraint_lower in content_lower:
|
||
violations.append(constraint)
|
||
|
||
return violations
|
||
|
||
@staticmethod
|
||
def _is_positive_mention(content: str, keyword: str) -> bool:
|
||
"""判断 keyword 在 content 中是否为肯定性提及(实际执行/输出)
|
||
|
||
如果 keyword 出现在否定语境中(如"我们不会存储X"),不算违规。
|
||
遍历所有出现位置,只要有一次肯定性提及即返回 True。
|
||
"""
|
||
start = 0
|
||
while True:
|
||
idx = content.find(keyword, start)
|
||
if idx == -1:
|
||
return False
|
||
|
||
# 检查 keyword 前面是否有否定词
|
||
prefix = content[max(0, idx - 20) : idx]
|
||
neg_prefixes = [
|
||
"不会", "不能", "不要", "没有", "并未", "并未", "无法",
|
||
"won't", "don't", "not ", "never ", "no ",
|
||
]
|
||
is_negated = any(neg in prefix for neg in neg_prefixes)
|
||
|
||
if not is_negated:
|
||
return True
|
||
|
||
# 继续搜索下一个出现位置
|
||
start = idx + len(keyword)
|
||
|
||
@staticmethod
|
||
def _extract_text(output: dict[str, object]) -> str:
|
||
"""从 output dict 中提取所有文本内容"""
|
||
parts: list[str] = []
|
||
for value in output.values():
|
||
if isinstance(value, str):
|
||
parts.append(value)
|
||
else:
|
||
parts.append(str(value))
|
||
return " ".join(parts)
|
||
|
||
async def _llm_check(
|
||
self, output: dict[str, object], constraints: list[str]
|
||
) -> AlignmentCheckResult:
|
||
"""LLM 语义检查"""
|
||
content = self._extract_text(output)
|
||
constraint_text = "\n".join(f"- {c}" for c in constraints)
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"You are an alignment auditor. Check if the following output "
|
||
"violates any of the listed constraints. "
|
||
"Reply with 'PASS' if no violations, or list the violated constraints."
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
f"Constraints:\n{constraint_text}\n\nOutput:\n{content}"
|
||
),
|
||
},
|
||
]
|
||
try:
|
||
response = await self._llm_gateway.chat(
|
||
messages=messages, model=self._config.audit_model
|
||
)
|
||
reply = response.content.strip()
|
||
if reply.upper().startswith("PASS"):
|
||
return AlignmentCheckResult(passed=True, checked_by="llm")
|
||
else:
|
||
return AlignmentCheckResult(
|
||
passed=False,
|
||
violations=[reply],
|
||
checked_by="llm",
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"LLM audit failed: {e}")
|
||
return AlignmentCheckResult(
|
||
passed=True,
|
||
violations=[f"LLM audit unavailable (delegated to rule check): {e}"],
|
||
checked_by="rule",
|
||
)
|
||
|
||
def record_interaction(self, session_id: str) -> CascadeAlert | None:
|
||
"""记录一次 agent 间交互,超过阈值返回 CascadeAlert"""
|
||
self._interaction_counts[session_id] = (
|
||
self._interaction_counts.get(session_id, 0) + 1
|
||
)
|
||
count = self._interaction_counts[session_id]
|
||
if count > self._config.cascade_max_interactions:
|
||
return CascadeAlert(
|
||
session_id=session_id,
|
||
alert_type="interaction_limit",
|
||
current_value=count,
|
||
threshold=self._config.cascade_max_interactions,
|
||
message=(
|
||
f"Session {session_id} exceeded max interactions: "
|
||
f"{count} > {self._config.cascade_max_interactions}"
|
||
),
|
||
)
|
||
return None
|
||
|
||
def record_loop_depth(self, session_id: str, depth: int) -> CascadeAlert | None:
|
||
"""记录循环深度,超过阈值返回 CascadeAlert"""
|
||
self._loop_depths[session_id] = depth
|
||
if depth > self._config.cascade_max_depth:
|
||
return CascadeAlert(
|
||
session_id=session_id,
|
||
alert_type="loop_depth",
|
||
current_value=depth,
|
||
threshold=self._config.cascade_max_depth,
|
||
message=(
|
||
f"Session {session_id} exceeded max loop depth: "
|
||
f"{depth} > {self._config.cascade_max_depth}"
|
||
),
|
||
)
|
||
return None
|
||
|
||
def reset_session(self, session_id: str) -> None:
|
||
"""重置某个 session 的交互计数"""
|
||
self._interaction_counts.pop(session_id, None)
|
||
self._loop_depths.pop(session_id, None)
|
||
|
||
def get_interaction_count(self, session_id: str) -> int:
|
||
"""获取某个 session 的当前交互计数"""
|
||
return self._interaction_counts.get(session_id, 0)
|