From 567cbc9c9b8ce49299e8eb009047ca06698052b1 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Wed, 24 Jun 2026 22:35:52 +0800 Subject: [PATCH] refactor: simplify code across U1-U7 (bug fix + efficiency + reuse + quality) --- src/agentkit/chat/skill_routing.py | 46 +++++++++++++++---------- src/agentkit/core/middleware.py | 21 +++++------ src/agentkit/core/react.py | 22 +++++++----- src/agentkit/orchestrator/checkpoint.py | 29 +++++++--------- src/agentkit/skills/skill_detail.py | 17 +++------ tests/unit/test_middleware.py | 8 ++--- 6 files changed, 71 insertions(+), 72 deletions(-) diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index c0e2d4d..9972173 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -98,6 +98,31 @@ def parse_skill_prefix(content: str) -> tuple[str | None, str]: return explicit_skill, clean +_PROMPT_KEYS = ("identity", "context", "instructions", "constraints", "output_format") + + +def format_preconditions_block(preconditions: list[str], header_level: int = 2) -> str: + """格式化激活前置条件文本块。""" + header = "#" * header_level + lines = [f"{header} Activation Preconditions", "Before executing this skill, verify:"] + lines.extend(f"- {p}" for p in preconditions) + lines.append( + "If any precondition is not met, refuse to execute or ask the user for clarification." + ) + return "\n".join(lines) + + +def collect_prompt_parts(config: Any, with_headers: bool = False) -> list[str]: + """从 skill config 的 prompt 字典中收集各部分文本。""" + prompt = config.prompt or {} + parts: list[str] = [] + for key in _PROMPT_KEYS: + val = prompt.get(key) + if val: + parts.append(f"### {key.title()}\n{val}" if with_headers else val) + return parts + + def build_skill_system_prompt(skill_config) -> str | None: """Build system prompt from skill config's prompt section. @@ -120,12 +145,7 @@ def build_skill_system_prompt(skill_config) -> str | None: # 安全守卫:前置条件即使在概要模式下也必须注入 preconditions = getattr(skill_config, "preconditions", None) if preconditions: - lines = ["## Activation Preconditions", "Before executing this skill, verify:"] - lines.extend(f"- {p}" for p in preconditions) - lines.append( - "If any precondition is not met, refuse to execute or ask the user for clarification." - ) - return f"{summary}\n\n" + "\n".join(lines) + return f"{summary}\n\n" + format_preconditions_block(preconditions) # 提示 LLM 可通过 skill_detail 工具加载完整 instructions return ( f"{summary}\n\n" @@ -135,22 +155,12 @@ def build_skill_system_prompt(skill_config) -> str | None: # Level 1+: 全量加载(现有行为) if not skill_config.prompt: return None - prompt_parts = [] - for key in ("identity", "context", "instructions", "constraints", "output_format"): - val = skill_config.prompt.get(key) - if val: - prompt_parts.append(val) - base = "\n\n".join(prompt_parts) if prompt_parts else None + base = "\n\n".join(collect_prompt_parts(skill_config)) or None # v7: 注入激活前置条件(软检查) preconditions = getattr(skill_config, "preconditions", None) if preconditions: - lines = ["## Activation Preconditions", "Before executing this skill, verify:"] - lines.extend(f"- {p}" for p in preconditions) - lines.append( - "If any precondition is not met, refuse to execute or ask the user for clarification." - ) - preconditions_block = "\n".join(lines) + preconditions_block = format_preconditions_block(preconditions) return f"{base}\n\n{preconditions_block}" if base else preconditions_block return base diff --git a/src/agentkit/core/middleware.py b/src/agentkit/core/middleware.py index eb02dbc..3a48b90 100644 --- a/src/agentkit/core/middleware.py +++ b/src/agentkit/core/middleware.py @@ -90,15 +90,12 @@ class MiddlewareChain: return await handler(ctx) # before: 外 → 内 + # before 异常自然传播,不执行 after 链 executed_befores: list[Middleware] = [] current_ctx = ctx - try: - for mw in self._middlewares: - current_ctx = await mw.before(current_ctx) - executed_befores.append(mw) - except Exception: - # before 异常:不执行 after,直接传播 - raise + for mw in self._middlewares: + current_ctx = await mw.before(current_ctx) + executed_befores.append(mw) # handler result = await handler(current_ctx) @@ -152,12 +149,11 @@ class SummarizationMiddleware: class TokenUsageMiddleware: """Token 计量中间件 — 记录请求的 token 用量。 - before: 记录起始时间 + before: 无操作 after: 从 result 中提取 token usage,记录到 metadata """ async def before(self, ctx: RequestContext) -> RequestContext: - ctx.metadata["token_usage_start"] = ctx.metadata.get("token_usage_start", 0) return ctx async def after(self, ctx: RequestContext, result: Any) -> Any: @@ -191,17 +187,18 @@ class LoopDetectionMiddleware: if len(trajectory) < self._threshold: return result - # 检查最终 trajectory 中的重复工具调用模式 + # 检查最终 trajectory 中的重复工具调用模式(只取尾部窗口) + recent = trajectory[-self._window_size :] if trajectory else [] tool_calls = [ (step.get("tool_name", ""), step.get("arguments_hash", "")) - for step in trajectory + for step in recent if isinstance(step, dict) and "tool_name" in step ] if not tool_calls: return result # 滑动窗口检测 - window = tool_calls[-self._window_size :] + window = tool_calls unique = set(window) if len(unique) < len(window) and len(window) - len(unique) >= self._threshold - 1: logger.warning( diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 22fdfc2..cbdd600 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -98,8 +98,7 @@ async def _ensure_async_iterable(obj: Any, label: str = ""): yield item return raise TypeError( - f"{label}: awaited value is not async iterable " - f"(got {type(resolved).__name__})" + f"{label}: awaited value is not async iterable (got {type(resolved).__name__})" ) # Case 3: anything else — surface a clear, actionable error rather @@ -218,19 +217,17 @@ class ReActEngine: ponytail: 精确 hash 匹配,不做语义相似度。 """ + hash_to_name: dict[str, str] = {} for tc in tool_calls: args_str = json.dumps(tc.arguments, sort_keys=True, default=str) - h = hash(f"{tc.name}:{args_str}") - self._loop_window.append(str(h)) + h = str(hash(f"{tc.name}:{args_str}")) + self._loop_window.append(h) + hash_to_name[h] = tc.name counts = Counter(self._loop_window) for h, count in counts.items(): if count >= self._loop_threshold: - # Find the tool name for this hash - for tc in tool_calls: - args_str = json.dumps(tc.arguments, sort_keys=True, default=str) - if str(hash(f"{tc.name}:{args_str}")) == h: - return tc.name + return hash_to_name.get(h) return None async def execute( @@ -1162,6 +1159,13 @@ class ReActEngine: ) # 记录 LLM 调用步骤 + if trace_recorder is not None: + trace_recorder.record_step( + step=step, + action="llm_call", + duration_ms=llm_duration_ms, + tokens_used=step_tokens, + ) # Record assistant message assistant_msg: dict[str, Any] = { diff --git a/src/agentkit/orchestrator/checkpoint.py b/src/agentkit/orchestrator/checkpoint.py index 4c115bf..a2433a4 100644 --- a/src/agentkit/orchestrator/checkpoint.py +++ b/src/agentkit/orchestrator/checkpoint.py @@ -116,13 +116,9 @@ class PipelineCheckpoint: # 尝试写入 Redis if self._redis is not None: try: - await self._redis.set( - self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl - ) + await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl) except Exception as e: - logger.warning( - f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}" - ) + logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}") async def load_plan(self, plan_id: str) -> dict[str, Any] | None: """加载完整 plan JSON。""" @@ -133,9 +129,7 @@ class PipelineCheckpoint: if raw: return json.loads(raw) except Exception as e: - logger.warning( - f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}" - ) + logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}") # 内存降级(检查 TTL 过期) entry = self._memory_plans.get(plan_id) if entry is None: @@ -217,14 +211,17 @@ class PipelineCheckpoint: if not phase_ids: # Redis 无数据,检查内存(过滤过期) return [ - c - for c in self._memory.get(plan_id, []) - if not self._is_expired(c.saved_at) + c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at) ] - results: list[CheckpointData] = [] + # 批量 GET(pipeline 避免 N+1 往返) + pipe = self._redis.pipeline() for phase_id in phase_ids: - raw = await self._redis.get(self._key(plan_id, phase_id)) + pipe.get(self._key(plan_id, phase_id)) + raws = await pipe.execute() + + results: list[CheckpointData] = [] + for raw in raws: if raw: data = json.loads(raw) results.append(CheckpointData.from_dict(data)) @@ -255,6 +252,4 @@ class PipelineCheckpoint: pipe.delete(self._plan_key(plan_id)) await pipe.execute() except Exception as e: - logger.warning( - f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}" - ) + logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}") diff --git a/src/agentkit/skills/skill_detail.py b/src/agentkit/skills/skill_detail.py index bcdf6ac..f2882be 100644 --- a/src/agentkit/skills/skill_detail.py +++ b/src/agentkit/skills/skill_detail.py @@ -14,6 +14,8 @@ from __future__ import annotations import logging from typing import Any +from agentkit.chat.skill_routing import collect_prompt_parts, format_preconditions_block +from agentkit.core.exceptions import SkillNotFoundError from agentkit.tools.base import Tool logger = logging.getLogger(__name__) @@ -91,7 +93,7 @@ class SkillDetailTool(Tool): try: skill = self._registry.get(query) return self._format_skill_full(skill) - except Exception: + except SkillNotFoundError: pass # 非精确匹配,降级到关键词搜索 # 关键词搜索:匹配 skill 名称和描述 @@ -118,21 +120,12 @@ class SkillDetailTool(Tool): def _format_skill_full(skill: Any) -> dict[str, Any]: """格式化 skill 的完整 instructions 供 LLM 使用。""" config = skill.config - prompt_parts: list[str] = [] - for key in ("identity", "context", "instructions", "constraints", "output_format"): - val = (config.prompt or {}).get(key) - if val: - prompt_parts.append(f"### {key.title()}\n{val}") + prompt_parts = collect_prompt_parts(config, with_headers=True) # v7: 注入激活前置条件(安全守卫,即使在渐进加载模式下也不应省略) preconditions = getattr(config, "preconditions", None) if preconditions: - lines = ["### Activation Preconditions", "Before executing this skill, verify:"] - lines.extend(f"- {p}" for p in preconditions) - lines.append( - "If any precondition is not met, refuse to execute or ask the user for clarification." - ) - prompt_parts.append("\n".join(lines)) + prompt_parts.append(format_preconditions_block(preconditions, header_level=3)) return { "name": config.name, diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py index 2d68ce6..9dc6145 100644 --- a/tests/unit/test_middleware.py +++ b/tests/unit/test_middleware.py @@ -428,12 +428,12 @@ class TestTokenUsageMiddleware: """TokenUsageMiddleware 测试。""" @pytest.mark.asyncio - async def test_before_initializes_metadata(self) -> None: - """before 初始化 token_usage_start。""" + async def test_before_returns_ctx_unchanged(self) -> None: + """before 返回 ctx 不做修改(token 计量在 after 中完成)。""" mw = TokenUsageMiddleware() ctx = _make_ctx() - await mw.before(ctx) - assert "token_usage_start" in ctx.metadata + result = await mw.before(ctx) + assert result is ctx @pytest.mark.asyncio async def test_after_extracts_usage_from_result(self) -> None: