fix(review): address all 14 P2 advisory findings
This commit is contained in:
parent
8620751864
commit
0456429beb
|
|
@ -151,8 +151,8 @@ class ContextCompressor:
|
||||||
result = []
|
result = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
content = str(msg.get("content", ""))
|
content = str(msg.get("content", ""))
|
||||||
if len(content) > self._max_tokens * 2:
|
if len(content) > self._max_tokens * 4:
|
||||||
msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"}
|
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"}
|
||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -83,12 +83,14 @@ class TraceRecorder:
|
||||||
task_id: str = "",
|
task_id: str = "",
|
||||||
agent_name: str = "",
|
agent_name: str = "",
|
||||||
skill_name: str | None = None,
|
skill_name: str | None = None,
|
||||||
|
on_trace_complete: Callable[[ExecutionTrace], None] | None = None,
|
||||||
):
|
):
|
||||||
self._trace: ExecutionTrace | None = None
|
self._trace: ExecutionTrace | None = None
|
||||||
self._completed_trace: ExecutionTrace | None = None
|
self._completed_trace: ExecutionTrace | None = None
|
||||||
self._completed: bool = False
|
self._completed: bool = False
|
||||||
self._step_start_time: float = 0
|
self._step_start_time: float = 0
|
||||||
self._trace_start_time: float = 0
|
self._trace_start_time: float = 0
|
||||||
|
self._on_trace_complete = on_trace_complete
|
||||||
# 如果构造时提供了参数,自动 start_trace
|
# 如果构造时提供了参数,自动 start_trace
|
||||||
if task_id:
|
if task_id:
|
||||||
self.start_trace(task_id=task_id, agent_name=agent_name, skill_name=skill_name)
|
self.start_trace(task_id=task_id, agent_name=agent_name, skill_name=skill_name)
|
||||||
|
|
@ -165,6 +167,10 @@ class TraceRecorder:
|
||||||
self._completed = True
|
self._completed = True
|
||||||
self._completed_trace = result
|
self._completed_trace = result
|
||||||
self._trace = None
|
self._trace = None
|
||||||
|
|
||||||
|
if self._on_trace_complete is not None:
|
||||||
|
self._on_trace_complete(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_trace(self) -> ExecutionTrace | None:
|
def get_trace(self) -> ExecutionTrace | None:
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,18 @@ class PersistentEvolutionStore:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
return loop.run_in_executor(None, func)
|
return loop.run_in_executor(None, func)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Dispose the SQLAlchemy engine, releasing all pooled connections."""
|
||||||
|
if self._engine is not None:
|
||||||
|
await self._run_sync(self._engine.dispose)
|
||||||
|
self._engine = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "PersistentEvolutionStore":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _retry_locked(func, *args, max_retries: int = 5, base_delay: float = 0.05, **kwargs):
|
def _retry_locked(func, *args, max_retries: int = 5, base_delay: float = 0.05, **kwargs):
|
||||||
"""Retry a function on SQLite 'database is locked' OperationalError."""
|
"""Retry a function on SQLite 'database is locked' OperationalError."""
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,11 @@ class LLMGateway:
|
||||||
self._providers[name] = provider
|
self._providers[name] = provider
|
||||||
logger.info(f"LLM provider '{name}' registered")
|
logger.info(f"LLM provider '{name}' registered")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_providers(self) -> bool:
|
||||||
|
"""Return True if at least one LLM provider is registered."""
|
||||||
|
return bool(self._providers)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -53,7 +54,10 @@ class EpisodicMemory(Memory):
|
||||||
# 生成 embedding
|
# 生成 embedding
|
||||||
embedding = None
|
embedding = None
|
||||||
if self._embedder:
|
if self._embedder:
|
||||||
text = f"{key} {value}"
|
if isinstance(value, dict):
|
||||||
|
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
|
||||||
|
else:
|
||||||
|
text = str(value)
|
||||||
embedding = await self._embedder.embed(text)
|
embedding = await self._embedder.embed(text)
|
||||||
|
|
||||||
entry = Model(
|
entry = Model(
|
||||||
|
|
@ -131,8 +135,16 @@ class EpisodicMemory(Memory):
|
||||||
logger.error(f"Failed to retrieve episodic memory: {e}")
|
logger.error(f"Failed to retrieve episodic memory: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
|
||||||
"""语义检索相似历史案例"""
|
"""语义检索相似历史案例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索查询文本。
|
||||||
|
top_k: 返回的最大结果数。
|
||||||
|
filters: 可选过滤条件(agent_name, task_type, outcome)。
|
||||||
|
search_multiplier: 预取行数倍数(fetch top_k * search_multiplier 行后再
|
||||||
|
排序截断)。当过滤条件较严格时,可增大此值以避免漏掉相关条目。
|
||||||
|
"""
|
||||||
async with self._session_factory() as db:
|
async with self._session_factory() as db:
|
||||||
try:
|
try:
|
||||||
Model = self._episodic_model
|
Model = self._episodic_model
|
||||||
|
|
@ -149,7 +161,7 @@ class EpisodicMemory(Memory):
|
||||||
if filters.get("outcome"):
|
if filters.get("outcome"):
|
||||||
stmt = stmt.where(Model.outcome == filters["outcome"])
|
stmt = stmt.where(Model.outcome == filters["outcome"])
|
||||||
|
|
||||||
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 5)
|
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * search_multiplier)
|
||||||
|
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
entries = result.scalars().all()
|
entries = result.scalars().all()
|
||||||
|
|
@ -192,6 +204,12 @@ class EpisodicMemory(Memory):
|
||||||
))
|
))
|
||||||
|
|
||||||
items.sort(key=lambda x: x.score, reverse=True)
|
items.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
if len(items) < top_k:
|
||||||
|
logger.warning(
|
||||||
|
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
|
||||||
|
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
|
||||||
|
len(items), top_k, search_multiplier,
|
||||||
|
)
|
||||||
return items[:top_k]
|
return items[:top_k]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.prompts.section import PromptSection
|
from agentkit.prompts.section import PromptSection
|
||||||
|
|
@ -43,7 +44,7 @@ class PromptTemplate:
|
||||||
context = self._sections.context
|
context = self._sections.context
|
||||||
if variables:
|
if variables:
|
||||||
for key, value in variables.items():
|
for key, value in variables.items():
|
||||||
context = context.replace(f"${{{key}}}", str(value))
|
context = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), context)
|
||||||
system_parts.append(context)
|
system_parts.append(context)
|
||||||
if self._sections.constraints:
|
if self._sections.constraints:
|
||||||
system_parts.append(self._sections.constraints)
|
system_parts.append(self._sections.constraints)
|
||||||
|
|
@ -53,7 +54,7 @@ class PromptTemplate:
|
||||||
instructions = self._sections.instructions
|
instructions = self._sections.instructions
|
||||||
if variables:
|
if variables:
|
||||||
for key, value in variables.items():
|
for key, value in variables.items():
|
||||||
instructions = instructions.replace(f"${{{key}}}", str(value))
|
instructions = re.sub(r'\$\{' + re.escape(key) + r'\}', str(value), instructions)
|
||||||
user_parts.append(instructions)
|
user_parts.append(instructions)
|
||||||
if self._sections.output_format:
|
if self._sections.output_format:
|
||||||
user_parts.append(self._sections.output_format)
|
user_parts.append(self._sections.output_format)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ async def health_check(request: Request):
|
||||||
if llm_gateway:
|
if llm_gateway:
|
||||||
llm_status = "configured"
|
llm_status = "configured"
|
||||||
try:
|
try:
|
||||||
if hasattr(llm_gateway, "_providers") and llm_gateway._providers:
|
if llm_gateway.has_providers:
|
||||||
llm_status = "available"
|
llm_status = "available"
|
||||||
else:
|
else:
|
||||||
llm_status = "no_providers"
|
llm_status = "no_providers"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
"""Metrics route — /api/v1/metrics"""
|
"""Metrics route — /api/v1/metrics"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["metrics"])
|
router = APIRouter(tags=["metrics"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,36 +29,32 @@ async def get_metrics(request: Request):
|
||||||
task_metrics["completed_tasks"] = counts.get("completed", 0)
|
task_metrics["completed_tasks"] = counts.get("completed", 0)
|
||||||
task_metrics["failed_tasks"] = counts.get("failed", 0)
|
task_metrics["failed_tasks"] = counts.get("failed", 0)
|
||||||
task_metrics["pending_tasks"] = counts.get("pending", 0)
|
task_metrics["pending_tasks"] = counts.get("pending", 0)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.warning(f"Failed to collect task metrics: {e}")
|
||||||
|
|
||||||
# Agent pool metrics
|
# Agent pool metrics
|
||||||
agent_pool = getattr(app.state, "agent_pool", None)
|
agent_pool = getattr(app.state, "agent_pool", None)
|
||||||
agent_metrics: dict = {
|
agent_metrics: dict = {
|
||||||
"total_agents": 0,
|
"total_agents": 0,
|
||||||
"agent_names": [],
|
|
||||||
}
|
}
|
||||||
if agent_pool:
|
if agent_pool:
|
||||||
try:
|
try:
|
||||||
agents = agent_pool.list_agents()
|
agents = agent_pool.list_agents()
|
||||||
agent_metrics["total_agents"] = len(agents)
|
agent_metrics["total_agents"] = len(agents)
|
||||||
agent_metrics["agent_names"] = [a.get("name", "") for a in agents]
|
except Exception as e:
|
||||||
except Exception:
|
logger.warning(f"Failed to collect agent metrics: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
# Skill registry metrics
|
# Skill registry metrics
|
||||||
skill_registry = getattr(app.state, "skill_registry", None)
|
skill_registry = getattr(app.state, "skill_registry", None)
|
||||||
skill_metrics: dict = {
|
skill_metrics: dict = {
|
||||||
"total_skills": 0,
|
"total_skills": 0,
|
||||||
"skill_names": [],
|
|
||||||
}
|
}
|
||||||
if skill_registry:
|
if skill_registry:
|
||||||
try:
|
try:
|
||||||
skills = skill_registry.list_skills()
|
skills = skill_registry.list_skills()
|
||||||
skill_metrics["total_skills"] = len(skills)
|
skill_metrics["total_skills"] = len(skills)
|
||||||
skill_metrics["skill_names"] = [s.name for s in skills]
|
except Exception as e:
|
||||||
except Exception:
|
logger.warning(f"Failed to collect skill metrics: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"tasks": task_metrics,
|
"tasks": task_metrics,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""Skill registration routes"""
|
"""Skill registration routes"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -7,6 +9,8 @@ from typing import Any
|
||||||
from agentkit.skills.base import Skill, SkillConfig
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
from agentkit.skills.pipeline import SkillPipeline
|
from agentkit.skills.pipeline import SkillPipeline
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["skills"])
|
router = APIRouter(tags=["skills"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -111,6 +115,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ
|
||||||
try:
|
try:
|
||||||
result = await pipeline.execute(input_data=request.input_data)
|
result = await pipeline.execute(input_data=request.input_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {e}")
|
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Pipeline execution failed")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,10 @@ class InMemoryTaskStore:
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
return len(self._tasks)
|
return len(self._tasks)
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""Verify the store is operational. Always returns True for in-memory backend."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Backward-compatible alias
|
# Backward-compatible alias
|
||||||
TaskStore = InMemoryTaskStore
|
TaskStore = InMemoryTaskStore
|
||||||
|
|
@ -196,6 +200,7 @@ class RedisTaskStore:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
KEY_PREFIX = "agentkit:task:"
|
KEY_PREFIX = "agentkit:task:"
|
||||||
|
ZSET_KEY = "agentkit:tasks:by_time"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -258,7 +263,9 @@ class RedisTaskStore:
|
||||||
skill_name=skill_name,
|
skill_name=skill_name,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
)
|
)
|
||||||
|
score = record.created_at.timestamp()
|
||||||
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
|
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
|
||||||
|
await redis.zadd(self.ZSET_KEY, {task_id: score})
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def get(self, task_id: str) -> TaskRecord | None:
|
async def get(self, task_id: str) -> TaskRecord | None:
|
||||||
|
|
@ -270,27 +277,45 @@ class RedisTaskStore:
|
||||||
return TaskRecord.from_dict(json.loads(raw))
|
return TaskRecord.from_dict(json.loads(raw))
|
||||||
|
|
||||||
# Lua script for atomic read-modify-write
|
# Lua script for atomic read-modify-write
|
||||||
|
# ARGV[1] = "1" to reset TTL (apply ex=ttl_seconds), "0" to keep existing TTL (KEEPTTL)
|
||||||
|
# ARGV[2] = ttl_seconds (only used when ARGV[1] == "1")
|
||||||
|
# ARGV[3] = number of merge fields
|
||||||
|
# ARGV[4..] = key/value pairs
|
||||||
_UPDATE_STATUS_SCRIPT = """
|
_UPDATE_STATUS_SCRIPT = """
|
||||||
|
local reset_ttl = ARGV[1]
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
local n = tonumber(ARGV[3])
|
||||||
local key = KEYS[1]
|
local key = KEYS[1]
|
||||||
local ttl = tonumber(ARGV[1])
|
|
||||||
local raw = redis.call('GET', key)
|
local raw = redis.call('GET', key)
|
||||||
if raw == false then
|
if raw == false then
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
local data = cjson.decode(raw)
|
local data = cjson.decode(raw)
|
||||||
local n = tonumber(ARGV[2])
|
|
||||||
for i = 1, n do
|
for i = 1, n do
|
||||||
local k = ARGV[2 + 2 * (i - 1) + 1]
|
local k = ARGV[3 + 2 * (i - 1) + 1]
|
||||||
local v = ARGV[2 + 2 * (i - 1) + 2]
|
local v = ARGV[3 + 2 * (i - 1) + 2]
|
||||||
data[k] = v
|
data[k] = v
|
||||||
end
|
end
|
||||||
local encoded = cjson.encode(data)
|
local encoded = cjson.encode(data)
|
||||||
redis.call('SET', key, encoded, 'EX', ttl)
|
if reset_ttl == "1" then
|
||||||
|
redis.call('SET', key, encoded, 'EX', ttl)
|
||||||
|
else
|
||||||
|
redis.call('SET', key, encoded, 'KEEPTTL')
|
||||||
|
end
|
||||||
return encoded
|
return encoded
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
|
async def update_status(self, task_id: str, status: TaskStatus, reset_ttl: bool = False, **kwargs) -> TaskRecord:
|
||||||
"""Update task status and optional fields atomically via Lua script."""
|
"""Update task status and optional fields atomically via Lua script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task identifier.
|
||||||
|
status: New task status.
|
||||||
|
reset_ttl: If True, reset the Redis TTL to ``ttl_seconds``. Defaults to
|
||||||
|
False so that frequent status updates on a long-running task do not
|
||||||
|
extend its lifetime indefinitely.
|
||||||
|
**kwargs: Optional fields to update (started_at, completed_at, etc.).
|
||||||
|
"""
|
||||||
redis = await self._get_redis()
|
redis = await self._get_redis()
|
||||||
key = self._key(task_id)
|
key = self._key(task_id)
|
||||||
|
|
||||||
|
|
@ -304,7 +329,7 @@ return encoded
|
||||||
merge_fields[k] = value
|
merge_fields[k] = value
|
||||||
|
|
||||||
# Flatten merge_fields into ARGV pairs
|
# Flatten merge_fields into ARGV pairs
|
||||||
args = [str(self._ttl_seconds), str(len(merge_fields))]
|
args = ["1" if reset_ttl else "0", str(self._ttl_seconds), str(len(merge_fields))]
|
||||||
for k, v in merge_fields.items():
|
for k, v in merge_fields.items():
|
||||||
args.append(k)
|
args.append(k)
|
||||||
args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v))
|
args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v))
|
||||||
|
|
@ -360,10 +385,26 @@ return encoded
|
||||||
redis = await self._get_redis()
|
redis = await self._get_redis()
|
||||||
return await self._count_keys(redis)
|
return await self._count_keys(redis)
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""Verify Redis connectivity by sending a PING command."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
return await redis.ping()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
# ── helpers ────────────────────────────────────────────────
|
# ── helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _count_keys(self, redis) -> int:
|
async def _count_keys(self, redis) -> int:
|
||||||
"""Count task keys using SCAN (avoid KEYS on large datasets)."""
|
"""Count task keys. Uses ZCARD on the sorted set for O(1) when
|
||||||
|
available, falls back to SCAN otherwise."""
|
||||||
|
try:
|
||||||
|
count = await redis.zcard(self.ZSET_KEY)
|
||||||
|
if count > 0:
|
||||||
|
return count
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback: full SCAN
|
||||||
count = 0
|
count = 0
|
||||||
cursor = 0
|
cursor = 0
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -375,8 +416,32 @@ return encoded
|
||||||
|
|
||||||
async def _evict_oldest_completed(self, redis) -> bool:
|
async def _evict_oldest_completed(self, redis) -> bool:
|
||||||
"""Find and delete the oldest completed/failed/cancelled task.
|
"""Find and delete the oldest completed/failed/cancelled task.
|
||||||
|
Uses ZRANGE on the sorted set for O(log N) when available,
|
||||||
|
falls back to full SCAN otherwise.
|
||||||
Returns True if a record was evicted, False otherwise.
|
Returns True if a record was evicted, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
# Try ZSET-based eviction first
|
||||||
|
try:
|
||||||
|
member_count = await redis.zcard(self.ZSET_KEY)
|
||||||
|
if member_count > 0:
|
||||||
|
# Iterate from oldest (lowest score) to find a completed task
|
||||||
|
task_ids = await redis.zrange(self.ZSET_KEY, 0, -1)
|
||||||
|
for tid in task_ids:
|
||||||
|
raw = await redis.get(self._key(tid))
|
||||||
|
if raw is None:
|
||||||
|
# Stale ZSET entry – clean up
|
||||||
|
await redis.zrem(self.ZSET_KEY, tid)
|
||||||
|
continue
|
||||||
|
record = TaskRecord.from_dict(json.loads(raw))
|
||||||
|
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED) and record.completed_at is not None:
|
||||||
|
await redis.delete(self._key(tid))
|
||||||
|
await redis.zrem(self.ZSET_KEY, tid)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: full SCAN
|
||||||
tasks: list[TaskRecord] = []
|
tasks: list[TaskRecord] = []
|
||||||
cursor = 0
|
cursor = 0
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -405,6 +470,10 @@ return encoded
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await redis.delete(self._key(oldest.task_id))
|
await redis.delete(self._key(oldest.task_id))
|
||||||
|
try:
|
||||||
|
await redis.zrem(self.ZSET_KEY, oldest.task_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -418,6 +487,11 @@ def create_task_store(
|
||||||
|
|
||||||
If ``backend="redis"`` and the Redis connection cannot be established,
|
If ``backend="redis"`` and the Redis connection cannot be established,
|
||||||
falls back to :class:`InMemoryTaskStore` with a warning.
|
falls back to :class:`InMemoryTaskStore` with a warning.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This factory only validates that the ``redis`` package is importable.
|
||||||
|
Runtime connectivity should be verified via ``await store.health_check()``
|
||||||
|
during application startup.
|
||||||
"""
|
"""
|
||||||
if backend == "redis":
|
if backend == "redis":
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
from agentkit.skills.base import Skill, SkillConfig
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
|
|
@ -161,19 +162,20 @@ class SkillPipeline:
|
||||||
- "key.path > 0.5" — 数值大于
|
- "key.path > 0.5" — 数值大于
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if "==" in condition:
|
eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip())
|
||||||
path, value = condition.split("==", 1)
|
if eq_match:
|
||||||
path = path.strip()
|
path = eq_match.group(1)
|
||||||
value = value.strip().strip("'\"")
|
value = eq_match.group(2).strip().strip("'\"")
|
||||||
actual = self._resolve_path(path, current_input)
|
actual = self._resolve_path(path, current_input)
|
||||||
return str(actual) == value
|
return str(actual) == value
|
||||||
elif ">" in condition:
|
gt_match = re.match(r'^([\w.]+)\s*>\s*(.+)$', condition.strip())
|
||||||
path, value = condition.split(">", 1)
|
if gt_match:
|
||||||
path = path.strip()
|
path = gt_match.group(1)
|
||||||
value = float(value.strip())
|
value = float(gt_match.group(2).strip())
|
||||||
actual = float(self._resolve_path(path, current_input))
|
actual = float(self._resolve_path(path, current_input))
|
||||||
return actual > value
|
return actual > value
|
||||||
except Exception:
|
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||||
|
logger.warning(f"Condition evaluation failed for '{condition}': {e}")
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,12 @@ class SkillMdParser:
|
||||||
def parse(file_path: str) -> tuple[dict[str, Any], dict[str, str], str]:
|
def parse(file_path: str) -> tuple[dict[str, Any], dict[str, str], str]:
|
||||||
"""解析 SKILL.md 文件
|
"""解析 SKILL.md 文件
|
||||||
|
|
||||||
|
Note: Only H1 headings (# ) are treated as section delimiters.
|
||||||
|
H2+ headings (## , ### , etc.) are treated as regular content
|
||||||
|
and merged into their parent H1 section. This is by design —
|
||||||
|
SKILL.md uses a flat section model where sub-structure within
|
||||||
|
a section is preserved as-is in the section body text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: SKILL.md 文件路径
|
file_path: SKILL.md 文件路径
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -156,7 +156,6 @@ class TestEpisodicMemoryStore:
|
||||||
|
|
||||||
mock_embedder.embed.assert_called_once()
|
mock_embedder.embed.assert_called_once()
|
||||||
call_args = mock_embedder.embed.call_args[0][0]
|
call_args = mock_embedder.embed.call_args[0][0]
|
||||||
assert "key1" in call_args
|
|
||||||
assert "some value" in call_args
|
assert "some value" in call_args
|
||||||
|
|
||||||
# 验证 entry 的 embedding 被设置
|
# 验证 entry 的 embedding 被设置
|
||||||
|
|
|
||||||
|
|
@ -216,9 +216,7 @@ class TestMetricsEndpoint:
|
||||||
assert data["tasks"]["failed_tasks"] == 0
|
assert data["tasks"]["failed_tasks"] == 0
|
||||||
assert data["tasks"]["pending_tasks"] == 0
|
assert data["tasks"]["pending_tasks"] == 0
|
||||||
assert data["agents"]["total_agents"] == 0
|
assert data["agents"]["total_agents"] == 0
|
||||||
assert data["agents"]["agent_names"] == []
|
|
||||||
assert data["skills"]["total_skills"] == 0
|
assert data["skills"]["total_skills"] == 0
|
||||||
assert data["skills"]["skill_names"] == []
|
|
||||||
|
|
||||||
def test_metrics_with_registered_skill(self, client, skill_registry):
|
def test_metrics_with_registered_skill(self, client, skill_registry):
|
||||||
skill_config = SkillConfig(
|
skill_config = SkillConfig(
|
||||||
|
|
@ -234,7 +232,6 @@ class TestMetricsEndpoint:
|
||||||
response = client.get("/api/v1/metrics")
|
response = client.get("/api/v1/metrics")
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["skills"]["total_skills"] == 1
|
assert data["skills"]["total_skills"] == 1
|
||||||
assert "metrics_skill" in data["skills"]["skill_names"]
|
|
||||||
|
|
||||||
def test_metrics_version(self, client):
|
def test_metrics_version(self, client):
|
||||||
response = client.get("/api/v1/metrics")
|
response = client.get("/api/v1/metrics")
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class FakeRedis:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._data: dict[str, str] = {}
|
self._data: dict[str, str] = {}
|
||||||
|
self._zsets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(cls, url, **kwargs):
|
def from_url(cls, url, **kwargs):
|
||||||
|
|
@ -55,19 +56,54 @@ class FakeRedis:
|
||||||
async def close(self):
|
async def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def ping(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Sorted-set operations ──────────────────────────────
|
||||||
|
|
||||||
|
async def zadd(self, name, mapping):
|
||||||
|
zs = self._zsets.setdefault(name, {})
|
||||||
|
added = 0
|
||||||
|
for member, score in mapping.items():
|
||||||
|
if member not in zs:
|
||||||
|
added += 1
|
||||||
|
zs[member] = score
|
||||||
|
return added
|
||||||
|
|
||||||
|
async def zcard(self, name):
|
||||||
|
return len(self._zsets.get(name, {}))
|
||||||
|
|
||||||
|
async def zrange(self, name, start, end):
|
||||||
|
zs = self._zsets.get(name, {})
|
||||||
|
# Sort by score, then by member for deterministic order
|
||||||
|
sorted_members = sorted(zs.keys(), key=lambda m: (zs[m], m))
|
||||||
|
if end == -1:
|
||||||
|
return sorted_members[start:]
|
||||||
|
return sorted_members[start : end + 1]
|
||||||
|
|
||||||
|
async def zrem(self, name, *members):
|
||||||
|
zs = self._zsets.get(name, {})
|
||||||
|
removed = 0
|
||||||
|
for m in members:
|
||||||
|
if m in zs:
|
||||||
|
del zs[m]
|
||||||
|
removed += 1
|
||||||
|
return removed
|
||||||
|
|
||||||
async def eval(self, script, numkeys, *args):
|
async def eval(self, script, numkeys, *args):
|
||||||
"""Simulate Redis EVAL for the update_status Lua script."""
|
"""Simulate Redis EVAL for the update_status Lua script."""
|
||||||
# This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore
|
# This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore
|
||||||
key = args[0]
|
key = args[0]
|
||||||
ttl = int(args[1])
|
reset_ttl = args[1]
|
||||||
n = int(args[2])
|
ttl = int(args[2])
|
||||||
|
n = int(args[3])
|
||||||
raw = self._data.get(key)
|
raw = self._data.get(key)
|
||||||
if raw is None:
|
if raw is None:
|
||||||
return None
|
return None
|
||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
k = args[3 + 2 * i]
|
k = args[4 + 2 * i]
|
||||||
v = args[4 + 2 * i]
|
v = args[5 + 2 * i]
|
||||||
# Try to parse JSON values (dicts/lists), otherwise keep as string
|
# Try to parse JSON values (dicts/lists), otherwise keep as string
|
||||||
try:
|
try:
|
||||||
data[k] = json.loads(v)
|
data[k] = json.loads(v)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue