refactor: simplify code across U1-U7 (bug fix + efficiency + reuse + quality)

This commit is contained in:
chiguyong 2026-06-24 22:35:52 +08:00
parent 0847c0e086
commit 567cbc9c9b
6 changed files with 71 additions and 72 deletions

View File

@ -98,6 +98,31 @@ def parse_skill_prefix(content: str) -> tuple[str | None, str]:
return explicit_skill, clean
_PROMPT_KEYS = ("identity", "context", "instructions", "constraints", "output_format")
def format_preconditions_block(preconditions: list[str], header_level: int = 2) -> str:
"""格式化激活前置条件文本块。"""
header = "#" * header_level
lines = [f"{header} Activation Preconditions", "Before executing this skill, verify:"]
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
return "\n".join(lines)
def collect_prompt_parts(config: Any, with_headers: bool = False) -> list[str]:
"""从 skill config 的 prompt 字典中收集各部分文本。"""
prompt = config.prompt or {}
parts: list[str] = []
for key in _PROMPT_KEYS:
val = prompt.get(key)
if val:
parts.append(f"### {key.title()}\n{val}" if with_headers else val)
return parts
def build_skill_system_prompt(skill_config) -> str | None:
"""Build system prompt from skill config's prompt section.
@ -120,12 +145,7 @@ def build_skill_system_prompt(skill_config) -> str | None:
# 安全守卫:前置条件即使在概要模式下也必须注入
preconditions = getattr(skill_config, "preconditions", None)
if preconditions:
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
return f"{summary}\n\n" + "\n".join(lines)
return f"{summary}\n\n" + format_preconditions_block(preconditions)
# 提示 LLM 可通过 skill_detail 工具加载完整 instructions
return (
f"{summary}\n\n"
@ -135,22 +155,12 @@ def build_skill_system_prompt(skill_config) -> str | None:
# Level 1+: 全量加载(现有行为)
if not skill_config.prompt:
return None
prompt_parts = []
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = skill_config.prompt.get(key)
if val:
prompt_parts.append(val)
base = "\n\n".join(prompt_parts) if prompt_parts else None
base = "\n\n".join(collect_prompt_parts(skill_config)) or None
# v7: 注入激活前置条件(软检查)
preconditions = getattr(skill_config, "preconditions", None)
if preconditions:
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
preconditions_block = "\n".join(lines)
preconditions_block = format_preconditions_block(preconditions)
return f"{base}\n\n{preconditions_block}" if base else preconditions_block
return base

View File

@ -90,15 +90,12 @@ class MiddlewareChain:
return await handler(ctx)
# before: 外 → 内
# before 异常自然传播,不执行 after 链
executed_befores: list[Middleware] = []
current_ctx = ctx
try:
for mw in self._middlewares:
current_ctx = await mw.before(current_ctx)
executed_befores.append(mw)
except Exception:
# before 异常:不执行 after直接传播
raise
for mw in self._middlewares:
current_ctx = await mw.before(current_ctx)
executed_befores.append(mw)
# handler
result = await handler(current_ctx)
@ -152,12 +149,11 @@ class SummarizationMiddleware:
class TokenUsageMiddleware:
"""Token 计量中间件 — 记录请求的 token 用量。
before: 记录起始时间
before: 无操作
after: result 中提取 token usage记录到 metadata
"""
async def before(self, ctx: RequestContext) -> RequestContext:
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
return ctx
async def after(self, ctx: RequestContext, result: Any) -> Any:
@ -191,17 +187,18 @@ class LoopDetectionMiddleware:
if len(trajectory) < self._threshold:
return result
# 检查最终 trajectory 中的重复工具调用模式
# 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
recent = trajectory[-self._window_size :] if trajectory else []
tool_calls = [
(step.get("tool_name", ""), step.get("arguments_hash", ""))
for step in trajectory
for step in recent
if isinstance(step, dict) and "tool_name" in step
]
if not tool_calls:
return result
# 滑动窗口检测
window = tool_calls[-self._window_size :]
window = tool_calls
unique = set(window)
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
logger.warning(

View File

@ -98,8 +98,7 @@ async def _ensure_async_iterable(obj: Any, label: str = "<obj>"):
yield item
return
raise TypeError(
f"{label}: awaited value is not async iterable "
f"(got {type(resolved).__name__})"
f"{label}: awaited value is not async iterable (got {type(resolved).__name__})"
)
# Case 3: anything else — surface a clear, actionable error rather
@ -218,19 +217,17 @@ class ReActEngine:
ponytail: 精确 hash 匹配不做语义相似度
"""
hash_to_name: dict[str, str] = {}
for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
h = hash(f"{tc.name}:{args_str}")
self._loop_window.append(str(h))
h = str(hash(f"{tc.name}:{args_str}"))
self._loop_window.append(h)
hash_to_name[h] = tc.name
counts = Counter(self._loop_window)
for h, count in counts.items():
if count >= self._loop_threshold:
# Find the tool name for this hash
for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
if str(hash(f"{tc.name}:{args_str}")) == h:
return tc.name
return hash_to_name.get(h)
return None
async def execute(
@ -1162,6 +1159,13 @@ class ReActEngine:
)
# 记录 LLM 调用步骤
if trace_recorder is not None:
trace_recorder.record_step(
step=step,
action="llm_call",
duration_ms=llm_duration_ms,
tokens_used=step_tokens,
)
# Record assistant message
assistant_msg: dict[str, Any] = {

View File

@ -116,13 +116,9 @@ class PipelineCheckpoint:
# 尝试写入 Redis
if self._redis is not None:
try:
await self._redis.set(
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
)
await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
except Exception as e:
logger.warning(
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
)
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
"""加载完整 plan JSON。"""
@ -133,9 +129,7 @@ class PipelineCheckpoint:
if raw:
return json.loads(raw)
except Exception as e:
logger.warning(
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
)
logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
# 内存降级(检查 TTL 过期)
entry = self._memory_plans.get(plan_id)
if entry is None:
@ -217,14 +211,17 @@ class PipelineCheckpoint:
if not phase_ids:
# Redis 无数据,检查内存(过滤过期)
return [
c
for c in self._memory.get(plan_id, [])
if not self._is_expired(c.saved_at)
c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)
]
results: list[CheckpointData] = []
# 批量 GETpipeline 避免 N+1 往返)
pipe = self._redis.pipeline()
for phase_id in phase_ids:
raw = await self._redis.get(self._key(plan_id, phase_id))
pipe.get(self._key(plan_id, phase_id))
raws = await pipe.execute()
results: list[CheckpointData] = []
for raw in raws:
if raw:
data = json.loads(raw)
results.append(CheckpointData.from_dict(data))
@ -255,6 +252,4 @@ class PipelineCheckpoint:
pipe.delete(self._plan_key(plan_id))
await pipe.execute()
except Exception as e:
logger.warning(
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
)
logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")

