fix: address remaining code review issues

- AlignmentGuard: direction-aware constraint checking (negation/affirmation detection)
  instead of simple substring matching to reduce false positives
- Reflexion: extract actual token usage from LLM response instead of hardcoded 1
- MemoryTool: protect version/history sections from update_soul modification
- Fix AsyncMock warnings for sync find_best_agent method
This commit is contained in:
chiguyong 2026-06-11 00:14:11 +08:00
parent 5171e942d6
commit 79eb8469f9
4 changed files with 95 additions and 7 deletions

View File

@ -237,7 +237,7 @@ class ReflexionEngine:
agent_name=agent_name,
task_type=task_type,
)
total_tokens += 1 # approximate token cost for evaluation call
total_tokens += self._extract_usage_tokens(react_result)
# Track best result
if score > best_score:
@ -269,7 +269,7 @@ class ReflexionEngine:
agent_name=agent_name,
task_type=task_type,
)
total_tokens += 1 # approximate token cost for reflection call
total_tokens += self._extract_usage_tokens(react_result)
if reflection_text is None:
# 反思失败,返回当前最佳结果
@ -672,6 +672,27 @@ class ReflexionEngine:
logger.warning(f"Reflection LLM call failed, skipping reflection: {e}")
return None
@staticmethod
def _extract_usage_tokens(result: ReActResult) -> int:
"""从 LLM 响应中提取实际 token 用量,降级时估算
尝试从 ReActResult trajectory 中获取最后一步的 usage 信息
如果不可用基于输出长度估算
"""
# 尝试从 trajectory 中获取 usage
if result.trajectory:
last_step = result.trajectory[-1]
# ReActStep 可能携带 usage 信息
usage = getattr(last_step, "usage", None) or getattr(last_step, "token_usage", None)
if usage and isinstance(usage, dict):
total = usage.get("total_tokens", 0)
if total > 0:
return total
# 降级:基于输出长度估算(约 4 字符 = 1 token
estimated = max(1, len(result.output) // 4)
return estimated
def _build_reflection_prompt(
self,
original_prompt: str | None,

View File

@ -102,15 +102,78 @@ class AlignmentGuard:
def _rule_check(
self, output: dict[str, Any], constraints: list[str]
) -> list[str]:
"""基于规则的约束检查:将 output 内容拼接后做关键词/子串匹配"""
"""基于规则的约束检查:方向性判断,区分'禁止X''提及X'
约束格式
- "不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" 输出中不应执行X
- "必须X" / "需要X" / "务必X" / "must X" / "should X" 输出中应包含X
- 其他 简单子串匹配约束关键词出现在输出中即违规
"""
import re
content = self._extract_text(output)
content_lower = content.lower()
violations: list[str] = []
for constraint in constraints:
# 简单子串匹配:约束关键词出现在输出中即视为违规
if constraint.lower() in content.lower():
constraint_lower = constraint.lower().strip()
# 检测否定约束:"不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X"
neg_match = re.match(
r"^(?:不要|禁止|不得|不能|不可以|别|no\s+|don'?t\s+|never\s+|must\s+not\s+|should\s+not\s+)\s*(.+)",
constraint_lower,
)
if neg_match:
# 否定约束:提取被禁止的内容关键词
forbidden = neg_match.group(1).strip()
# 只有当输出中实际执行了该行为时才判违规(而非仅仅提及)
# 简单启发式:如果输出中包含"执行/输出/提供了 + forbidden"则违规
# 更安全的做法:如果 forbidden 直接出现在输出中且不是否定语境
if self._is_positive_mention(content_lower, forbidden):
violations.append(constraint)
continue
# 检测肯定约束:"必须X" / "需要X" / "务必X" / "must X" / "should X"
pos_match = re.match(
r"^(?:必须|需要|务必|一定要|must\s+|should\s+|shall\s+)\s*(.+)",
constraint_lower,
)
if pos_match:
# 肯定约束:输出中应包含该内容
required = pos_match.group(1).strip()
if required not in content_lower:
violations.append(constraint)
continue
# 默认:简单子串匹配
if constraint_lower in content_lower:
violations.append(constraint)
return violations
@staticmethod
def _is_positive_mention(content: str, keyword: str) -> bool:
"""判断 keyword 在 content 中是否为肯定性提及(实际执行/输出)
如果 keyword 出现在否定语境中"我们不会存储X"不算违规
"""
# 找到 keyword 在 content 中的位置
idx = content.find(keyword)
if idx == -1:
return False
# 检查 keyword 前面是否有否定词
prefix = content[max(0, idx - 20) : idx]
neg_prefixes = [
"不会", "不能", "不要", "没有", "并未", "并未", "无法",
"won't", "don't", "not ", "never ", "no ",
]
for neg in neg_prefixes:
if neg in prefix:
return False
return True
@staticmethod
def _extract_text(output: dict[str, Any]) -> str:
"""从 output dict 中提取所有文本内容"""

View File

@ -22,6 +22,8 @@ from agentkit.tools.base import Tool
VALID_FILES = {"soul", "user", "memory", "daily"}
VALID_ACTIONS = {"add", "replace", "remove", "read", "update_soul"}
# 受保护的 section不允许通过 update_soul 修改,防止 Agent 覆盖版本追踪元数据
PROTECTED_SOUL_SECTIONS = {"版本", "更新历史"}
class MemoryTool(Tool):
@ -126,6 +128,8 @@ class MemoryTool(Tool):
return {"success": False, "error": "section is required for update_soul action"}
if not content:
return {"success": False, "error": "content is required for update_soul action"}
if section in PROTECTED_SOUL_SECTIONS:
return {"success": False, "error": f"Cannot modify protected section '{section}' via update_soul"}
return await self._update_soul(mf, section, content, reason)
return {"success": False, "error": f"Unhandled action: {action}"}

View File

@ -328,7 +328,7 @@ class TestLayer2CapabilityMatching:
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
org_context = MagicMock()
org_context.find_best_agent = AsyncMock(return_value=None)
org_context.find_best_agent = MagicMock(return_value=None)
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
@ -403,7 +403,7 @@ class TestTransparency:
"""TRACE 模式显示完整路由追踪"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = AsyncMock(return_value="analyst")
org_context.find_best_agent = MagicMock(return_value="analyst")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(