fix: address remaining code review issues
- AlignmentGuard: direction-aware constraint checking (negation/affirmation detection) instead of simple substring matching to reduce false positives - Reflexion: extract actual token usage from LLM response instead of hardcoded 1 - MemoryTool: protect version/history sections from update_soul modification - Fix AsyncMock warnings for sync find_best_agent method
This commit is contained in:
parent
5171e942d6
commit
79eb8469f9
|
|
@ -237,7 +237,7 @@ class ReflexionEngine:
|
|||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += 1 # approximate token cost for evaluation call
|
||||
total_tokens += self._extract_usage_tokens(react_result)
|
||||
|
||||
# Track best result
|
||||
if score > best_score:
|
||||
|
|
@ -269,7 +269,7 @@ class ReflexionEngine:
|
|||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += 1 # approximate token cost for reflection call
|
||||
total_tokens += self._extract_usage_tokens(react_result)
|
||||
|
||||
if reflection_text is None:
|
||||
# 反思失败,返回当前最佳结果
|
||||
|
|
@ -672,6 +672,27 @@ class ReflexionEngine:
|
|||
logger.warning(f"Reflection LLM call failed, skipping reflection: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_usage_tokens(result: ReActResult) -> int:
|
||||
"""从 LLM 响应中提取实际 token 用量,降级时估算
|
||||
|
||||
尝试从 ReActResult 的 trajectory 中获取最后一步的 usage 信息。
|
||||
如果不可用,基于输出长度估算。
|
||||
"""
|
||||
# 尝试从 trajectory 中获取 usage
|
||||
if result.trajectory:
|
||||
last_step = result.trajectory[-1]
|
||||
# ReActStep 可能携带 usage 信息
|
||||
usage = getattr(last_step, "usage", None) or getattr(last_step, "token_usage", None)
|
||||
if usage and isinstance(usage, dict):
|
||||
total = usage.get("total_tokens", 0)
|
||||
if total > 0:
|
||||
return total
|
||||
|
||||
# 降级:基于输出长度估算(约 4 字符 = 1 token)
|
||||
estimated = max(1, len(result.output) // 4)
|
||||
return estimated
|
||||
|
||||
def _build_reflection_prompt(
|
||||
self,
|
||||
original_prompt: str | None,
|
||||
|
|
|
|||
|
|
@ -102,15 +102,78 @@ class AlignmentGuard:
|
|||
def _rule_check(
|
||||
self, output: dict[str, Any], constraints: list[str]
|
||||
) -> list[str]:
|
||||
"""基于规则的约束检查:将 output 内容拼接后做关键词/子串匹配"""
|
||||
"""基于规则的约束检查:方向性判断,区分'禁止X'和'提及X'
|
||||
|
||||
约束格式:
|
||||
- "不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" → 输出中不应执行X
|
||||
- "必须X" / "需要X" / "务必X" / "must X" / "should X" → 输出中应包含X
|
||||
- 其他 → 简单子串匹配(约束关键词出现在输出中即违规)
|
||||
"""
|
||||
import re
|
||||
|
||||
content = self._extract_text(output)
|
||||
content_lower = content.lower()
|
||||
violations: list[str] = []
|
||||
|
||||
for constraint in constraints:
|
||||
# 简单子串匹配:约束关键词出现在输出中即视为违规
|
||||
if constraint.lower() in content.lower():
|
||||
constraint_lower = constraint.lower().strip()
|
||||
|
||||
# 检测否定约束:"不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X"
|
||||
neg_match = re.match(
|
||||
r"^(?:不要|禁止|不得|不能|不可以|别|no\s+|don'?t\s+|never\s+|must\s+not\s+|should\s+not\s+)\s*(.+)",
|
||||
constraint_lower,
|
||||
)
|
||||
if neg_match:
|
||||
# 否定约束:提取被禁止的内容关键词
|
||||
forbidden = neg_match.group(1).strip()
|
||||
# 只有当输出中实际执行了该行为时才判违规(而非仅仅提及)
|
||||
# 简单启发式:如果输出中包含"执行/输出/提供了 + forbidden"则违规
|
||||
# 更安全的做法:如果 forbidden 直接出现在输出中且不是否定语境
|
||||
if self._is_positive_mention(content_lower, forbidden):
|
||||
violations.append(constraint)
|
||||
continue
|
||||
|
||||
# 检测肯定约束:"必须X" / "需要X" / "务必X" / "must X" / "should X"
|
||||
pos_match = re.match(
|
||||
r"^(?:必须|需要|务必|一定要|must\s+|should\s+|shall\s+)\s*(.+)",
|
||||
constraint_lower,
|
||||
)
|
||||
if pos_match:
|
||||
# 肯定约束:输出中应包含该内容
|
||||
required = pos_match.group(1).strip()
|
||||
if required not in content_lower:
|
||||
violations.append(constraint)
|
||||
continue
|
||||
|
||||
# 默认:简单子串匹配
|
||||
if constraint_lower in content_lower:
|
||||
violations.append(constraint)
|
||||
|
||||
return violations
|
||||
|
||||
@staticmethod
|
||||
def _is_positive_mention(content: str, keyword: str) -> bool:
|
||||
"""判断 keyword 在 content 中是否为肯定性提及(实际执行/输出)
|
||||
|
||||
如果 keyword 出现在否定语境中(如"我们不会存储X"),不算违规。
|
||||
"""
|
||||
# 找到 keyword 在 content 中的位置
|
||||
idx = content.find(keyword)
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
# 检查 keyword 前面是否有否定词
|
||||
prefix = content[max(0, idx - 20) : idx]
|
||||
neg_prefixes = [
|
||||
"不会", "不能", "不要", "没有", "并未", "并未", "无法",
|
||||
"won't", "don't", "not ", "never ", "no ",
|
||||
]
|
||||
for neg in neg_prefixes:
|
||||
if neg in prefix:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(output: dict[str, Any]) -> str:
|
||||
"""从 output dict 中提取所有文本内容"""
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from agentkit.tools.base import Tool
|
|||
|
||||
VALID_FILES = {"soul", "user", "memory", "daily"}
|
||||
VALID_ACTIONS = {"add", "replace", "remove", "read", "update_soul"}
|
||||
# 受保护的 section:不允许通过 update_soul 修改,防止 Agent 覆盖版本追踪元数据
|
||||
PROTECTED_SOUL_SECTIONS = {"版本", "更新历史"}
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
|
|
@ -126,6 +128,8 @@ class MemoryTool(Tool):
|
|||
return {"success": False, "error": "section is required for update_soul action"}
|
||||
if not content:
|
||||
return {"success": False, "error": "content is required for update_soul action"}
|
||||
if section in PROTECTED_SOUL_SECTIONS:
|
||||
return {"success": False, "error": f"Cannot modify protected section '{section}' via update_soul"}
|
||||
return await self._update_soul(mf, section, content, reason)
|
||||
|
||||
return {"success": False, "error": f"Unhandled action: {action}"}
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@ class TestLayer2CapabilityMatching:
|
|||
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value=None)
|
||||
org_context.find_best_agent = MagicMock(return_value=None)
|
||||
|
||||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
|
@ -403,7 +403,7 @@ class TestTransparency:
|
|||
"""TRACE 模式显示完整路由追踪"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = AsyncMock(return_value="analyst")
|
||||
org_context.find_best_agent = MagicMock(return_value="analyst")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
|
|
|
|||
Loading…
Reference in New Issue