refactor: simplify code across U1-U7 (bug fix + efficiency + reuse + quality)

This commit is contained in:
chiguyong 2026-06-24 22:35:52 +08:00
parent 0847c0e086
commit 567cbc9c9b
6 changed files with 71 additions and 72 deletions

View File

@ -98,6 +98,31 @@ def parse_skill_prefix(content: str) -> tuple[str | None, str]:
return explicit_skill, clean return explicit_skill, clean
_PROMPT_KEYS = ("identity", "context", "instructions", "constraints", "output_format")
def format_preconditions_block(preconditions: list[str], header_level: int = 2) -> str:
"""格式化激活前置条件文本块。"""
header = "#" * header_level
lines = [f"{header} Activation Preconditions", "Before executing this skill, verify:"]
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
return "\n".join(lines)
def collect_prompt_parts(config: Any, with_headers: bool = False) -> list[str]:
"""从 skill config 的 prompt 字典中收集各部分文本。"""
prompt = config.prompt or {}
parts: list[str] = []
for key in _PROMPT_KEYS:
val = prompt.get(key)
if val:
parts.append(f"### {key.title()}\n{val}" if with_headers else val)
return parts
def build_skill_system_prompt(skill_config) -> str | None: def build_skill_system_prompt(skill_config) -> str | None:
"""Build system prompt from skill config's prompt section. """Build system prompt from skill config's prompt section.
@ -120,12 +145,7 @@ def build_skill_system_prompt(skill_config) -> str | None:
# 安全守卫:前置条件即使在概要模式下也必须注入 # 安全守卫:前置条件即使在概要模式下也必须注入
preconditions = getattr(skill_config, "preconditions", None) preconditions = getattr(skill_config, "preconditions", None)
if preconditions: if preconditions:
lines = ["## Activation Preconditions", "Before executing this skill, verify:"] return f"{summary}\n\n" + format_preconditions_block(preconditions)
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
return f"{summary}\n\n" + "\n".join(lines)
# 提示 LLM 可通过 skill_detail 工具加载完整 instructions # 提示 LLM 可通过 skill_detail 工具加载完整 instructions
return ( return (
f"{summary}\n\n" f"{summary}\n\n"
@ -135,22 +155,12 @@ def build_skill_system_prompt(skill_config) -> str | None:
# Level 1+: 全量加载(现有行为) # Level 1+: 全量加载(现有行为)
if not skill_config.prompt: if not skill_config.prompt:
return None return None
prompt_parts = [] base = "\n\n".join(collect_prompt_parts(skill_config)) or None
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = skill_config.prompt.get(key)
if val:
prompt_parts.append(val)
base = "\n\n".join(prompt_parts) if prompt_parts else None
# v7: 注入激活前置条件(软检查) # v7: 注入激活前置条件(软检查)
preconditions = getattr(skill_config, "preconditions", None) preconditions = getattr(skill_config, "preconditions", None)
if preconditions: if preconditions:
lines = ["## Activation Preconditions", "Before executing this skill, verify:"] preconditions_block = format_preconditions_block(preconditions)
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
preconditions_block = "\n".join(lines)
return f"{base}\n\n{preconditions_block}" if base else preconditions_block return f"{base}\n\n{preconditions_block}" if base else preconditions_block
return base return base

View File

@ -90,15 +90,12 @@ class MiddlewareChain:
return await handler(ctx) return await handler(ctx)
# before: 外 → 内 # before: 外 → 内
# before 异常自然传播,不执行 after 链
executed_befores: list[Middleware] = [] executed_befores: list[Middleware] = []
current_ctx = ctx current_ctx = ctx
try:
for mw in self._middlewares: for mw in self._middlewares:
current_ctx = await mw.before(current_ctx) current_ctx = await mw.before(current_ctx)
executed_befores.append(mw) executed_befores.append(mw)
except Exception:
# before 异常:不执行 after直接传播
raise
# handler # handler
result = await handler(current_ctx) result = await handler(current_ctx)
@ -152,12 +149,11 @@ class SummarizationMiddleware:
class TokenUsageMiddleware: class TokenUsageMiddleware:
"""Token 计量中间件 — 记录请求的 token 用量。 """Token 计量中间件 — 记录请求的 token 用量。
before: 记录起始时间 before: 无操作
after: result 中提取 token usage记录到 metadata after: result 中提取 token usage记录到 metadata
""" """
async def before(self, ctx: RequestContext) -> RequestContext: async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
return ctx return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any: async def after(self, ctx: RequestContext, result: Any) -> Any:
@ -191,17 +187,18 @@ class LoopDetectionMiddleware:
if len(trajectory) < self._threshold: if len(trajectory) < self._threshold:
return result return result
# 检查最终 trajectory 中的重复工具调用模式 # 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
recent = trajectory[-self._window_size :] if trajectory else []
tool_calls = [ tool_calls = [
(step.get("tool_name", ""), step.get("arguments_hash", "")) (step.get("tool_name", ""), step.get("arguments_hash", ""))
for step in trajectory for step in recent
if isinstance(step, dict) and "tool_name" in step if isinstance(step, dict) and "tool_name" in step
] ]
if not tool_calls: if not tool_calls:
return result return result
# 滑动窗口检测 # 滑动窗口检测
window = tool_calls[-self._window_size :] window = tool_calls
unique = set(window) unique = set(window)
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1: if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
logger.warning( logger.warning(

View File

@ -98,8 +98,7 @@ async def _ensure_async_iterable(obj: Any, label: str = "<obj>"):
yield item yield item
return return
raise TypeError( raise TypeError(
f"{label}: awaited value is not async iterable " f"{label}: awaited value is not async iterable (got {type(resolved).__name__})"
f"(got {type(resolved).__name__})"
) )
# Case 3: anything else — surface a clear, actionable error rather # Case 3: anything else — surface a clear, actionable error rather
@ -218,19 +217,17 @@ class ReActEngine:
ponytail: 精确 hash 匹配不做语义相似度 ponytail: 精确 hash 匹配不做语义相似度
""" """
hash_to_name: dict[str, str] = {}
for tc in tool_calls: for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str) args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
h = hash(f"{tc.name}:{args_str}") h = str(hash(f"{tc.name}:{args_str}"))
self._loop_window.append(str(h)) self._loop_window.append(h)
hash_to_name[h] = tc.name
counts = Counter(self._loop_window) counts = Counter(self._loop_window)
for h, count in counts.items(): for h, count in counts.items():
if count >= self._loop_threshold: if count >= self._loop_threshold:
# Find the tool name for this hash return hash_to_name.get(h)
for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
if str(hash(f"{tc.name}:{args_str}")) == h:
return tc.name
return None return None
async def execute( async def execute(
@ -1162,6 +1159,13 @@ class ReActEngine:
) )
# 记录 LLM 调用步骤 # 记录 LLM 调用步骤
if trace_recorder is not None:
trace_recorder.record_step(
step=step,
action="llm_call",
duration_ms=llm_duration_ms,
tokens_used=step_tokens,
)
# Record assistant message # Record assistant message
assistant_msg: dict[str, Any] = { assistant_msg: dict[str, Any] = {

View File

@ -116,13 +116,9 @@ class PipelineCheckpoint:
# 尝试写入 Redis # 尝试写入 Redis
if self._redis is not None: if self._redis is not None:
try: try:
await self._redis.set( await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
)
async def load_plan(self, plan_id: str) -> dict[str, Any] | None: async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
"""加载完整 plan JSON。""" """加载完整 plan JSON。"""
@ -133,9 +129,7 @@ class PipelineCheckpoint:
if raw: if raw:
return json.loads(raw) return json.loads(raw)
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
)
# 内存降级(检查 TTL 过期) # 内存降级(检查 TTL 过期)
entry = self._memory_plans.get(plan_id) entry = self._memory_plans.get(plan_id)
if entry is None: if entry is None:
@ -217,14 +211,17 @@ class PipelineCheckpoint:
if not phase_ids: if not phase_ids:
# Redis 无数据,检查内存(过滤过期) # Redis 无数据,检查内存(过滤过期)
return [ return [
c c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)
for c in self._memory.get(plan_id, [])
if not self._is_expired(c.saved_at)
] ]
results: list[CheckpointData] = [] # 批量 GETpipeline 避免 N+1 往返)
pipe = self._redis.pipeline()
for phase_id in phase_ids: for phase_id in phase_ids:
raw = await self._redis.get(self._key(plan_id, phase_id)) pipe.get(self._key(plan_id, phase_id))
raws = await pipe.execute()
results: list[CheckpointData] = []
for raw in raws:
if raw: if raw:
data = json.loads(raw) data = json.loads(raw)
results.append(CheckpointData.from_dict(data)) results.append(CheckpointData.from_dict(data))
@ -255,6 +252,4 @@ class PipelineCheckpoint:
pipe.delete(self._plan_key(plan_id)) pipe.delete(self._plan_key(plan_id))
await pipe.execute() await pipe.execute()
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
)

View File

@ -14,6 +14,8 @@ from __future__ import annotations
import logging import logging
from typing import Any from typing import Any
from agentkit.chat.skill_routing import collect_prompt_parts, format_preconditions_block
from agentkit.core.exceptions import SkillNotFoundError
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,7 +93,7 @@ class SkillDetailTool(Tool):
try: try:
skill = self._registry.get(query) skill = self._registry.get(query)
return self._format_skill_full(skill) return self._format_skill_full(skill)
except Exception: except SkillNotFoundError:
pass # 非精确匹配,降级到关键词搜索 pass # 非精确匹配,降级到关键词搜索
# 关键词搜索:匹配 skill 名称和描述 # 关键词搜索:匹配 skill 名称和描述
@ -118,21 +120,12 @@ class SkillDetailTool(Tool):
def _format_skill_full(skill: Any) -> dict[str, Any]: def _format_skill_full(skill: Any) -> dict[str, Any]:
"""格式化 skill 的完整 instructions 供 LLM 使用。""" """格式化 skill 的完整 instructions 供 LLM 使用。"""
config = skill.config config = skill.config
prompt_parts: list[str] = [] prompt_parts = collect_prompt_parts(config, with_headers=True)
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = (config.prompt or {}).get(key)
if val:
prompt_parts.append(f"### {key.title()}\n{val}")
# v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略) # v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略)
preconditions = getattr(config, "preconditions", None) preconditions = getattr(config, "preconditions", None)
if preconditions: if preconditions:
lines = ["### Activation Preconditions", "Before executing this skill, verify:"] prompt_parts.append(format_preconditions_block(preconditions, header_level=3))
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
prompt_parts.append("\n".join(lines))
return { return {
"name": config.name, "name": config.name,

View File

@ -428,12 +428,12 @@ class TestTokenUsageMiddleware:
"""TokenUsageMiddleware 测试。""" """TokenUsageMiddleware 测试。"""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_before_initializes_metadata(self) -> None: async def test_before_returns_ctx_unchanged(self) -> None:
"""before 初始化 token_usage_start""" """before 返回 ctx 不做修改token 计量在 after 中完成)"""
mw = TokenUsageMiddleware() mw = TokenUsageMiddleware()
ctx = _make_ctx() ctx = _make_ctx()
await mw.before(ctx) result = await mw.before(ctx)
assert "token_usage_start" in ctx.metadata assert result is ctx
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_after_extracts_usage_from_result(self) -> None: async def test_after_extracts_usage_from_result(self) -> None: