fix(review): address all 14 P2 advisory findings

This commit is contained in:
chiguyong 2026-06-06 18:20:46 +08:00
parent 8620751864
commit 0456429beb
15 changed files with 208 additions and 47 deletions

View File

@ -151,8 +151,8 @@ class ContextCompressor:
result = []
for msg in messages:
content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 2:
msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"}
if len(content) > self._max_tokens * 4:
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"}
result.append(msg)
return result

View File

@ -7,7 +7,7 @@
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Callable
@dataclass
@ -83,12 +83,14 @@ class TraceRecorder:
task_id: str = "",
agent_name: str = "",
skill_name: str | None = None,
on_trace_complete: Callable[[ExecutionTrace], None] | None = None,
):
self._trace: ExecutionTrace | None = None
self._completed_trace: ExecutionTrace | None = None
self._completed: bool = False
self._step_start_time: float = 0
self._trace_start_time: float = 0
self._on_trace_complete = on_trace_complete
# 如果构造时提供了参数,自动 start_trace
if task_id:
self.start_trace(task_id=task_id, agent_name=agent_name, skill_name=skill_name)
@ -165,6 +167,10 @@ class TraceRecorder:
self._completed = True
self._completed_trace = result
self._trace = None
if self._on_trace_complete is not None:
self._on_trace_complete(result)
return result
def get_trace(self) -> ExecutionTrace | None:

View File

@ -162,6 +162,18 @@ class PersistentEvolutionStore:
loop = asyncio.get_running_loop()
return loop.run_in_executor(None, func)
async def close(self) -> None:
"""Dispose the SQLAlchemy engine, releasing all pooled connections."""
if self._engine is not None:
await self._run_sync(self._engine.dispose)
self._engine = None
async def __aenter__(self) -> "PersistentEvolutionStore":
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()
@staticmethod
def _retry_locked(func, *args, max_retries: int = 5, base_delay: float = 0.05, **kwargs):
"""Retry a function on SQLite 'database is locked' OperationalError."""

View File

@ -24,6 +24,11 @@ class LLMGateway:
self._providers[name] = provider
logger.info(f"LLM provider '{name}' registered")
@property
def has_providers(self) -> bool:
"""Return True if at least one LLM provider is registered."""
return bool(self._providers)
async def chat(
self,
messages: list[dict[str, str]],

View File

@ -1,5 +1,6 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
import json
import logging
import math
from datetime import datetime, timezone
@ -53,7 +54,10 @@ class EpisodicMemory(Memory):
# 生成 embedding
embedding = None
if self._embedder:
text = f"{key} {value}"
if isinstance(value, dict):
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
else:
text = str(value)
embedding = await self._embedder.embed(text)
entry = Model(
@ -131,8 +135,16 @@ class EpisodicMemory(Memory):
logger.error(f"Failed to retrieve episodic memory: {e}")
return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
"""语义检索相似历史案例"""
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
"""语义检索相似历史案例
Args:
query: 搜索查询文本
top_k: 返回的最大结果数
filters: 可选过滤条件agent_name, task_type, outcome
search_multiplier: 预取行数倍数fetch top_k * search_multiplier 行后再
排序截断当过滤条件较严格时可增大此值以避免漏掉相关条目
"""
async with self._session_factory() as db:
try:
Model = self._episodic_model
@ -149,7 +161,7 @@ class EpisodicMemory(Memory):
if filters.get("outcome"):
stmt = stmt.where(Model.outcome == filters["outcome"])
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 5)
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * search_multiplier)
result = await db.execute(stmt)
entries = result.scalars().all()
@ -192,6 +204,12 @@ class EpisodicMemory(Memory):
))
items.sort(key=lambda x: x.score, reverse=True)
if len(items) < top_k:
logger.warning(
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
len(items), top_k, search_multiplier,
)
return items[:top_k]
except Exception as e:

View File

