refactor: simplify code across U1-U7 (bug fix + efficiency + reuse + quality)
This commit is contained in:
parent
0847c0e086
commit
567cbc9c9b
|
|
@ -98,6 +98,31 @@ def parse_skill_prefix(content: str) -> tuple[str | None, str]:
|
||||||
return explicit_skill, clean
|
return explicit_skill, clean
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPT_KEYS = ("identity", "context", "instructions", "constraints", "output_format")
|
||||||
|
|
||||||
|
|
||||||
|
def format_preconditions_block(preconditions: list[str], header_level: int = 2) -> str:
|
||||||
|
"""格式化激活前置条件文本块。"""
|
||||||
|
header = "#" * header_level
|
||||||
|
lines = [f"{header} Activation Preconditions", "Before executing this skill, verify:"]
|
||||||
|
lines.extend(f"- {p}" for p in preconditions)
|
||||||
|
lines.append(
|
||||||
|
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_prompt_parts(config: Any, with_headers: bool = False) -> list[str]:
|
||||||
|
"""从 skill config 的 prompt 字典中收集各部分文本。"""
|
||||||
|
prompt = config.prompt or {}
|
||||||
|
parts: list[str] = []
|
||||||
|
for key in _PROMPT_KEYS:
|
||||||
|
val = prompt.get(key)
|
||||||
|
if val:
|
||||||
|
parts.append(f"### {key.title()}\n{val}" if with_headers else val)
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def build_skill_system_prompt(skill_config) -> str | None:
|
def build_skill_system_prompt(skill_config) -> str | None:
|
||||||
"""Build system prompt from skill config's prompt section.
|
"""Build system prompt from skill config's prompt section.
|
||||||
|
|
||||||
|
|
@ -120,12 +145,7 @@ def build_skill_system_prompt(skill_config) -> str | None:
|
||||||
# 安全守卫:前置条件即使在概要模式下也必须注入
|
# 安全守卫:前置条件即使在概要模式下也必须注入
|
||||||
preconditions = getattr(skill_config, "preconditions", None)
|
preconditions = getattr(skill_config, "preconditions", None)
|
||||||
if preconditions:
|
if preconditions:
|
||||||
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
|
return f"{summary}\n\n" + format_preconditions_block(preconditions)
|
||||||
lines.extend(f"- {p}" for p in preconditions)
|
|
||||||
lines.append(
|
|
||||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
|
||||||
)
|
|
||||||
return f"{summary}\n\n" + "\n".join(lines)
|
|
||||||
# 提示 LLM 可通过 skill_detail 工具加载完整 instructions
|
# 提示 LLM 可通过 skill_detail 工具加载完整 instructions
|
||||||
return (
|
return (
|
||||||
f"{summary}\n\n"
|
f"{summary}\n\n"
|
||||||
|
|
@ -135,22 +155,12 @@ def build_skill_system_prompt(skill_config) -> str | None:
|
||||||
# Level 1+: 全量加载(现有行为)
|
# Level 1+: 全量加载(现有行为)
|
||||||
if not skill_config.prompt:
|
if not skill_config.prompt:
|
||||||
return None
|
return None
|
||||||
prompt_parts = []
|
base = "\n\n".join(collect_prompt_parts(skill_config)) or None
|
||||||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
|
||||||
val = skill_config.prompt.get(key)
|
|
||||||
if val:
|
|
||||||
prompt_parts.append(val)
|
|
||||||
base = "\n\n".join(prompt_parts) if prompt_parts else None
|
|
||||||
|
|
||||||
# v7: 注入激活前置条件(软检查)
|
# v7: 注入激活前置条件(软检查)
|
||||||
preconditions = getattr(skill_config, "preconditions", None)
|
preconditions = getattr(skill_config, "preconditions", None)
|
||||||
if preconditions:
|
if preconditions:
|
||||||
lines = ["## Activation Preconditions", "Before executing this skill, verify:"]
|
preconditions_block = format_preconditions_block(preconditions)
|
||||||
lines.extend(f"- {p}" for p in preconditions)
|
|
||||||
lines.append(
|
|
||||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
|
||||||
)
|
|
||||||
preconditions_block = "\n".join(lines)
|
|
||||||
return f"{base}\n\n{preconditions_block}" if base else preconditions_block
|
return f"{base}\n\n{preconditions_block}" if base else preconditions_block
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,15 +90,12 @@ class MiddlewareChain:
|
||||||
return await handler(ctx)
|
return await handler(ctx)
|
||||||
|
|
||||||
# before: 外 → 内
|
# before: 外 → 内
|
||||||
|
# before 异常自然传播,不执行 after 链
|
||||||
executed_befores: list[Middleware] = []
|
executed_befores: list[Middleware] = []
|
||||||
current_ctx = ctx
|
current_ctx = ctx
|
||||||
try:
|
for mw in self._middlewares:
|
||||||
for mw in self._middlewares:
|
current_ctx = await mw.before(current_ctx)
|
||||||
current_ctx = await mw.before(current_ctx)
|
executed_befores.append(mw)
|
||||||
executed_befores.append(mw)
|
|
||||||
except Exception:
|
|
||||||
# before 异常:不执行 after,直接传播
|
|
||||||
raise
|
|
||||||
|
|
||||||
# handler
|
# handler
|
||||||
result = await handler(current_ctx)
|
result = await handler(current_ctx)
|
||||||
|
|
@ -152,12 +149,11 @@ class SummarizationMiddleware:
|
||||||
class TokenUsageMiddleware:
|
class TokenUsageMiddleware:
|
||||||
"""Token 计量中间件 — 记录请求的 token 用量。
|
"""Token 计量中间件 — 记录请求的 token 用量。
|
||||||
|
|
||||||
before: 记录起始时间
|
before: 无操作
|
||||||
after: 从 result 中提取 token usage,记录到 metadata
|
after: 从 result 中提取 token usage,记录到 metadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def before(self, ctx: RequestContext) -> RequestContext:
|
async def before(self, ctx: RequestContext) -> RequestContext:
|
||||||
ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0)
|
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
async def after(self, ctx: RequestContext, result: Any) -> Any:
|
||||||
|
|
@ -191,17 +187,18 @@ class LoopDetectionMiddleware:
|
||||||
if len(trajectory) < self._threshold:
|
if len(trajectory) < self._threshold:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 检查最终 trajectory 中的重复工具调用模式
|
# 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口)
|
||||||
|
recent = trajectory[-self._window_size :] if trajectory else []
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
(step.get("tool_name", ""), step.get("arguments_hash", ""))
|
(step.get("tool_name", ""), step.get("arguments_hash", ""))
|
||||||
for step in trajectory
|
for step in recent
|
||||||
if isinstance(step, dict) and "tool_name" in step
|
if isinstance(step, dict) and "tool_name" in step
|
||||||
]
|
]
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 滑动窗口检测
|
# 滑动窗口检测
|
||||||
window = tool_calls[-self._window_size :]
|
window = tool_calls
|
||||||
unique = set(window)
|
unique = set(window)
|
||||||
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
|
if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -98,8 +98,7 @@ async def _ensure_async_iterable(obj: Any, label: str = "<obj>"):
|
||||||
yield item
|
yield item
|
||||||
return
|
return
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{label}: awaited value is not async iterable "
|
f"{label}: awaited value is not async iterable (got {type(resolved).__name__})"
|
||||||
f"(got {type(resolved).__name__})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Case 3: anything else — surface a clear, actionable error rather
|
# Case 3: anything else — surface a clear, actionable error rather
|
||||||
|
|
@ -218,19 +217,17 @@ class ReActEngine:
|
||||||
|
|
||||||
ponytail: 精确 hash 匹配,不做语义相似度。
|
ponytail: 精确 hash 匹配,不做语义相似度。
|
||||||
"""
|
"""
|
||||||
|
hash_to_name: dict[str, str] = {}
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||||
h = hash(f"{tc.name}:{args_str}")
|
h = str(hash(f"{tc.name}:{args_str}"))
|
||||||
self._loop_window.append(str(h))
|
self._loop_window.append(h)
|
||||||
|
hash_to_name[h] = tc.name
|
||||||
|
|
||||||
counts = Counter(self._loop_window)
|
counts = Counter(self._loop_window)
|
||||||
for h, count in counts.items():
|
for h, count in counts.items():
|
||||||
if count >= self._loop_threshold:
|
if count >= self._loop_threshold:
|
||||||
# Find the tool name for this hash
|
return hash_to_name.get(h)
|
||||||
for tc in tool_calls:
|
|
||||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
|
||||||
if str(hash(f"{tc.name}:{args_str}")) == h:
|
|
||||||
return tc.name
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
|
|
@ -1162,6 +1159,13 @@ class ReActEngine:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录 LLM 调用步骤
|
# 记录 LLM 调用步骤
|
||||||
|
if trace_recorder is not None:
|
||||||
|
trace_recorder.record_step(
|
||||||
|
step=step,
|
||||||
|
action="llm_call",
|
||||||
|
duration_ms=llm_duration_ms,
|
||||||
|
tokens_used=step_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Record assistant message
|
# Record assistant message
|
||||||
assistant_msg: dict[str, Any] = {
|
assistant_msg: dict[str, Any] = {
|
||||||
|
|
|
||||||
|
|
@ -116,13 +116,9 @@ class PipelineCheckpoint:
|
||||||
# 尝试写入 Redis
|
# 尝试写入 Redis
|
||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
try:
|
try:
|
||||||
await self._redis.set(
|
await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
|
||||||
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
||||||
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
||||||
"""加载完整 plan JSON。"""
|
"""加载完整 plan JSON。"""
|
||||||
|
|
@ -133,9 +129,7 @@ class PipelineCheckpoint:
|
||||||
if raw:
|
if raw:
|
||||||
return json.loads(raw)
|
return json.loads(raw)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
|
||||||
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
|
||||||
)
|
|
||||||
# 内存降级(检查 TTL 过期)
|
# 内存降级(检查 TTL 过期)
|
||||||
entry = self._memory_plans.get(plan_id)
|
entry = self._memory_plans.get(plan_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
|
|
@ -217,14 +211,17 @@ class PipelineCheckpoint:
|
||||||
if not phase_ids:
|
if not phase_ids:
|
||||||
# Redis 无数据,检查内存(过滤过期)
|
# Redis 无数据,检查内存(过滤过期)
|
||||||
return [
|
return [
|
||||||
c
|
c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)
|
||||||
for c in self._memory.get(plan_id, [])
|
|
||||||
if not self._is_expired(c.saved_at)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
results: list[CheckpointData] = []
|
# 批量 GET(pipeline 避免 N+1 往返)
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
for phase_id in phase_ids:
|
for phase_id in phase_ids:
|
||||||
raw = await self._redis.get(self._key(plan_id, phase_id))
|
pipe.get(self._key(plan_id, phase_id))
|
||||||
|
raws = await pipe.execute()
|
||||||
|
|
||||||
|
results: list[CheckpointData] = []
|
||||||
|
for raw in raws:
|
||||||
if raw:
|
if raw:
|
||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
results.append(CheckpointData.from_dict(data))
|
results.append(CheckpointData.from_dict(data))
|
||||||
|
|
@ -255,6 +252,4 @@ class PipelineCheckpoint:
|
||||||
pipe.delete(self._plan_key(plan_id))
|
pipe.delete(self._plan_key(plan_id))
|
||||||
await pipe.execute()
|
await pipe.execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")
|
||||||
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.chat.skill_routing import collect_prompt_parts, format_preconditions_block
|
||||||
|
from agentkit.core.exceptions import SkillNotFoundError
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -91,7 +93,7 @@ class SkillDetailTool(Tool):
|
||||||
try:
|
try:
|
||||||
skill = self._registry.get(query)
|
skill = self._registry.get(query)
|
||||||
return self._format_skill_full(skill)
|
return self._format_skill_full(skill)
|
||||||
except Exception:
|
except SkillNotFoundError:
|
||||||
pass # 非精确匹配,降级到关键词搜索
|
pass # 非精确匹配,降级到关键词搜索
|
||||||
|
|
||||||
# 关键词搜索:匹配 skill 名称和描述
|
# 关键词搜索:匹配 skill 名称和描述
|
||||||
|
|
@ -118,21 +120,12 @@ class SkillDetailTool(Tool):
|
||||||
def _format_skill_full(skill: Any) -> dict[str, Any]:
|
def _format_skill_full(skill: Any) -> dict[str, Any]:
|
||||||
"""格式化 skill 的完整 instructions 供 LLM 使用。"""
|
"""格式化 skill 的完整 instructions 供 LLM 使用。"""
|
||||||
config = skill.config
|
config = skill.config
|
||||||
prompt_parts: list[str] = []
|
prompt_parts = collect_prompt_parts(config, with_headers=True)
|
||||||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
|
||||||
val = (config.prompt or {}).get(key)
|
|
||||||
if val:
|
|
||||||
prompt_parts.append(f"### {key.title()}\n{val}")
|
|
||||||
|
|
||||||
# v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略)
|
# v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略)
|
||||||
preconditions = getattr(config, "preconditions", None)
|
preconditions = getattr(config, "preconditions", None)
|
||||||
if preconditions:
|
if preconditions:
|
||||||
lines = ["### Activation Preconditions", "Before executing this skill, verify:"]
|
prompt_parts.append(format_preconditions_block(preconditions, header_level=3))
|
||||||
lines.extend(f"- {p}" for p in preconditions)
|
|
||||||
lines.append(
|
|
||||||
"If any precondition is not met, refuse to execute or ask the user for clarification."
|
|
||||||
)
|
|
||||||
prompt_parts.append("\n".join(lines))
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": config.name,
|
"name": config.name,
|
||||||
|
|
|
||||||
|
|
@ -428,12 +428,12 @@ class TestTokenUsageMiddleware:
|
||||||
"""TokenUsageMiddleware 测试。"""
|
"""TokenUsageMiddleware 测试。"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_before_initializes_metadata(self) -> None:
|
async def test_before_returns_ctx_unchanged(self) -> None:
|
||||||
"""before 初始化 token_usage_start。"""
|
"""before 返回 ctx 不做修改(token 计量在 after 中完成)。"""
|
||||||
mw = TokenUsageMiddleware()
|
mw = TokenUsageMiddleware()
|
||||||
ctx = _make_ctx()
|
ctx = _make_ctx()
|
||||||
await mw.before(ctx)
|
result = await mw.before(ctx)
|
||||||
assert "token_usage_start" in ctx.metadata
|
assert result is ctx
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_after_extracts_usage_from_result(self) -> None:
|
async def test_after_extracts_usage_from_result(self) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue