From 0456429beb470649d8ec1cbc48e4746e5731e929 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sat, 6 Jun 2026 18:20:46 +0800 Subject: [PATCH] fix(review): address all 14 P2 advisory findings --- src/agentkit/core/compressor.py | 4 +- src/agentkit/core/trace.py | 8 +- src/agentkit/evolution/evolution_store.py | 12 +++ src/agentkit/llm/gateway.py | 5 ++ src/agentkit/memory/episodic.py | 26 ++++++- src/agentkit/prompts/template.py | 5 +- src/agentkit/server/routes/health.py | 2 +- src/agentkit/server/routes/metrics.py | 20 ++--- src/agentkit/server/routes/skills.py | 7 +- src/agentkit/server/task_store.py | 92 ++++++++++++++++++++--- src/agentkit/skills/pipeline.py | 20 ++--- src/agentkit/skills/skill_md.py | 6 ++ tests/unit/test_episodic_memory.py | 1 - tests/unit/test_observability.py | 3 - tests/unit/test_task_store_redis.py | 44 ++++++++++- 15 files changed, 208 insertions(+), 47 deletions(-) diff --git a/src/agentkit/core/compressor.py b/src/agentkit/core/compressor.py index 368b47b..0c8fc28 100644 --- a/src/agentkit/core/compressor.py +++ b/src/agentkit/core/compressor.py @@ -151,8 +151,8 @@ class ContextCompressor: result = [] for msg in messages: content = str(msg.get("content", "")) - if len(content) > self._max_tokens * 2: - msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"} + if len(content) > self._max_tokens * 4: + msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"} result.append(msg) return result diff --git a/src/agentkit/core/trace.py b/src/agentkit/core/trace.py index c64f726..52e1711 100644 --- a/src/agentkit/core/trace.py +++ b/src/agentkit/core/trace.py @@ -7,7 +7,7 @@ import time import uuid from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable @dataclass @@ -83,12 +83,14 @@ class TraceRecorder: task_id: str = "", agent_name: str = "", skill_name: str | None = None, + on_trace_complete: Callable[[ExecutionTrace], None] | None = None, ): self._trace: ExecutionTrace | None = None self._completed_trace: ExecutionTrace | None = None self._completed: bool = False self._step_start_time: float = 0 self._trace_start_time: float = 0 + self._on_trace_complete = on_trace_complete # 如果构造时提供了参数,自动 start_trace if task_id: self.start_trace(task_id=task_id, agent_name=agent_name, skill_name=skill_name) @@ -165,6 +167,10 @@ class TraceRecorder: self._completed = True self._completed_trace = result self._trace = None + + if self._on_trace_complete is not None: + self._on_trace_complete(result) + return result def get_trace(self) -> ExecutionTrace | None: diff --git a/src/agentkit/evolution/evolution_store.py b/src/agentkit/evolution/evolution_store.py index 36e80e0..d738ab6 100644 --- a/src/agentkit/evolution/evolution_store.py +++ b/src/agentkit/evolution/evolution_store.py @@ -162,6 +162,18 @@ class PersistentEvolutionStore: loop = asyncio.get_running_loop() return loop.run_in_executor(None, func) + async def close(self) -> None: + """Dispose the SQLAlchemy engine, releasing all pooled connections.""" + if self._engine is not None: + await self._run_sync(self._engine.dispose) + self._engine = None + + async def __aenter__(self) -> "PersistentEvolutionStore": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() + @staticmethod def _retry_locked(func, *args, max_retries: int = 5, base_delay: float = 0.05, **kwargs): """Retry a function on SQLite 'database is locked' OperationalError.""" diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 33885d4..08b1585 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -24,6 +24,11 @@ class LLMGateway: self._providers[name] = provider logger.info(f"LLM provider '{name}' registered") + @property + def has_providers(self) -> bool: + """Return True if at least one LLM provider is registered.""" + return bool(self._providers) + async def chat( self, messages: list[dict[str, str]], diff --git a/src/agentkit/memory/episodic.py b/src/agentkit/memory/episodic.py index 75b3efc..d02595d 100644 --- a/src/agentkit/memory/episodic.py +++ b/src/agentkit/memory/episodic.py @@ -1,5 +1,6 @@ """Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆""" +import json import logging import math from datetime import datetime, timezone @@ -53,7 +54,10 @@ class EpisodicMemory(Memory): # 生成 embedding embedding = None if self._embedder: - text = f"{key} {value}" + if isinstance(value, dict): + text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500] + else: + text = str(value) embedding = await self._embedder.embed(text) entry = Model( @@ -131,8 +135,16 @@ class EpisodicMemory(Memory): logger.error(f"Failed to retrieve episodic memory: {e}") return None - async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: - """语义检索相似历史案例""" + async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]: + """语义检索相似历史案例 + + Args: + query: 搜索查询文本。 + top_k: 返回的最大结果数。 + filters: 可选过滤条件(agent_name, task_type, outcome)。 + search_multiplier: 预取行数倍数(fetch top_k * search_multiplier 行后再 + 排序截断)。当过滤条件较严格时,可增大此值以避免漏掉相关条目。 + """ async with self._session_factory() as db: try: Model = self._episodic_model @@ -149,7 +161,7 @@ class EpisodicMemory(Memory): if filters.get("outcome"): stmt = stmt.where(Model.outcome == filters["outcome"]) - stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 5) + stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * search_multiplier) result = await db.execute(stmt) entries = result.scalars().all() @@ -192,6 +204,12 @@ class EpisodicMemory(Memory): )) items.sort(key=lambda x: x.score, reverse=True) + if len(items) < top_k: + logger.warning( + "EpisodicMemory.search returned %d results after scoring (top_k=%d). " + "Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.", + len(items), top_k, search_multiplier, + ) return items[:top_k] except Exception as e: diff --git a/src/agentkit/prompts/template.py b/src/agentkit/prompts/template.py index aba8077..c1ce98f 100644 --- a/src/agentkit/prompts/template.py +++ b/src/agentkit/prompts/template.py @@ -3,6 +3,7 @@ import hashlib import json import logging +import re from typing import Any from agentkit.prompts.section import PromptSection @@ -43,7 +44,7 @@ class PromptTemplate: context = self._sections.context if variables: for key, value in variables.items(): - context = context.replace(f"${{{key}}}", str(value)) + context = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), context) system_parts.append(context) if self._sections.constraints: system_parts.append(self._sections.constraints) @@ -53,7 +54,7 @@ class PromptTemplate: instructions = self._sections.instructions if variables: for key, value in variables.items(): - instructions = instructions.replace(f"${{{key}}}", str(value)) + instructions = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), instructions) user_parts.append(instructions) if self._sections.output_format: user_parts.append(self._sections.output_format) diff --git a/src/agentkit/server/routes/health.py b/src/agentkit/server/routes/health.py index ee14e1b..06b3fe6 100644 --- a/src/agentkit/server/routes/health.py +++ b/src/agentkit/server/routes/health.py @@ -52,7 +52,7 @@ async def health_check(request: Request): if llm_gateway: llm_status = "configured" try: - if hasattr(llm_gateway, "_providers") and llm_gateway._providers: + if llm_gateway.has_providers: llm_status = "available" else: llm_status = "no_providers" diff --git a/src/agentkit/server/routes/metrics.py b/src/agentkit/server/routes/metrics.py index 7aa1134..451002b 100644 --- a/src/agentkit/server/routes/metrics.py +++ b/src/agentkit/server/routes/metrics.py @@ -1,7 +1,11 @@ """Metrics route — /api/v1/metrics""" +import logging + from fastapi import APIRouter, Request +logger = logging.getLogger(__name__) + router = APIRouter(tags=["metrics"]) @@ -25,36 +29,32 @@ async def get_metrics(request: Request): task_metrics["completed_tasks"] = counts.get("completed", 0) task_metrics["failed_tasks"] = counts.get("failed", 0) task_metrics["pending_tasks"] = counts.get("pending", 0) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to collect task metrics: {e}") # Agent pool metrics agent_pool = getattr(app.state, "agent_pool", None) agent_metrics: dict = { "total_agents": 0, - "agent_names": [], } if agent_pool: try: agents = agent_pool.list_agents() agent_metrics["total_agents"] = len(agents) - agent_metrics["agent_names"] = [a.get("name", "") for a in agents] - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to collect agent metrics: {e}") # Skill registry metrics skill_registry = getattr(app.state, "skill_registry", None) skill_metrics: dict = { "total_skills": 0, - "skill_names": [], } if skill_registry: try: skills = skill_registry.list_skills() skill_metrics["total_skills"] = len(skills) - skill_metrics["skill_names"] = [s.name for s in skills] - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to collect skill metrics: {e}") return { "tasks": task_metrics, diff --git a/src/agentkit/server/routes/skills.py b/src/agentkit/server/routes/skills.py index 3b9587c..b10afa7 100644 --- a/src/agentkit/server/routes/skills.py +++ b/src/agentkit/server/routes/skills.py @@ -1,5 +1,7 @@ """Skill registration routes""" +import logging + from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel from typing import Any @@ -7,6 +9,8 @@ from typing import Any from agentkit.skills.base import Skill, SkillConfig from agentkit.skills.pipeline import SkillPipeline +logger = logging.getLogger(__name__) + router = APIRouter(tags=["skills"]) @@ -111,6 +115,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ try: result = await pipeline.execute(input_data=request.input_data) except Exception as e: - raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {e}") + logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Pipeline execution failed") return result diff --git a/src/agentkit/server/task_store.py b/src/agentkit/server/task_store.py index 6025cc8..d1c9d42 100644 --- a/src/agentkit/server/task_store.py +++ b/src/agentkit/server/task_store.py @@ -182,6 +182,10 @@ class InMemoryTaskStore: def size(self) -> int: return len(self._tasks) + async def health_check(self) -> bool: + """Verify the store is operational. Always returns True for in-memory backend.""" + return True + # Backward-compatible alias TaskStore = InMemoryTaskStore @@ -196,6 +200,7 @@ class RedisTaskStore: """ KEY_PREFIX = "agentkit:task:" + ZSET_KEY = "agentkit:tasks:by_time" def __init__( self, @@ -258,7 +263,9 @@ class RedisTaskStore: skill_name=skill_name, input_data=input_data, ) + score = record.created_at.timestamp() await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds) + await redis.zadd(self.ZSET_KEY, {task_id: score}) return record async def get(self, task_id: str) -> TaskRecord | None: @@ -270,27 +277,45 @@ class RedisTaskStore: return TaskRecord.from_dict(json.loads(raw)) # Lua script for atomic read-modify-write + # ARGV[1] = "1" to reset TTL (apply ex=ttl_seconds), "0" to keep existing TTL (KEEPTTL) + # ARGV[2] = ttl_seconds (only used when ARGV[1] == "1") + # ARGV[3] = number of merge fields + # ARGV[4..] = key/value pairs _UPDATE_STATUS_SCRIPT = """ +local reset_ttl = ARGV[1] +local ttl = tonumber(ARGV[2]) +local n = tonumber(ARGV[3]) local key = KEYS[1] -local ttl = tonumber(ARGV[1]) local raw = redis.call('GET', key) if raw == false then return nil end local data = cjson.decode(raw) -local n = tonumber(ARGV[2]) for i = 1, n do - local k = ARGV[2 + 2 * (i - 1) + 1] - local v = ARGV[2 + 2 * (i - 1) + 2] + local k = ARGV[3 + 2 * (i - 1) + 1] + local v = ARGV[3 + 2 * (i - 1) + 2] data[k] = v end local encoded = cjson.encode(data) -redis.call('SET', key, encoded, 'EX', ttl) +if reset_ttl == "1" then + redis.call('SET', key, encoded, 'EX', ttl) +else + redis.call('SET', key, encoded, 'KEEPTTL') +end return encoded """ - async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord: - """Update task status and optional fields atomically via Lua script.""" + async def update_status(self, task_id: str, status: TaskStatus, reset_ttl: bool = False, **kwargs) -> TaskRecord: + """Update task status and optional fields atomically via Lua script. + + Args: + task_id: Task identifier. + status: New task status. + reset_ttl: If True, reset the Redis TTL to ``ttl_seconds``. Defaults to + False so that frequent status updates on a long-running task do not + extend its lifetime indefinitely. + **kwargs: Optional fields to update (started_at, completed_at, etc.). + """ redis = await self._get_redis() key = self._key(task_id) @@ -304,7 +329,7 @@ return encoded merge_fields[k] = value # Flatten merge_fields into ARGV pairs - args = [str(self._ttl_seconds), str(len(merge_fields))] + args = ["1" if reset_ttl else "0", str(self._ttl_seconds), str(len(merge_fields))] for k, v in merge_fields.items(): args.append(k) args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v)) @@ -360,10 +385,26 @@ return encoded redis = await self._get_redis() return await self._count_keys(redis) + async def health_check(self) -> bool: + """Verify Redis connectivity by sending a PING command.""" + try: + redis = await self._get_redis() + return await redis.ping() + except Exception: + return False + # ── helpers ──────────────────────────────────────────────── async def _count_keys(self, redis) -> int: - """Count task keys using SCAN (avoid KEYS on large datasets).""" + """Count task keys. Uses ZCARD on the sorted set for O(1) when + available, falls back to SCAN otherwise.""" + try: + count = await redis.zcard(self.ZSET_KEY) + if count > 0: + return count + except Exception: + pass + # Fallback: full SCAN count = 0 cursor = 0 while True: @@ -375,8 +416,32 @@ return encoded async def _evict_oldest_completed(self, redis) -> bool: """Find and delete the oldest completed/failed/cancelled task. + Uses ZRANGE on the sorted set for O(log N) when available, + falls back to full SCAN otherwise. Returns True if a record was evicted, False otherwise. """ + # Try ZSET-based eviction first + try: + member_count = await redis.zcard(self.ZSET_KEY) + if member_count > 0: + # Iterate from oldest (lowest score) to find a completed task + task_ids = await redis.zrange(self.ZSET_KEY, 0, -1) + for tid in task_ids: + raw = await redis.get(self._key(tid)) + if raw is None: + # Stale ZSET entry – clean up + await redis.zrem(self.ZSET_KEY, tid) + continue + record = TaskRecord.from_dict(json.loads(raw)) + if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED) and record.completed_at is not None: + await redis.delete(self._key(tid)) + await redis.zrem(self.ZSET_KEY, tid) + return True + return False + except Exception: + pass + + # Fallback: full SCAN tasks: list[TaskRecord] = [] cursor = 0 while True: @@ -405,6 +470,10 @@ return encoded return False await redis.delete(self._key(oldest.task_id)) + try: + await redis.zrem(self.ZSET_KEY, oldest.task_id) + except Exception: + pass return True @@ -418,6 +487,11 @@ def create_task_store( If ``backend="redis"`` and the Redis connection cannot be established, falls back to :class:`InMemoryTaskStore` with a warning. + + Note: + This factory only validates that the ``redis`` package is importable. + Runtime connectivity should be verified via ``await store.health_check()`` + during application startup. """ if backend == "redis": try: diff --git a/src/agentkit/skills/pipeline.py b/src/agentkit/skills/pipeline.py index d1f5a2a..d5b7972 100644 --- a/src/agentkit/skills/pipeline.py +++ b/src/agentkit/skills/pipeline.py @@ -7,6 +7,7 @@ """ import logging +import re from typing import Any, Callable, Coroutine from agentkit.skills.base import Skill, SkillConfig @@ -161,19 +162,20 @@ class SkillPipeline: - "key.path > 0.5" — 数值大于 """ try: - if "==" in condition: - path, value = condition.split("==", 1) - path = path.strip() - value = value.strip().strip("'\"") + eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip()) + if eq_match: + path = eq_match.group(1) + value = eq_match.group(2).strip().strip("'\"") actual = self._resolve_path(path, current_input) return str(actual) == value - elif ">" in condition: - path, value = condition.split(">", 1) - path = path.strip() - value = float(value.strip()) + gt_match = re.match(r'^([\w.]+)\s*>\s*(.+)$', condition.strip()) + if gt_match: + path = gt_match.group(1) + value = float(gt_match.group(2).strip()) actual = float(self._resolve_path(path, current_input)) return actual > value - except Exception: + except (ValueError, TypeError, AttributeError, KeyError) as e: + logger.warning(f"Condition evaluation failed for '{condition}': {e}") return False return False diff --git a/src/agentkit/skills/skill_md.py b/src/agentkit/skills/skill_md.py index 002d3d7..c8d9c3d 100644 --- a/src/agentkit/skills/skill_md.py +++ b/src/agentkit/skills/skill_md.py @@ -30,6 +30,12 @@ class SkillMdParser: def parse(file_path: str) -> tuple[dict[str, Any], dict[str, str], str]: """解析 SKILL.md 文件 + Note: Only H1 headings (# ) are treated as section delimiters. + H2+ headings (## , ### , etc.) are treated as regular content + and merged into their parent H1 section. This is by design — + SKILL.md uses a flat section model where sub-structure within + a section is preserved as-is in the section body text. + Args: file_path: SKILL.md 文件路径 diff --git a/tests/unit/test_episodic_memory.py b/tests/unit/test_episodic_memory.py index a79f458..944bdc8 100644 --- a/tests/unit/test_episodic_memory.py +++ b/tests/unit/test_episodic_memory.py @@ -156,7 +156,6 @@ class TestEpisodicMemoryStore: mock_embedder.embed.assert_called_once() call_args = mock_embedder.embed.call_args[0][0] - assert "key1" in call_args assert "some value" in call_args # 验证 entry 的 embedding 被设置 diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py index 8f2370e..0eceb93 100644 --- a/tests/unit/test_observability.py +++ b/tests/unit/test_observability.py @@ -216,9 +216,7 @@ class TestMetricsEndpoint: assert data["tasks"]["failed_tasks"] == 0 assert data["tasks"]["pending_tasks"] == 0 assert data["agents"]["total_agents"] == 0 - assert data["agents"]["agent_names"] == [] assert data["skills"]["total_skills"] == 0 - assert data["skills"]["skill_names"] == [] def test_metrics_with_registered_skill(self, client, skill_registry): skill_config = SkillConfig( @@ -234,7 +232,6 @@ class TestMetricsEndpoint: response = client.get("/api/v1/metrics") data = response.json() assert data["skills"]["total_skills"] == 1 - assert "metrics_skill" in data["skills"]["skill_names"] def test_metrics_version(self, client): response = client.get("/api/v1/metrics") diff --git a/tests/unit/test_task_store_redis.py b/tests/unit/test_task_store_redis.py index be41af4..0f4bb4d 100644 --- a/tests/unit/test_task_store_redis.py +++ b/tests/unit/test_task_store_redis.py @@ -26,6 +26,7 @@ class FakeRedis: def __init__(self): self._data: dict[str, str] = {} + self._zsets: dict[str, dict[str, float]] = {} @classmethod def from_url(cls, url, **kwargs): @@ -55,19 +56,54 @@ class FakeRedis: async def close(self): pass + async def ping(self): + return True + + # ── Sorted-set operations ────────────────────────────── + + async def zadd(self, name, mapping): + zs = self._zsets.setdefault(name, {}) + added = 0 + for member, score in mapping.items(): + if member not in zs: + added += 1 + zs[member] = score + return added + + async def zcard(self, name): + return len(self._zsets.get(name, {})) + + async def zrange(self, name, start, end): + zs = self._zsets.get(name, {}) + # Sort by score, then by member for deterministic order + sorted_members = sorted(zs.keys(), key=lambda m: (zs[m], m)) + if end == -1: + return sorted_members[start:] + return sorted_members[start : end + 1] + + async def zrem(self, name, *members): + zs = self._zsets.get(name, {}) + removed = 0 + for m in members: + if m in zs: + del zs[m] + removed += 1 + return removed + async def eval(self, script, numkeys, *args): """Simulate Redis EVAL for the update_status Lua script.""" # This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore key = args[0] - ttl = int(args[1]) - n = int(args[2]) + reset_ttl = args[1] + ttl = int(args[2]) + n = int(args[3]) raw = self._data.get(key) if raw is None: return None data = json.loads(raw) for i in range(n): - k = args[3 + 2 * i] - v = args[4 + 2 * i] + k = args[4 + 2 * i] + v = args[5 + 2 * i] # Try to parse JSON values (dicts/lists), otherwise keep as string try: data[k] = json.loads(v)