@ -3,6 +3,7 @@
import hashlib
import json
import logging
import re
from typing import Any
from agentkit.prompts.section import PromptSection
@ -43,7 +44,7 @@ class PromptTemplate:
context = self._sections.context
if variables:
for key, value in variables.items():
context = context.replace(f"${{{key}}}", str(value))
context = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), context)
system_parts.append(context)
if self._sections.constraints:
system_parts.append(self._sections.constraints)
@ -53,7 +54,7 @@ class PromptTemplate:
instructions = self._sections.instructions
if variables:
for key, value in variables.items():
instructions = instructions.replace(f"${{{key}}}", str(value))
instructions = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), instructions)
user_parts.append(instructions)
if self._sections.output_format:
user_parts.append(self._sections.output_format)

View File

@ -52,7 +52,7 @@ async def health_check(request: Request):
if llm_gateway:
llm_status = "configured"
try:
if hasattr(llm_gateway, "_providers") and llm_gateway._providers:
if llm_gateway.has_providers:
llm_status = "available"
else:
llm_status = "no_providers"

View File

@ -1,7 +1,11 @@
"""Metrics route — /api/v1/metrics"""
import logging
from fastapi import APIRouter, Request
logger = logging.getLogger(__name__)
router = APIRouter(tags=["metrics"])
@ -25,36 +29,32 @@ async def get_metrics(request: Request):
task_metrics["completed_tasks"] = counts.get("completed", 0)
task_metrics["failed_tasks"] = counts.get("failed", 0)
task_metrics["pending_tasks"] = counts.get("pending", 0)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to collect task metrics: {e}")
# Agent pool metrics
agent_pool = getattr(app.state, "agent_pool", None)
agent_metrics: dict = {
"total_agents": 0,
"agent_names": [],
}
if agent_pool:
try:
agents = agent_pool.list_agents()
agent_metrics["total_agents"] = len(agents)
agent_metrics["agent_names"] = [a.get("name", "") for a in agents]
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to collect agent metrics: {e}")
# Skill registry metrics
skill_registry = getattr(app.state, "skill_registry", None)
skill_metrics: dict = {
"total_skills": 0,
"skill_names": [],
}
if skill_registry:
try:
skills = skill_registry.list_skills()
skill_metrics["total_skills"] = len(skills)
skill_metrics["skill_names"] = [s.name for s in skills]
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to collect skill metrics: {e}")
return {
"tasks": task_metrics,

View File

@ -1,5 +1,7 @@
"""Skill registration routes"""
import logging
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import Any
@ -7,6 +9,8 @@ from typing import Any
from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.pipeline import SkillPipeline
logger = logging.getLogger(__name__)
router = APIRouter(tags=["skills"])
@ -111,6 +115,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ
try:
result = await pipeline.execute(input_data=request.input_data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {e}")
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Pipeline execution failed")
return result

View File

@ -182,6 +182,10 @@ class InMemoryTaskStore:
def size(self) -> int:
return len(self._tasks)
async def health_check(self) -> bool:
"""Verify the store is operational. Always returns True for in-memory backend."""
return True
# Backward-compatible alias
TaskStore = InMemoryTaskStore
@ -196,6 +200,7 @@ class RedisTaskStore:
"""
KEY_PREFIX = "agentkit:task:"
ZSET_KEY = "agentkit:tasks:by_time"
def __init__(
self,
@ -258,7 +263,9 @@ class RedisTaskStore:
skill_name=skill_name,
input_data=input_data,
)
score = record.created_at.timestamp()
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
await redis.zadd(self.ZSET_KEY, {task_id: score})
return record
async def get(self, task_id: str) -> TaskRecord | None:
@ -270,27 +277,45 @@ class RedisTaskStore:
return TaskRecord.from_dict(json.loads(raw))
# Lua script for atomic read-modify-write
# ARGV[1] = "1" to reset TTL (apply ex=ttl_seconds), "0" to keep existing TTL (KEEPTTL)
# ARGV[2] = ttl_seconds (only used when ARGV[1] == "1")
# ARGV[3] = number of merge fields
# ARGV[4..] = key/value pairs
_UPDATE_STATUS_SCRIPT = """
local reset_ttl = ARGV[1]
local ttl = tonumber(ARGV[2])
local n = tonumber(ARGV[3])
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local raw = redis.call('GET', key)
if raw == false then
return nil
end
local data = cjson.decode(raw)
local n = tonumber(ARGV[2])
for i = 1, n do
local k = ARGV[2 + 2 * (i - 1) + 1]
local v = ARGV[2 + 2 * (i - 1) + 2]
local k = ARGV[3 + 2 * (i - 1) + 1]
local v = ARGV[3 + 2 * (i - 1) + 2]
data[k] = v
end
local encoded = cjson.encode(data)
if reset_ttl == "1" then
redis.call('SET', key, encoded, 'EX', ttl)
else
redis.call('SET', key, encoded, 'KEEPTTL')
end
return encoded
"""
async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
"""Update task status and optional fields atomically via Lua script."""
async def update_status(self, task_id: str, status: TaskStatus, reset_ttl: bool = False, **kwargs) -> TaskRecord:
"""Update task status and optional fields atomically via Lua script.
Args:
task_id: Task identifier.
status: New task status.
reset_ttl: If True, reset the Redis TTL to ``ttl_seconds``. Defaults to
False so that frequent status updates on a long-running task do not
extend its lifetime indefinitely.
**kwargs: Optional fields to update (started_at, completed_at, etc.).
"""
redis = await self._get_redis()
key = self._key(task_id)
@ -304,7 +329,7 @@ return encoded
merge_fields[k] = value
# Flatten merge_fields into ARGV pairs
args = [str(self._ttl_seconds), str(len(merge_fields))]
args = ["1" if reset_ttl else "0", str(self._ttl_seconds), str(len(merge_fields))]
for k, v in merge_fields.items():
args.append(k)
args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v))
@ -360,10 +385,26 @@ return encoded
redis = await self._get_redis()
return await self._count_keys(redis)
async def health_check(self) -> bool:
"""Verify Redis connectivity by sending a PING command."""
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
# ── helpers ────────────────────────────────────────────────
async def _count_keys(self, redis) -> int:
"""Count task keys using SCAN (avoid KEYS on large datasets)."""
"""Count task keys. Uses ZCARD on the sorted set for O(1) when
available, falls back to SCAN otherwise."""
try:
count = await redis.zcard(self.ZSET_KEY)
if count > 0:
return count
except Exception:
pass
# Fallback: full SCAN
count = 0
cursor = 0
while True:
@ -375,8 +416,32 @@ return encoded
async def _evict_oldest_completed(self, redis) -> bool:
"""Find and delete the oldest completed/failed/cancelled task.
Uses ZRANGE on the sorted set for O(log N) when available,
falls back to full SCAN otherwise.
Returns True if a record was evicted, False otherwise.
"""
# Try ZSET-based eviction first
try:
member_count = await redis.zcard(self.ZSET_KEY)
if member_count > 0:
# Iterate from oldest (lowest score) to find a completed task
task_ids = await redis.zrange(self.ZSET_KEY, 0, -1)
for tid in task_ids:
raw = await redis.get(self._key(tid))
if raw is None:
# Stale ZSET entry clean up
await redis.zrem(self.ZSET_KEY, tid)
continue
record = TaskRecord.from_dict(json.loads(raw))
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED) and record.completed_at is not None:
await redis.delete(self._key(tid))
await redis.zrem(self.ZSET_KEY, tid)
return True
return False
except Exception:
pass
# Fallback: full SCAN
tasks: list[TaskRecord] = []
cursor = 0
while True:
@ -405,6 +470,10 @@ return encoded
return False
await redis.delete(self._key(oldest.task_id))
try:
await redis.zrem(self.ZSET_KEY, oldest.task_id)
except Exception:
pass
return True
@ -418,6 +487,11 @@ def create_task_store(
If ``backend="redis"`` and the Redis connection cannot be established,
falls back to :class:`InMemoryTaskStore` with a warning.
Note:
This factory only validates that the ``redis`` package is importable.
Runtime connectivity should be verified via ``await store.health_check()``
during application startup.
"""
if backend == "redis":
try:

View File

@ -7,6 +7,7 @@
"""
import logging
import re
from typing import Any, Callable, Coroutine
from agentkit.skills.base import Skill, SkillConfig
@ -161,19 +162,20 @@ class SkillPipeline:
- "key.path > 0.5" 数值大于
"""
try:
if "==" in condition:
path, value = condition.split("==", 1)
path = path.strip()
value = value.strip().strip("'\"")
eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip())
if eq_match:
path = eq_match.group(1)
value = eq_match.group(2).strip().strip("'\"")
actual = self._resolve_path(path, current_input)
return str(actual) == value
elif ">" in condition:
path, value = condition.split(">", 1)
path = path.strip()
value = float(value.strip())
gt_match = re.match(r'^([\w.]+)\s*>\s*(.+)$', condition.strip())
if gt_match:
path = gt_match.group(1)
value = float(gt_match.group(2).strip())
actual = float(self._resolve_path(path, current_input))
return actual > value
except Exception:
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.warning(f"Condition evaluation failed for '{condition}': {e}")
return False
return False

View File

@ -30,6 +30,12 @@ class SkillMdParser:
def parse(file_path: str) -> tuple[dict[str, Any], dict[str, str], str]:
"""解析 SKILL.md 文件
Note: Only H1 headings (# ) are treated as section delimiters.
H2+ headings (## , ### , etc.) are treated as regular content
and merged into their parent H1 section. This is by design
SKILL.md uses a flat section model where sub-structure within
a section is preserved as-is in the section body text.
Args:
file_path: SKILL.md 文件路径

View File

@ -156,7 +156,6 @@ class TestEpisodicMemoryStore:
mock_embedder.embed.assert_called_once()
call_args = mock_embedder.embed.call_args[0][0]
assert "key1" in call_args
assert "some value" in call_args
# 验证 entry 的 embedding 被设置

View File

@ -216,9 +216,7 @@ class TestMetricsEndpoint:
assert data["tasks"]["failed_tasks"] == 0
assert data["tasks"]["pending_tasks"] == 0
assert data["agents"]["total_agents"] == 0
assert data["agents"]["agent_names"] == []
assert data["skills"]["total_skills"] == 0
assert data["skills"]["skill_names"] == []
def test_metrics_with_registered_skill(self, client, skill_registry):
skill_config = SkillConfig(
@ -234,7 +232,6 @@ class TestMetricsEndpoint:
response = client.get("/api/v1/metrics")
data = response.json()
assert data["skills"]["total_skills"] == 1
assert "metrics_skill" in data["skills"]["skill_names"]
def test_metrics_version(self, client):
response = client.get("/api/v1/metrics")

View File

@ -26,6 +26,7 @@ class FakeRedis:
def __init__(self):
self._data: dict[str, str] = {}
self._zsets: dict[str, dict[str, float]] = {}
@classmethod
def from_url(cls, url, **kwargs):
@ -55,19 +56,54 @@ class FakeRedis:
async def close(self):
pass
async def ping(self):
return True
# ── Sorted-set operations ──────────────────────────────
async def zadd(self, name, mapping):
zs = self._zsets.setdefault(name, {})
added = 0
for member, score in mapping.items():
if member not in zs:
added += 1
zs[member] = score
return added
async def zcard(self, name):
return len(self._zsets.get(name, {}))
async def zrange(self, name, start, end):
zs = self._zsets.get(name, {})
# Sort by score, then by member for deterministic order
sorted_members = sorted(zs.keys(), key=lambda m: (zs[m], m))
if end == -1:
return sorted_members[start:]
return sorted_members[start : end + 1]
async def zrem(self, name, *members):
zs = self._zsets.get(name, {})
removed = 0
for m in members:
if m in zs:
del zs[m]
removed += 1
return removed
async def eval(self, script, numkeys, *args):
"""Simulate Redis EVAL for the update_status Lua script."""
# This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore
key = args[0]
ttl = int(args[1])
n = int(args[2])
reset_ttl = args[1]
ttl = int(args[2])
n = int(args[3])
raw = self._data.get(key)
if raw is None:
return None
data = json.loads(raw)
for i in range(n):
k = args[3 + 2 * i]
v = args[4 + 2 * i]
k = args[4 + 2 * i]
v = args[5 + 2 * i]
# Try to parse JSON values (dicts/lists), otherwise keep as string
try:
data[k] = json.loads(v)