fischer-agentkit/src/agentkit/quality/alignment.py

304 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AlignmentGuard - 对齐守卫:约束注入 + 级联故障检测"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from agentkit.telemetry.tracer import get_tracer
logger = logging.getLogger(__name__)
@dataclass
class AlignmentConfig:
"""对齐守卫配置"""
constraints: list[str] = field(default_factory=list)
cascade_max_interactions: int = 10
cascade_max_depth: int = 3
audit_enabled: bool = False
audit_model: str = "default"
audit_sample_rate: float = 1.0 # 审计采样率 0.0-1.01.0=每次都审计
@classmethod
def from_dict(cls, data: dict[str, object]) -> "AlignmentConfig":
"""从字典创建,忽略未知键"""
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
filtered = {k: v for k, v in data.items() if k in known_fields}
return cls(**filtered)
@dataclass
class AlignmentCheckResult:
"""对齐检查结果"""
passed: bool
violations: list[str] = field(default_factory=list)
checked_by: str = "" # "rule" or "llm"
@dataclass
class CascadeAlert:
"""级联故障告警"""
session_id: str
alert_type: str # "interaction_limit" or "loop_depth"
current_value: int
threshold: int
message: str
class ConstraintInjector:
"""将全局约束注入到任务 input_data 中"""
def __init__(self, config: AlignmentConfig):
self._config = config
def inject(self, input_data: dict[str, object]) -> dict[str, object]:
"""注入约束指令到 input_data
在 input_data 中添加 'alignment_constraints' 键,值为约束列表。
不修改原始 dict返回新 dict。
"""
result = {**input_data, "alignment_constraints": list(self._config.constraints)}
return result
class AlignmentGuard:
"""对齐守卫 — 扩展 QualityGate增加约束注入和级联检测"""
def __init__(self, config: AlignmentConfig, llm_gateway=None):
self._config = config
self._injector = ConstraintInjector(config)
self._llm_gateway = llm_gateway
self._interaction_counts: dict[str, int] = {}
self._loop_depths: dict[str, int] = {}
def inject_constraints(self, input_data: dict[str, object]) -> dict[str, object]:
"""委托给 ConstraintInjector"""
return self._injector.inject(input_data)
async def check_output(
self,
output: dict[str, object],
constraints: list[str] | None = None,
) -> AlignmentCheckResult:
"""检查输出是否符合约束
- 系统级约束:基于规则的检查(关键词 + 正则匹配)
- 组织级约束LLM 语义检查(仅当 audit_enabled=True
"""
effective_constraints = constraints if constraints is not None else self._config.constraints
if not effective_constraints:
return AlignmentCheckResult(passed=True, checked_by="rule")
tracer = get_tracer()
with tracer.start_span("guard.check") as span:
span.set_attribute("guard.constraints_count", len(effective_constraints))
# 1. 基于规则的检查:关键词/子串匹配
violations = self._rule_check(output, effective_constraints)
if violations:
result = AlignmentCheckResult(
passed=False,
violations=violations,
checked_by="rule",
)
span.set_attribute("guard.passed", result.passed)
span.set_attribute("guard.checked_by", result.checked_by)
return result
# 2. LLM 语义检查(仅当 audit_enabled=True 且有 llm_gateway按采样率执行
if self._config.audit_enabled and self._llm_gateway is not None:
import random
if random.random() < self._config.audit_sample_rate:
result = await self._llm_check(output, effective_constraints)
span.set_attribute("guard.passed", result.passed)
span.set_attribute("guard.checked_by", result.checked_by)
return result
# 采样未命中,信任规则检查结果
logger.debug("LLM audit skipped (sample rate=%.2f)", self._config.audit_sample_rate)
result = AlignmentCheckResult(passed=True, checked_by="rule")
span.set_attribute("guard.passed", result.passed)
span.set_attribute("guard.checked_by", result.checked_by)
return result
def _rule_check(
self, output: dict[str, object], constraints: list[str]
) -> list[str]:
"""基于规则的约束检查:方向性判断,区分'禁止X''提及X'
约束格式:
- "不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" → 输出中不应执行X
- "必须X" / "需要X" / "务必X" / "must X" / "should X" → 输出中应包含X
- 其他 → 简单子串匹配(约束关键词出现在输出中即违规)
"""
import re
content = self._extract_text(output)
content_lower = content.lower()
violations: list[str] = []
for constraint in constraints:
constraint_lower = constraint.lower().strip()
# 检测否定约束:"不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X"
neg_match = re.match(
r"^(?:不要|禁止|不得|不能|不可以|别|no\s+|don'?t\s+|never\s+|must\s+not\s+|should\s+not\s+)\s*(.+)",
constraint_lower,
)
if neg_match:
# 否定约束:提取被禁止的内容关键词
forbidden = neg_match.group(1).strip()
# 只有当输出中实际执行了该行为时才判违规(而非仅仅提及)
# 简单启发式:如果输出中包含"执行/输出/提供了 + forbidden"则违规
# 更安全的做法:如果 forbidden 直接出现在输出中且不是否定语境
if self._is_positive_mention(content_lower, forbidden):
violations.append(constraint)
continue
# 检测肯定约束:"必须X" / "需要X" / "务必X" / "must X" / "should X"
pos_match = re.match(
r"^(?:必须|需要|务必|一定要|must\s+|should\s+|shall\s+)\s*(.+)",
constraint_lower,
)
if pos_match:
# 肯定约束:输出中应包含该内容
required = pos_match.group(1).strip()
if required not in content_lower:
violations.append(constraint)
continue
# 默认:简单子串匹配
if constraint_lower in content_lower:
violations.append(constraint)
return violations
@staticmethod
def _is_positive_mention(content: str, keyword: str) -> bool:
"""判断 keyword 在 content 中是否为肯定性提及(实际执行/输出)
如果 keyword 出现在否定语境中(如"我们不会存储X"),不算违规。
遍历所有出现位置,只要有一次肯定性提及即返回 True。
"""
start = 0
while True:
idx = content.find(keyword, start)
if idx == -1:
return False
# 检查 keyword 前面是否有否定词
prefix = content[max(0, idx - 20) : idx]
neg_prefixes = [
"不会", "不能", "不要", "没有", "并未", "并未", "无法",
"won't", "don't", "not ", "never ", "no ",
]
is_negated = any(neg in prefix for neg in neg_prefixes)
if not is_negated:
return True
# 继续搜索下一个出现位置
start = idx + len(keyword)
@staticmethod
def _extract_text(output: dict[str, object]) -> str:
"""从 output dict 中提取所有文本内容"""
parts: list[str] = []
for value in output.values():
if isinstance(value, str):
parts.append(value)
else:
parts.append(str(value))
return " ".join(parts)
async def _llm_check(
self, output: dict[str, object], constraints: list[str]
) -> AlignmentCheckResult:
"""LLM 语义检查"""
content = self._extract_text(output)
constraint_text = "\n".join(f"- {c}" for c in constraints)
messages = [
{
"role": "system",
"content": (
"You are an alignment auditor. Check if the following output "
"violates any of the listed constraints. "
"Reply with 'PASS' if no violations, or list the violated constraints."
),
},
{
"role": "user",
"content": (
f"Constraints:\n{constraint_text}\n\nOutput:\n{content}"
),
},
]
try:
response = await self._llm_gateway.chat(
messages=messages, model=self._config.audit_model
)
reply = response.content.strip()
if reply.upper().startswith("PASS"):
return AlignmentCheckResult(passed=True, checked_by="llm")
else:
return AlignmentCheckResult(
passed=False,
violations=[reply],
checked_by="llm",
)
except Exception as e:
logger.warning(f"LLM audit failed: {e}")
return AlignmentCheckResult(
passed=True,
violations=[f"LLM audit unavailable (delegated to rule check): {e}"],
checked_by="rule",
)
def record_interaction(self, session_id: str) -> CascadeAlert | None:
"""记录一次 agent 间交互,超过阈值返回 CascadeAlert"""
self._interaction_counts[session_id] = (
self._interaction_counts.get(session_id, 0) + 1
)
count = self._interaction_counts[session_id]
if count > self._config.cascade_max_interactions:
return CascadeAlert(
session_id=session_id,
alert_type="interaction_limit",
current_value=count,
threshold=self._config.cascade_max_interactions,
message=(
f"Session {session_id} exceeded max interactions: "
f"{count} > {self._config.cascade_max_interactions}"
),
)
return None
def record_loop_depth(self, session_id: str, depth: int) -> CascadeAlert | None:
"""记录循环深度,超过阈值返回 CascadeAlert"""
self._loop_depths[session_id] = depth
if depth > self._config.cascade_max_depth:
return CascadeAlert(
session_id=session_id,
alert_type="loop_depth",
current_value=depth,
threshold=self._config.cascade_max_depth,
message=(
f"Session {session_id} exceeded max loop depth: "
f"{depth} > {self._config.cascade_max_depth}"
),
)
return None
def reset_session(self, session_id: str) -> None:
"""重置某个 session 的交互计数"""
self._interaction_counts.pop(session_id, None)
self._loop_depths.pop(session_id, None)
def get_interaction_count(self, session_id: str) -> int:
"""获取某个 session 的当前交互计数"""
return self._interaction_counts.get(session_id, 0)