From 79eb8469f97f33e676e077e5630e107658a83982 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 11 Jun 2026 00:14:11 +0800 Subject: [PATCH] fix: address remaining code review issues - AlignmentGuard: direction-aware constraint checking (negation/affirmation detection) instead of simple substring matching to reduce false positives - Reflexion: extract actual token usage from LLM response instead of hardcoded 1 - MemoryTool: protect version/history sections from update_soul modification - Fix AsyncMock warnings for sync find_best_agent method --- src/agentkit/core/reflexion.py | 25 +++++++++- src/agentkit/quality/alignment.py | 69 ++++++++++++++++++++++++++-- src/agentkit/tools/memory_tool.py | 4 ++ tests/unit/test_cost_aware_router.py | 4 +- 4 files changed, 95 insertions(+), 7 deletions(-) diff --git a/src/agentkit/core/reflexion.py b/src/agentkit/core/reflexion.py index 6571d2b..2b26e54 100644 --- a/src/agentkit/core/reflexion.py +++ b/src/agentkit/core/reflexion.py @@ -237,7 +237,7 @@ class ReflexionEngine: agent_name=agent_name, task_type=task_type, ) - total_tokens += 1 # approximate token cost for evaluation call + total_tokens += self._extract_usage_tokens(react_result) # Track best result if score > best_score: @@ -269,7 +269,7 @@ class ReflexionEngine: agent_name=agent_name, task_type=task_type, ) - total_tokens += 1 # approximate token cost for reflection call + total_tokens += self._extract_usage_tokens(react_result) if reflection_text is None: # 反思失败,返回当前最佳结果 @@ -672,6 +672,27 @@ class ReflexionEngine: logger.warning(f"Reflection LLM call failed, skipping reflection: {e}") return None + @staticmethod + def _extract_usage_tokens(result: ReActResult) -> int: + """从 LLM 响应中提取实际 token 用量,降级时估算 + + 尝试从 ReActResult 的 trajectory 中获取最后一步的 usage 信息。 + 如果不可用,基于输出长度估算。 + """ + # 尝试从 trajectory 中获取 usage + if result.trajectory: + last_step = result.trajectory[-1] + # ReActStep 可能携带 usage 信息 + usage = getattr(last_step, "usage", None) or getattr(last_step, "token_usage", None) + if usage and isinstance(usage, dict): + total = usage.get("total_tokens", 0) + if total > 0: + return total + + # 降级:基于输出长度估算(约 4 字符 = 1 token) + estimated = max(1, len(result.output) // 4) + return estimated + def _build_reflection_prompt( self, original_prompt: str | None, diff --git a/src/agentkit/quality/alignment.py b/src/agentkit/quality/alignment.py index 66b82de..86faddd 100644 --- a/src/agentkit/quality/alignment.py +++ b/src/agentkit/quality/alignment.py @@ -102,15 +102,78 @@ class AlignmentGuard: def _rule_check( self, output: dict[str, Any], constraints: list[str] ) -> list[str]: - """基于规则的约束检查:将 output 内容拼接后做关键词/子串匹配""" + """基于规则的约束检查:方向性判断,区分'禁止X'和'提及X' + + 约束格式: + - "不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" → 输出中不应执行X + - "必须X" / "需要X" / "务必X" / "must X" / "should X" → 输出中应包含X + - 其他 → 简单子串匹配(约束关键词出现在输出中即违规) + """ + import re + content = self._extract_text(output) + content_lower = content.lower() violations: list[str] = [] + for constraint in constraints: - # 简单子串匹配:约束关键词出现在输出中即视为违规 - if constraint.lower() in content.lower(): + constraint_lower = constraint.lower().strip() + + # 检测否定约束:"不要X" / "禁止X" / "不得X" / "不能X" / "no X" / "don't X" + neg_match = re.match( + r"^(?:不要|禁止|不得|不能|不可以|别|no\s+|don'?t\s+|never\s+|must\s+not\s+|should\s+not\s+)\s*(.+)", + constraint_lower, + ) + if neg_match: + # 否定约束:提取被禁止的内容关键词 + forbidden = neg_match.group(1).strip() + # 只有当输出中实际执行了该行为时才判违规(而非仅仅提及) + # 简单启发式:如果输出中包含"执行/输出/提供了 + forbidden"则违规 + # 更安全的做法:如果 forbidden 直接出现在输出中且不是否定语境 + if self._is_positive_mention(content_lower, forbidden): + violations.append(constraint) + continue + + # 检测肯定约束:"必须X" / "需要X" / "务必X" / "must X" / "should X" + pos_match = re.match( + r"^(?:必须|需要|务必|一定要|must\s+|should\s+|shall\s+)\s*(.+)", + constraint_lower, + ) + if pos_match: + # 肯定约束:输出中应包含该内容 + required = pos_match.group(1).strip() + if required not in content_lower: + violations.append(constraint) + continue + + # 默认:简单子串匹配 + if constraint_lower in content_lower: violations.append(constraint) + return violations + @staticmethod + def _is_positive_mention(content: str, keyword: str) -> bool: + """判断 keyword 在 content 中是否为肯定性提及(实际执行/输出) + + 如果 keyword 出现在否定语境中(如"我们不会存储X"),不算违规。 + """ + # 找到 keyword 在 content 中的位置 + idx = content.find(keyword) + if idx == -1: + return False + + # 检查 keyword 前面是否有否定词 + prefix = content[max(0, idx - 20) : idx] + neg_prefixes = [ + "不会", "不能", "不要", "没有", "并未", "并未", "无法", + "won't", "don't", "not ", "never ", "no ", + ] + for neg in neg_prefixes: + if neg in prefix: + return False + + return True + @staticmethod def _extract_text(output: dict[str, Any]) -> str: """从 output dict 中提取所有文本内容""" diff --git a/src/agentkit/tools/memory_tool.py b/src/agentkit/tools/memory_tool.py index 1901682..448054a 100644 --- a/src/agentkit/tools/memory_tool.py +++ b/src/agentkit/tools/memory_tool.py @@ -22,6 +22,8 @@ from agentkit.tools.base import Tool VALID_FILES = {"soul", "user", "memory", "daily"} VALID_ACTIONS = {"add", "replace", "remove", "read", "update_soul"} +# 受保护的 section:不允许通过 update_soul 修改,防止 Agent 覆盖版本追踪元数据 +PROTECTED_SOUL_SECTIONS = {"版本", "更新历史"} class MemoryTool(Tool): @@ -126,6 +128,8 @@ class MemoryTool(Tool): return {"success": False, "error": "section is required for update_soul action"} if not content: return {"success": False, "error": "content is required for update_soul action"} + if section in PROTECTED_SOUL_SECTIONS: + return {"success": False, "error": f"Cannot modify protected section '{section}' via update_soul"} return await self._update_soul(mf, section, content, reason) return {"success": False, "error": f"Unhandled action: {action}"} diff --git a/tests/unit/test_cost_aware_router.py b/tests/unit/test_cost_aware_router.py index 3f0326f..50a29f2 100644 --- a/tests/unit/test_cost_aware_router.py +++ b/tests/unit/test_cost_aware_router.py @@ -328,7 +328,7 @@ class TestLayer2CapabilityMatching: """org_context.find_best_agent 返回 None 时回退到 IntentRouter""" gateway = _make_llm_gateway(json.dumps({"complexity": 0.8})) org_context = MagicMock() - org_context.find_best_agent = AsyncMock(return_value=None) + org_context.find_best_agent = MagicMock(return_value=None) search_skill = _make_skill("search", keywords=["调研"], description="市场调研") registry = _make_skill_registry([search_skill]) @@ -403,7 +403,7 @@ class TestTransparency: """TRACE 模式显示完整路由追踪""" gateway = _make_llm_gateway(json.dumps({"complexity": 0.85})) org_context = MagicMock() - org_context.find_best_agent = AsyncMock(return_value="analyst") + org_context.find_best_agent = MagicMock(return_value="analyst") router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context) result = await router.route(