View File

@ -14,6 +14,8 @@ from __future__ import annotations
import logging
from typing import Any
from agentkit.chat.skill_routing import collect_prompt_parts, format_preconditions_block
from agentkit.core.exceptions import SkillNotFoundError
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
@ -91,7 +93,7 @@ class SkillDetailTool(Tool):
try:
skill = self._registry.get(query)
return self._format_skill_full(skill)
except Exception:
except SkillNotFoundError:
pass # 非精确匹配,降级到关键词搜索
# 关键词搜索:匹配 skill 名称和描述
@ -118,21 +120,12 @@ class SkillDetailTool(Tool):
def _format_skill_full(skill: Any) -> dict[str, Any]:
"""格式化 skill 的完整 instructions 供 LLM 使用。"""
config = skill.config
prompt_parts: list[str] = []
for key in ("identity", "context", "instructions", "constraints", "output_format"):
val = (config.prompt or {}).get(key)
if val:
prompt_parts.append(f"### {key.title()}\n{val}")
prompt_parts = collect_prompt_parts(config, with_headers=True)
# v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略)
preconditions = getattr(config, "preconditions", None)
if preconditions:
lines = ["### Activation Preconditions", "Before executing this skill, verify:"]
lines.extend(f"- {p}" for p in preconditions)
lines.append(
"If any precondition is not met, refuse to execute or ask the user for clarification."
)
prompt_parts.append("\n".join(lines))
prompt_parts.append(format_preconditions_block(preconditions, header_level=3))
return {
"name": config.name,

View File

@ -428,12 +428,12 @@ class TestTokenUsageMiddleware:
"""TokenUsageMiddleware 测试。"""
@pytest.mark.asyncio
async def test_before_initializes_metadata(self) -> None:
"""before 初始化 token_usage_start"""
async def test_before_returns_ctx_unchanged(self) -> None:
"""before 返回 ctx 不做修改token 计量在 after 中完成)"""
mw = TokenUsageMiddleware()
ctx = _make_ctx()
await mw.before(ctx)
assert "token_usage_start" in ctx.metadata
result = await mw.before(ctx)
assert result is ctx
@pytest.mark.asyncio
async def test_after_extracts_usage_from_result(self) -> None: