refactor: simplify code across U1-U7 (bug fix + efficiency + reuse + quality)
This commit is contained in:
parent
0847c0e086
commit
567cbc9c9b
|
|
@ -98,6 +98,31 @@ def parse_skill_prefix(content: str) -> tuple[str | None, str]:
|
|||
return explicit_skill, clean
|
||||
|
||||
|
||||
_PROMPT_KEYS = ("identity", "context", "instructions", "constraints", "output_format")
|
||||
|
||||
|
||||
def format_preconditions_block(preconditions: list[str], header_level: int = 2) -> str:
|
||||
"""格式化激活前置条件文本块。"""
|
||||
header = "#" * header_level
|
||||
lines = [f"{header} Activation Preconditions", "Before executing this skill, verify:"]
|
||||
lines.extend(f"- {p}" for p in preconditions)
|
||||
lines.append(
|
||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def collect_prompt_parts(config: Any, with_headers: bool = False) -> list[str]:
|
||||
"""从 skill config 的 prompt 字典中收集各部分文本。"""
|
||||
prompt = config.prompt or {}
|
||||
parts: list[str] = []
|
||||
for key in _PROMPT_KEYS:
|
||||
val = prompt.get(key)
|
||||
if val:
|
||||
parts.append(f"### {key.title()}\n{val}" if with_headers else val)
|
||||
return parts
|
||||
|
||||
|
||||
def build_skill_system_prompt(skill_config) -> str | None:
|
||||
"""Build system prompt from skill config's prompt section.
|
||||
|
||||
|
|
@ -120,12 +145,7 @@ def build_skill_system_prompt(skill_config) -> str | None:
|
|||
# 安全守卫:前置条件即使在概要模式下也必须注入
|
||||
preconditions = getattr(skill_config, "preconditions", None)
|
||||
if preconditions:
|
||||
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
|
||||
lines.extend(f"- {p}" for p in preconditions)
|
||||
lines.append(
|
||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
||||
)
|
||||
return f"{summary}\n\n" + "\n".join(lines)
|
||||
return f"{summary}\n\n" + format_preconditions_block(preconditions)
|
||||
# 提示 LLM 可通过 skill_detail 工具加载完整 instructions
|
||||
return (
|
||||
f"{summary}\n\n"
|
||||
|
|
@ -135,22 +155,12 @@ def build_skill_system_prompt(skill_config) -> str | None:
|
|||
# Level 1+: 全量加载(现有行为)
|
||||
if not skill_config.prompt:
|
||||
return None
|
||||
prompt_parts = []
|
||||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
||||
val = skill_config.prompt.get(key)
|
||||
if val:
|
||||
prompt_parts.append(val)
|
||||
base = "\n\n".join(prompt_parts) if prompt_parts else None
|
||||
base = "\n\n".join(collect_prompt_parts(skill_config)) or None
|
||||
|
||||
# v7: 注入激活前置条件(软检查)
|
||||
preconditions = getattr(skill_config, "preconditions", None)
|
||||
if preconditions:
|
||||
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
|
||||
lines.extend(f"- {p}" for p in preconditions)
|
||||
lines.append(
|
||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
||||
)
|
||||
preconditions_block = "\n".join(lines)
|
||||
preconditions_block = format_preconditions_block(preconditions)
|
||||
return f"{base}\n\n{preconditions_block}" if base else preconditions_block
|
||||
return base
|
||||
|
||||
|
|
|
|||
|
|
@ -90,15 +90,12 @@ class MiddlewareChain:
|
|||
return await handler(ctx)
|
||||
|
||||
# before: 外 → 内
|
||||
# before 异常自然传播,不执行 after 链
|
||||
executed_befores: list[Middleware] = []
|
||||
current_ctx = ctx
|
||||
try:
|
||||
for mw in self._middlewares:
|
||||
current_ctx = await mw.before(current_ctx)
|
||||
executed_befores.append(mw)
|
||||
except Exception:
|
||||
# before 异常:不执行 after,直接传播
|
||||
raise
|
||||
|
||||
# handler
|
||||
result = await handler(current_ctx)
|
||||
|
|
@ -152,12 +149,11 @@ class SummarizationMiddleware:
|
|||
class TokenUsageMiddleware:
|
||||
"""Token 计量中间件 — 记录请求的 token 用量。
|
||||
|
||||
before: 记录起始时间
|
||||
before: 无操作
|
||||
after: 从 result 中提取 token usage,记录到 metadata
|
||||
"""
|
||||
|
||||
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
|
||||
return ctx
|
||||
|
||||
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||||
|
|
@ -191,17 +187,18 @@ class LoopDetectionMiddleware:
|
|||
if len(trajectory) < self._threshold:
|
||||
return result
|
||||
|
||||
# 检查最终 trajectory 中的重复工具调用模式
|
||||
# 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
|
||||
recent = trajectory[-self._window_size :] if trajectory else []
|
||||
tool_calls = [
|
||||
(step.get("tool_name", ""), step.get("arguments_hash", ""))
|
||||
for step in trajectory
|
||||
for step in recent
|
||||
if isinstance(step, dict) and "tool_name" in step
|
||||
]
|
||||
if not tool_calls:
|
||||
return result
|
||||
|
||||
# 滑动窗口检测
|
||||
window = tool_calls[-self._window_size :]
|
||||
window = tool_calls
|
||||
unique = set(window)
|
||||
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -98,8 +98,7 @@ async def _ensure_async_iterable(obj: Any, label: str = "<obj>"):
|
|||
yield item
|
||||
return
|
||||
raise TypeError(
|
||||
f"{label}: awaited value is not async iterable "
|
||||
f"(got {type(resolved).__name__})"
|
||||
f"{label}: awaited value is not async iterable (got {type(resolved).__name__})"
|
||||
)
|
||||
|
||||
# Case 3: anything else — surface a clear, actionable error rather
|
||||
|
|
@ -218,19 +217,17 @@ class ReActEngine:
|
|||
|
||||
ponytail: 精确 hash 匹配,不做语义相似度。
|
||||
"""
|
||||
hash_to_name: dict[str, str] = {}
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||
h = hash(f"{tc.name}:{args_str}")
|
||||
self._loop_window.append(str(h))
|
||||
h = str(hash(f"{tc.name}:{args_str}"))
|
||||
self._loop_window.append(h)
|
||||
hash_to_name[h] = tc.name
|
||||
|
||||
counts = Counter(self._loop_window)
|
||||
for h, count in counts.items():
|
||||
if count >= self._loop_threshold:
|
||||
# Find the tool name for this hash
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||
if str(hash(f"{tc.name}:{args_str}")) == h:
|
||||
return tc.name
|
||||
return hash_to_name.get(h)
|
||||
return None
|
||||
|
||||
async def execute(
|
||||
|
|
@ -1162,6 +1159,13 @@ class ReActEngine:
|
|||
)
|
||||
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# Record assistant message
|
||||
assistant_msg: dict[str, Any] = {
|
||||
|
|
|
|||
|
|
@ -116,13 +116,9 @@ class PipelineCheckpoint:
|
|||
# 尝试写入 Redis
|
||||
if self._redis is not None:
|
||||
try:
|
||||
await self._redis.set(
|
||||
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
|
||||
)
|
||||
await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
||||
|
||||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
||||
"""加载完整 plan JSON。"""
|
||||
|
|
@ -133,9 +129,7 @@ class PipelineCheckpoint:
|
|||
if raw:
|
||||
return json.loads(raw)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
|
||||
# 内存降级(检查 TTL 过期)
|
||||
entry = self._memory_plans.get(plan_id)
|
||||
if entry is None:
|
||||
|
|
@ -217,14 +211,17 @@ class PipelineCheckpoint:
|
|||
if not phase_ids:
|
||||
# Redis 无数据,检查内存(过滤过期)
|
||||
return [
|
||||
c
|
||||
for c in self._memory.get(plan_id, [])
|
||||
if not self._is_expired(c.saved_at)
|
||||
c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)
|
||||
]
|
||||
|
||||
results: list[CheckpointData] = []
|
||||
# 批量 GET(pipeline 避免 N+1 往返)
|
||||
pipe = self._redis.pipeline()
|
||||
for phase_id in phase_ids:
|
||||
raw = await self._redis.get(self._key(plan_id, phase_id))
|
||||
pipe.get(self._key(plan_id, phase_id))
|
||||
raws = await pipe.execute()
|
||||
|
||||
results: list[CheckpointData] = []
|
||||
for raw in raws:
|
||||
if raw:
|
||||
data = json.loads(raw)
|
||||
results.append(CheckpointData.from_dict(data))
|
||||
|
|
@ -255,6 +252,4 @@ class PipelineCheckpoint:
|
|||
pipe.delete(self._plan_key(plan_id))
|
||||
await pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from __future__ import annotations
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.chat.skill_routing import collect_prompt_parts, format_preconditions_block
|
||||
from agentkit.core.exceptions import SkillNotFoundError
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -91,7 +93,7 @@ class SkillDetailTool(Tool):
|
|||
try:
|
||||
skill = self._registry.get(query)
|
||||
return self._format_skill_full(skill)
|
||||
except Exception:
|
||||
except SkillNotFoundError:
|
||||
pass # 非精确匹配,降级到关键词搜索
|
||||
|
||||
# 关键词搜索:匹配 skill 名称和描述
|
||||
|
|
@ -118,21 +120,12 @@ class SkillDetailTool(Tool):
|
|||
def _format_skill_full(skill: Any) -> dict[str, Any]:
|
||||
"""格式化 skill 的完整 instructions 供 LLM 使用。"""
|
||||
config = skill.config
|
||||
prompt_parts: list[str] = []
|
||||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
||||
val = (config.prompt or {}).get(key)
|
||||
if val:
|
||||
prompt_parts.append(f"### {key.title()}\n{val}")
|
||||
prompt_parts = collect_prompt_parts(config, with_headers=True)
|
||||
|
||||
# v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略)
|
||||
preconditions = getattr(config, "preconditions", None)
|
||||
if preconditions:
|
||||
lines = ["### Activation Preconditions", "Before executing this skill, verify:"]
|
||||
lines.extend(f"- {p}" for p in preconditions)
|
||||
lines.append(
|
||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
||||
)
|
||||
prompt_parts.append("\n".join(lines))
|
||||
prompt_parts.append(format_preconditions_block(preconditions, header_level=3))
|
||||
|
||||
return {
|
||||
"name": config.name,
|
||||
|
|
|
|||
|
|
@ -428,12 +428,12 @@ class TestTokenUsageMiddleware:
|
|||
"""TokenUsageMiddleware 测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_initializes_metadata(self) -> None:
|
||||
"""before 初始化 token_usage_start。"""
|
||||
async def test_before_returns_ctx_unchanged(self) -> None:
|
||||
"""before 返回 ctx 不做修改(token 计量在 after 中完成)。"""
|
||||
mw = TokenUsageMiddleware()
|
||||
ctx = _make_ctx()
|
||||
await mw.before(ctx)
|
||||
assert "token_usage_start" in ctx.metadata
|
||||
result = await mw.before(ctx)
|
||||
assert result is ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_extracts_usage_from_result(self) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue