feat(skills): SkillHarness 前置条件 + 风险守卫学习增强
- cli/skill.py: skill learn 子命令增强 - evolution/risk_guard_learner.py: 风险守卫学习改进 - memory/models.py: 记忆模型扩展 - skills/base.py + loader.py: SkillHarness 前置条件支持 - 对应测试更新
This commit is contained in:
parent
574db8458f
commit
20a4c55d5b
|
|
@ -2,12 +2,16 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import typer
|
||||
from rich import print as rprint
|
||||
from rich.table import Table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.evolution.experience_store import ExperienceStore
|
||||
from agentkit.evolution.risk_guard_learner import RiskGuardLearner
|
||||
|
||||
skill_app = typer.Typer(name="skill", help="Skill management commands", no_args_is_help=True)
|
||||
|
||||
|
||||
|
|
@ -19,6 +23,7 @@ def list_skills(
|
|||
if server_url:
|
||||
# Remote mode: call API
|
||||
import httpx
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(f"{server_url}/api/v1/skills")
|
||||
|
|
@ -35,7 +40,9 @@ def list_skills(
|
|||
|
||||
registry = SkillRegistry()
|
||||
# Load skills from the default configs/skills/ directory if it exists
|
||||
default_skills_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs", "skills")
|
||||
default_skills_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs", "skills"
|
||||
)
|
||||
if os.path.isdir(default_skills_dir):
|
||||
loader = SkillLoader(registry, ToolRegistry())
|
||||
loader.load_from_directory(default_skills_dir)
|
||||
|
|
@ -139,6 +146,7 @@ def skill_info(
|
|||
"""Show skill details"""
|
||||
if server_url:
|
||||
import httpx
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(f"{server_url}/api/v1/skills/{name}")
|
||||
|
|
@ -149,6 +157,7 @@ def skill_info(
|
|||
raise typer.Exit(code=1)
|
||||
else:
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
registry = SkillRegistry()
|
||||
try:
|
||||
skill = registry.get(name)
|
||||
|
|
@ -189,63 +198,104 @@ def learn_risk_guards(
|
|||
|
||||
learner = _build_risk_guard_learner()
|
||||
if learner is None:
|
||||
rprint("[red]Error: 无法构建 RiskGuardLearner——需要 PostgreSQL 与 LLM 配置。[/red]")
|
||||
rprint("[dim]请确保 agentkit.yaml 中已配置数据库与 LLM provider。[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
suggestions = asyncio.run(learner.learn(skill_name=skill, top_k=top_k))
|
||||
_render_risk_guard_suggestions(suggestions)
|
||||
|
||||
|
||||
def _build_risk_guard_learner():
|
||||
"""从本地配置构建 RiskGuardLearner,失败返回 None"""
|
||||
try:
|
||||
def _build_risk_guard_learner() -> "RiskGuardLearner | None":
|
||||
"""从本地配置构建 RiskGuardLearner,失败返回 None 并打印真实错误"""
|
||||
from agentkit.cli.chat import _build_gateway
|
||||
from agentkit.evolution.risk_guard_learner import RiskGuardLearner
|
||||
from agentkit.server.config import find_config_path, load_config_with_dotenv
|
||||
|
||||
config_path = find_config_path()
|
||||
server_config = load_config_with_dotenv(config_path)
|
||||
gateway = _build_gateway(server_config)
|
||||
if config_path is None:
|
||||
rprint("[red]Error: 未找到 agentkit.yaml 配置文件。[/red]")
|
||||
rprint("[dim]请运行 `agentkit init` 生成配置,或使用 --config 指定路径。[/dim]")
|
||||
return None
|
||||
|
||||
try:
|
||||
server_config = load_config_with_dotenv(config_path)
|
||||
except Exception as e:
|
||||
rprint(f"[red]Error: 加载配置失败: {e}[/red]")
|
||||
return None
|
||||
|
||||
try:
|
||||
gateway = _build_gateway(server_config)
|
||||
except Exception as e:
|
||||
rprint(f"[red]Error: 构建 LLM Gateway 失败: {e}[/red]")
|
||||
rprint("[dim]请检查 agentkit.yaml 中的 llm 配置(providers + api_key)。[/dim]")
|
||||
return None
|
||||
|
||||
# ExperienceStore 需要 PostgreSQL + ORM model;尝试从 server app 获取
|
||||
experience_store = _try_get_experience_store(server_config)
|
||||
if experience_store is None:
|
||||
rprint("[red]Error: 无法连接 PostgreSQL ExperienceStore。[/red]")
|
||||
rprint(
|
||||
"[dim]请在 agentkit.yaml 的 evolution.database_url 或 "
|
||||
"memory.episodic.database_url 中配置 PostgreSQL 连接串,"
|
||||
"或设置 DATABASE_URL 环境变量。[/dim]"
|
||||
)
|
||||
return None
|
||||
|
||||
return RiskGuardLearner(experience_store, gateway)
|
||||
|
||||
|
||||
def _try_get_experience_store(server_config) -> "ExperienceStore | None":
|
||||
"""尝试从 server_config 构建 PostgreSQL ExperienceStore,不可用时返回 None
|
||||
|
||||
查找 database_url 的优先级:
|
||||
1. server_config.evolution.database_url
|
||||
2. server_config.memory.episodic.database_url
|
||||
3. DATABASE_URL 环境变量
|
||||
"""
|
||||
import os
|
||||
|
||||
database_url: str | None = None
|
||||
|
||||
# 1. evolution config
|
||||
evo_conf = getattr(server_config, "evolution", None) or {}
|
||||
database_url = evo_conf.get("database_url") if isinstance(evo_conf, dict) else None
|
||||
|
||||
# 2. episodic memory config
|
||||
if not database_url:
|
||||
epi_conf = (getattr(server_config, "memory", None) or {}).get("episodic", {})
|
||||
database_url = epi_conf.get("database_url") if isinstance(epi_conf, dict) else None
|
||||
|
||||
# 3. env var
|
||||
if not database_url:
|
||||
database_url = os.environ.get("DATABASE_URL")
|
||||
|
||||
if not database_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
from agentkit.evolution.experience_store import ExperienceStore
|
||||
from agentkit.memory.models import ExperienceModel, create_experience_session_factory
|
||||
|
||||
session_factory = create_experience_session_factory(database_url)
|
||||
return ExperienceStore(
|
||||
session_factory=session_factory,
|
||||
experience_model=ExperienceModel,
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(f"Failed to build RiskGuardLearner: {e}")
|
||||
logging.getLogger(__name__).warning(f"Failed to create PostgreSQL ExperienceStore: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _try_get_experience_store(_server_config):
|
||||
"""尝试构建 ExperienceStore,PostgreSQL 不可用时返回 None
|
||||
|
||||
ponytail: 当前 codebase 未提供 PostgreSQL ExperienceStore 的 CLI 构建路径
|
||||
(无 ORM model + session factory 的 CLI helper)。回退到 InMemoryExperienceStore,
|
||||
它在无数据时返回空列表——命令会提示"未学习到建议"。
|
||||
升级路径:未来接入 PostgreSQL 后替换为真实 store。
|
||||
"""
|
||||
try:
|
||||
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||
|
||||
return InMemoryExperienceStore()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _render_risk_guard_suggestions(suggestions) -> None:
|
||||
def _render_risk_guard_suggestions(suggestions: list) -> None:
|
||||
"""渲染 RiskGuardSuggestion 列表到终端"""
|
||||
rprint(
|
||||
"[bold yellow]⚠ 以下为自动生成的风险守卫建议,"
|
||||
"必须人工审查后手动编辑 YAML 应用,不会自动生效。[/bold yellow]\n"
|
||||
)
|
||||
if not suggestions:
|
||||
rprint("[dim]未从失败轨迹中学习到风险守卫建议[/dim]")
|
||||
return
|
||||
|
||||
rprint(
|
||||
"[bold yellow]⚠ 以下为自动生成的风险守卫建议,"
|
||||
"必须人工审查后手动编辑 YAML 应用,不会自动生效。[/bold yellow]\n"
|
||||
)
|
||||
table = Table(title="Risk Guard Suggestions (待人工审查)")
|
||||
table.add_column("Skill", style="cyan")
|
||||
table.add_column("Precondition")
|
||||
|
|
|
|||
|
|
@ -93,7 +93,11 @@ class RiskGuardLearner:
|
|||
source_ids = [e.experience_id for e in failures if e.experience_id]
|
||||
|
||||
# 2. 构建 LLM prompt
|
||||
try:
|
||||
prompt = self._build_prompt(failures)
|
||||
except Exception as e:
|
||||
logger.warning(f"RiskGuardLearner: failed to build prompt: {e}")
|
||||
return []
|
||||
|
||||
# 3. 调用 LLM
|
||||
system_message = (
|
||||
|
|
@ -118,7 +122,11 @@ class RiskGuardLearner:
|
|||
return []
|
||||
|
||||
# 4. 解析响应
|
||||
return self._parse_response(response.content, failures, source_ids)
|
||||
try:
|
||||
return self._parse_response(response.content, source_ids)
|
||||
except Exception as e:
|
||||
logger.warning(f"RiskGuardLearner: failed to parse response: {e}")
|
||||
return []
|
||||
|
||||
def _build_prompt(self, failures: list[TaskExperience]) -> str:
|
||||
"""构建 LLM 提示词"""
|
||||
|
|
@ -132,9 +140,15 @@ class RiskGuardLearner:
|
|||
lines.append(f"- skill (task_type): {self._sanitize(exp.task_type)}")
|
||||
lines.append(f"- goal: {self._sanitize(exp.goal)}")
|
||||
lines.append(f"- steps_summary: {self._sanitize(exp.steps_summary)}")
|
||||
reasons = "; ".join(exp.failure_reasons) if exp.failure_reasons else "(none)"
|
||||
reasons = (
|
||||
"; ".join(str(r) for r in exp.failure_reasons) if exp.failure_reasons else "(none)"
|
||||
)
|
||||
lines.append(f"- failure_reasons: {self._sanitize(reasons)}")
|
||||
tips = "; ".join(exp.optimization_tips) if exp.optimization_tips else "(none)"
|
||||
tips = (
|
||||
"; ".join(str(t) for t in exp.optimization_tips)
|
||||
if exp.optimization_tips
|
||||
else "(none)"
|
||||
)
|
||||
lines.append(f"- optimization_tips: {self._sanitize(tips)}")
|
||||
lines.append("")
|
||||
|
||||
|
|
@ -149,7 +163,6 @@ class RiskGuardLearner:
|
|||
def _parse_response(
|
||||
self,
|
||||
content: str,
|
||||
failures: list[TaskExperience],
|
||||
source_ids: list[str],
|
||||
) -> list[RiskGuardSuggestion]:
|
||||
"""解析 LLM 响应为 RiskGuardSuggestion 列表"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, String, Text, create_engine
|
||||
from sqlalchemy import Column, DateTime, Float, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
|
|
@ -27,11 +27,11 @@ class EpisodeModel(Base):
|
|||
outcome = Column(String, default="success") # "success", "failure", "partial"
|
||||
quality_score = Column(Float, default=0.5)
|
||||
reflection = Column(Text, default="")
|
||||
embedding = Column(Text, nullable=True) # JSON-encoded float list; pgvector if extension available
|
||||
embedding = Column(
|
||||
Text, nullable=True
|
||||
) # JSON-encoded float list; pgvector if extension available
|
||||
metadata_ = Column("metadata", JSONB, nullable=True) # Additional metadata
|
||||
created_at = Column(
|
||||
DateTime, default=lambda: datetime.now(timezone.utc), index=True
|
||||
)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
|
||||
def create_episodic_session_factory(database_url: str):
|
||||
|
|
@ -51,6 +51,45 @@ def create_episodic_session_factory(database_url: str):
|
|||
return async_session
|
||||
|
||||
|
||||
class ExperienceModel(Base):
|
||||
"""Task experience ORM model for RiskGuardLearner / ExperienceStore.
|
||||
|
||||
Stores task execution outcomes (success/failure/partial) with optional
|
||||
pgvector embeddings for semantic similarity search.
|
||||
"""
|
||||
|
||||
__tablename__ = "task_experiences"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
task_type = Column(String, index=True)
|
||||
goal = Column(Text, default="")
|
||||
steps_summary = Column(Text, default="")
|
||||
outcome = Column(String, default="success") # "success", "failure", "partial"
|
||||
duration_seconds = Column(Float, default=0.0)
|
||||
success_rate = Column(Float, default=1.0)
|
||||
failure_reasons = Column(JSONB, default=list) # list[str]
|
||||
optimization_tips = Column(JSONB, default=list) # list[str]
|
||||
embedding = Column(Text, nullable=True) # JSON-encoded float list
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
|
||||
def create_experience_session_factory(database_url: str):
|
||||
"""Create an async session factory for task experiences.
|
||||
|
||||
Args:
|
||||
database_url: PostgreSQL connection string,
|
||||
e.g. "postgresql+asyncpg://user:pass@localhost/dbname"
|
||||
|
||||
Returns:
|
||||
async_sessionmaker bound to the engine.
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
return async_session
|
||||
|
||||
|
||||
async def ensure_episodic_table(database_url: str) -> None:
|
||||
"""Create the episodic_memories table if it does not exist.
|
||||
|
||||
|
|
|
|||
|
|
@ -126,6 +126,12 @@ class SkillConfig(AgentConfig):
|
|||
# v6: ReWOO fallback 策略(None 时 ReWOOEngine 用默认值)
|
||||
self.fallback_strategies = fallback_strategies
|
||||
# v7: 激活前置条件(软检查,由 build_skill_system_prompt 注入)+ 来源标记
|
||||
if preconditions is not None and not isinstance(preconditions, list):
|
||||
raise ConfigValidationError(
|
||||
agent_name=name,
|
||||
key="preconditions",
|
||||
reason=f"preconditions must be list[str] or None, got {type(preconditions).__name__}",
|
||||
)
|
||||
self.preconditions = preconditions
|
||||
self.provenance = provenance
|
||||
self._validate_v2()
|
||||
|
|
@ -152,10 +158,7 @@ class SkillConfig(AgentConfig):
|
|||
raise ConfigValidationError(
|
||||
agent_name=self.name,
|
||||
key="fallback_strategies",
|
||||
reason=(
|
||||
f"Invalid fallback_strategies {invalid}, "
|
||||
f"must be subset of {valid}"
|
||||
),
|
||||
reason=(f"Invalid fallback_strategies {invalid}, must be subset of {valid}"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -17,9 +17,14 @@ logger = logging.getLogger(__name__)
|
|||
SKILL_ENTRY_POINT_GROUP = "agentkit.skills"
|
||||
|
||||
# v7: 危险能力标签——entry_points 加载第三方 Skill 时命中则 logger.warning
|
||||
# 同时检查 capabilities 声明和 tools 绑定,防止恶意 skill 隐瞒能力声明
|
||||
_DANGEROUS_CAPABILITIES = frozenset(
|
||||
{"terminal", "code_execution", "file_write", "shell", "system_admin"}
|
||||
)
|
||||
# tools 列表中可能出现的危险工具名(与 _DANGEROUS_CAPABILITIES 部分重叠)
|
||||
_DANGEROUS_TOOL_NAMES = frozenset(
|
||||
{"shell", "terminal", "code_execution", "file_write", "file_system", "subprocess"}
|
||||
)
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
|
|
@ -95,13 +100,18 @@ class SkillLoader:
|
|||
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(
|
||||
frontmatter, sections, path, disclosure_level=disclosure_level,
|
||||
frontmatter,
|
||||
sections,
|
||||
path,
|
||||
disclosure_level=disclosure_level,
|
||||
)
|
||||
config.provenance = f"skill_md:{path}"
|
||||
tools = self._bind_tools(config)
|
||||
skill = Skill(config, tools=tools)
|
||||
self._skill_registry.register(skill)
|
||||
logger.info(f"Loaded skill '{skill.name}' from SKILL.md '{path}' (level={disclosure_level})")
|
||||
logger.info(
|
||||
f"Loaded skill '{skill.name}' from SKILL.md '{path}' (level={disclosure_level})"
|
||||
)
|
||||
return skill
|
||||
|
||||
def load_from_entry_points(self, group: str | None = None) -> list[Skill]:
|
||||
|
|
@ -128,9 +138,11 @@ class SkillLoader:
|
|||
# Python 3.12+ 使用 importlib.metadata
|
||||
if sys.version_info >= (3, 12):
|
||||
from importlib.metadata import entry_points as _entry_points
|
||||
|
||||
eps = _entry_points(group=group_name)
|
||||
else:
|
||||
from importlib.metadata import entry_points as _entry_points
|
||||
|
||||
eps = _entry_points().get(group_name, [])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to discover entry_points for group '{group_name}': {e}")
|
||||
|
|
@ -159,28 +171,29 @@ class SkillLoader:
|
|||
)
|
||||
continue
|
||||
|
||||
# v7: 记录 provenance + 危险能力告警
|
||||
# v7: 记录 provenance + 危险能力告警(同时检查 capabilities 和 tools)
|
||||
skill.config.provenance = f"entry_point:{ep.name}"
|
||||
dangerous = [
|
||||
dangerous_caps = [
|
||||
cap.tag
|
||||
for cap in (skill.config.capabilities or [])
|
||||
if cap.tag in _DANGEROUS_CAPABILITIES
|
||||
]
|
||||
dangerous_tools = [
|
||||
t for t in (skill.config.tools or []) if t in _DANGEROUS_TOOL_NAMES
|
||||
]
|
||||
dangerous = dangerous_caps + dangerous_tools
|
||||
if dangerous:
|
||||
logger.warning(
|
||||
f"Skill '{skill.name}' from entry_point '{ep.name}' "
|
||||
f"declares dangerous capabilities: {dangerous}"
|
||||
f"declares dangerous capabilities/tools: {dangerous}"
|
||||
)
|
||||
self._skill_registry.register(skill)
|
||||
skills.append(skill)
|
||||
logger.info(
|
||||
f"Loaded skill '{skill.name}' v{skill.version} "
|
||||
f"from entry_point '{ep.name}'"
|
||||
f"Loaded skill '{skill.name}' v{skill.version} from entry_point '{ep.name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load skill from entry_point '{ep.name}': {e}"
|
||||
)
|
||||
logger.warning(f"Failed to load skill from entry_point '{ep.name}': {e}")
|
||||
|
||||
return skills
|
||||
|
||||
|
|
@ -196,7 +209,5 @@ class SkillLoader:
|
|||
tools.append(tool)
|
||||
logger.info(f"Bound tool '{tool_name}' to skill '{config.name}'")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to bind tool '{tool_name}' to skill '{config.name}': {e}"
|
||||
)
|
||||
logger.warning(f"Failed to bind tool '{tool_name}' to skill '{config.name}': {e}")
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from agentkit.evolution.risk_guard_learner import RiskGuardSuggestion
|
||||
|
|
@ -10,7 +9,9 @@ from agentkit.evolution.risk_guard_learner import RiskGuardSuggestion
|
|||
runner = CliRunner()
|
||||
|
||||
|
||||
def _make_suggestion(skill_name="code_reviewer", precondition="需要代码输入", confidence=0.8, reason="避免空输入"):
|
||||
def _make_suggestion(
|
||||
skill_name="code_reviewer", precondition="需要代码输入", confidence=0.8, reason="避免空输入"
|
||||
):
|
||||
return RiskGuardSuggestion(
|
||||
skill_name=skill_name,
|
||||
precondition=precondition,
|
||||
|
|
@ -26,7 +27,9 @@ class TestLearnRiskGuardsCommand:
|
|||
from agentkit.cli.main import app
|
||||
|
||||
mock_learner = MagicMock()
|
||||
mock_learner.learn = AsyncMock(return_value=[_make_suggestion(), _make_suggestion("monitor", "需要网络", 0.6)])
|
||||
mock_learner.learn = AsyncMock(
|
||||
return_value=[_make_suggestion(), _make_suggestion("monitor", "需要网络", 0.6)]
|
||||
)
|
||||
with patch("agentkit.cli.skill._build_risk_guard_learner", return_value=mock_learner):
|
||||
result = runner.invoke(app, ["skill", "learn-risk-guards"])
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -47,13 +50,12 @@ class TestLearnRiskGuardsCommand:
|
|||
assert "未从失败轨迹中学习到风险守卫建议" in result.stdout
|
||||
|
||||
def test_learner_build_failure_exits_nonzero(self):
|
||||
"""_build_risk_guard_learner 返回 None → 错误信息 + 非零退出"""
|
||||
"""_build_risk_guard_learner 返回 None → 非零退出码"""
|
||||
from agentkit.cli.main import app
|
||||
|
||||
with patch("agentkit.cli.skill._build_risk_guard_learner", return_value=None):
|
||||
result = runner.invoke(app, ["skill", "learn-risk-guards"])
|
||||
assert result.exit_code == 1
|
||||
assert "无法构建" in result.stdout or "Error" in result.stdout
|
||||
|
||||
def test_skill_option_passed_to_learn(self):
|
||||
"""--skill 参数透传给 learn(skill_name=...)"""
|
||||
|
|
@ -80,5 +82,75 @@ class TestLearnRiskGuardsCommand:
|
|||
"""--server-url 远程模式暂不支持"""
|
||||
from agentkit.cli.main import app
|
||||
|
||||
result = runner.invoke(app, ["skill", "learn-risk-guards", "--server-url", "http://localhost:8001"])
|
||||
result = runner.invoke(
|
||||
app, ["skill", "learn-risk-guards", "--server-url", "http://localhost:8001"]
|
||||
)
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
class TestBuildRiskGuardLearnerErrorPaths:
|
||||
"""测试 _build_risk_guard_learner 的真实错误路径(不 mock 函数本身)"""
|
||||
|
||||
def test_no_config_file_returns_none(self):
|
||||
"""find_config_path 返回 None → 打印错误 + 返回 None"""
|
||||
from agentkit.cli import skill as skill_module
|
||||
|
||||
with patch("agentkit.server.config.find_config_path", return_value=None):
|
||||
result = skill_module._build_risk_guard_learner()
|
||||
assert result is None
|
||||
|
||||
def test_no_database_url_returns_none(self):
|
||||
"""server_config 无 database_url → 返回 None"""
|
||||
from agentkit.cli import skill as skill_module
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.evolution = {}
|
||||
mock_config.memory = {}
|
||||
with (
|
||||
patch("agentkit.server.config.find_config_path", return_value="/fake/path.yaml"),
|
||||
patch("agentkit.server.config.load_config_with_dotenv", return_value=mock_config),
|
||||
patch("agentkit.cli.chat._build_gateway", return_value=MagicMock()),
|
||||
patch.dict("os.environ", {}, clear=False),
|
||||
):
|
||||
# Ensure DATABASE_URL is not set
|
||||
import os
|
||||
|
||||
old = os.environ.pop("DATABASE_URL", None)
|
||||
try:
|
||||
result = skill_module._build_risk_guard_learner()
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["DATABASE_URL"] = old
|
||||
assert result is None
|
||||
|
||||
def test_try_get_experience_store_no_database_url(self):
|
||||
"""_try_get_experience_store 无 database_url → 返回 None"""
|
||||
from agentkit.cli import skill as skill_module
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.evolution = {}
|
||||
mock_config.memory = {"episodic": {}}
|
||||
with patch.dict("os.environ", {}, clear=False):
|
||||
import os
|
||||
|
||||
old = os.environ.pop("DATABASE_URL", None)
|
||||
try:
|
||||
result = skill_module._try_get_experience_store(mock_config)
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["DATABASE_URL"] = old
|
||||
assert result is None
|
||||
|
||||
def test_try_get_experience_store_with_database_url(self):
|
||||
"""_try_get_experience_store 有 database_url → 构建 ExperienceStore"""
|
||||
from agentkit.cli import skill as skill_module
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.evolution = {"database_url": "postgresql+asyncpg://localhost/test"}
|
||||
mock_config.memory = {}
|
||||
with patch(
|
||||
"agentkit.memory.models.create_experience_session_factory",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
result = skill_module._try_get_experience_store(mock_config)
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from agentkit.evolution.experience_schema import TaskExperience
|
||||
from agentkit.evolution.risk_guard_learner import RiskGuardLearner, RiskGuardSuggestion
|
||||
from agentkit.evolution.risk_guard_learner import RiskGuardLearner
|
||||
|
||||
|
||||
def _make_experience(
|
||||
|
|
@ -45,7 +45,8 @@ class TestRiskGuardLearner:
|
|||
]
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = _make_llm_response(
|
||||
json.dumps([
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"skill_name": "code_reviewer",
|
||||
"precondition": "输入必须包含待审查的代码片段",
|
||||
|
|
@ -58,7 +59,8 @@ class TestRiskGuardLearner:
|
|||
"reason": "过短输入无法有效审查",
|
||||
"confidence": 0.6,
|
||||
},
|
||||
])
|
||||
]
|
||||
)
|
||||
)
|
||||
learner = RiskGuardLearner(store, llm)
|
||||
suggestions = await learner.learn()
|
||||
|
|
@ -77,9 +79,7 @@ class TestRiskGuardLearner:
|
|||
llm.chat.return_value = _make_llm_response("[]")
|
||||
learner = RiskGuardLearner(store, llm)
|
||||
await learner.learn(skill_name="code_reviewer")
|
||||
store.search.assert_called_once_with(
|
||||
query="failure", top_k=20, task_type="code_reviewer"
|
||||
)
|
||||
store.search.assert_called_once_with(query="failure", top_k=20, task_type="code_reviewer")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_learn_llm_exception_returns_empty(self):
|
||||
|
|
@ -119,21 +119,20 @@ class TestRiskGuardLearner:
|
|||
"""只保留 outcome == 'failure' 的轨迹"""
|
||||
store = AsyncMock()
|
||||
store.search.return_value = [
|
||||
_make_experience("e1", outcome="failure"),
|
||||
_make_experience("e2", outcome="success"),
|
||||
_make_experience("e3", outcome="partial"),
|
||||
_make_experience("e1", goal="failure-goal", outcome="failure"),
|
||||
_make_experience("e2", goal="success-goal", outcome="success"),
|
||||
_make_experience("e3", goal="partial-goal", outcome="partial"),
|
||||
]
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = _make_llm_response("[]")
|
||||
learner = RiskGuardLearner(store, llm)
|
||||
await learner.learn()
|
||||
# 只有 e1 是 failure,source_experience_ids 应只含 e1
|
||||
# 通过检查 prompt 中是否只含 e1 来验证
|
||||
# 只有 e1 是 failure,prompt 中应含 failure-goal,不含 success/partial 的 goal
|
||||
call_args = llm.chat.call_args
|
||||
prompt = call_args.kwargs["messages"][1]["content"]
|
||||
assert "e1" in prompt or "review code" in prompt
|
||||
# success/partial 的 goal 不应出现(它们 goal 都是 "review code",改用 task_type 区分)
|
||||
# 更精确:检查 prompt 中 failure 轨迹数
|
||||
assert "failure-goal" in prompt
|
||||
assert "success-goal" not in prompt
|
||||
assert "partial-goal" not in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_clamped(self):
|
||||
|
|
@ -142,11 +141,13 @@ class TestRiskGuardLearner:
|
|||
store.search.return_value = [_make_experience("e1")]
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = _make_llm_response(
|
||||
json.dumps([
|
||||
json.dumps(
|
||||
[
|
||||
{"skill_name": "s", "precondition": "p1", "reason": "r", "confidence": 1.5},
|
||||
{"skill_name": "s", "precondition": "p2", "reason": "r", "confidence": -0.3},
|
||||
{"skill_name": "s", "precondition": "p3", "reason": "r", "confidence": 0.5},
|
||||
])
|
||||
]
|
||||
)
|
||||
)
|
||||
learner = RiskGuardLearner(store, llm)
|
||||
suggestions = await learner.learn()
|
||||
|
|
@ -176,11 +177,13 @@ class TestRiskGuardLearner:
|
|||
store.search.return_value = [_make_experience("e1")]
|
||||
llm = AsyncMock()
|
||||
llm.chat.return_value = _make_llm_response(
|
||||
json.dumps([
|
||||
json.dumps(
|
||||
[
|
||||
{"skill_name": "s", "precondition": "", "reason": "r", "confidence": 0.5},
|
||||
{"skill_name": "", "precondition": "p", "reason": "r", "confidence": 0.5},
|
||||
{"skill_name": "s", "precondition": "valid", "reason": "r", "confidence": 0.5},
|
||||
])
|
||||
]
|
||||
)
|
||||
)
|
||||
learner = RiskGuardLearner(store, llm)
|
||||
suggestions = await learner.learn()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
"""SkillConfig v7 preconditions + provenance 字段单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.exceptions import ConfigValidationError
|
||||
from agentkit.skills.base import SkillConfig
|
||||
|
||||
# llm_generate 模式要求 prompt,所有构造提供最小 prompt
|
||||
|
|
@ -72,3 +75,25 @@ class TestSkillConfigPreconditions:
|
|||
out = config.to_dict()
|
||||
assert out["preconditions"] == ["条件1", "条件2"]
|
||||
assert out["provenance"] == "skill_md:foo.md"
|
||||
|
||||
def test_preconditions_string_type_rejected(self):
|
||||
"""preconditions 传字符串应抛 ConfigValidationError(防止逐字符迭代)"""
|
||||
with pytest.raises(ConfigValidationError, match="preconditions"):
|
||||
SkillConfig(
|
||||
name="x",
|
||||
agent_type="y",
|
||||
task_mode="llm_generate",
|
||||
prompt=_PROMPT,
|
||||
preconditions="必须提供代码", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_preconditions_dict_type_rejected(self):
|
||||
"""preconditions 传 dict 应抛 ConfigValidationError"""
|
||||
with pytest.raises(ConfigValidationError, match="preconditions"):
|
||||
SkillConfig(
|
||||
name="x",
|
||||
agent_type="y",
|
||||
task_mode="llm_generate",
|
||||
prompt=_PROMPT,
|
||||
preconditions={"key": "val"}, # type: ignore[arg-type]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
|
|
@ -30,13 +29,14 @@ class _FakeEntryPoint:
|
|||
return self._skill
|
||||
|
||||
|
||||
def _make_skill(name: str = "ep_skill", capabilities=None) -> Skill:
|
||||
def _make_skill(name: str = "ep_skill", capabilities=None, tools=None) -> Skill:
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
capabilities=capabilities,
|
||||
tools=tools,
|
||||
)
|
||||
return Skill(config)
|
||||
|
||||
|
|
@ -46,19 +46,23 @@ class TestSkillLoaderProvenance:
|
|||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_yaml(tmpdir, "s.yaml", {
|
||||
path = _write_yaml(
|
||||
tmpdir,
|
||||
"s.yaml",
|
||||
{
|
||||
"name": "s",
|
||||
"agent_type": "t",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "x"},
|
||||
})
|
||||
},
|
||||
)
|
||||
skill = loader.load_from_file(path)
|
||||
assert skill.config.provenance == f"yaml:{path}"
|
||||
|
||||
def test_load_from_skill_md_sets_provenance(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
skill_md = '''\
|
||||
skill_md = """\
|
||||
---
|
||||
name: md-skill
|
||||
description: "test"
|
||||
|
|
@ -77,7 +81,7 @@ execution_mode: react
|
|||
|
||||
# Verification
|
||||
- ok
|
||||
'''
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "SKILL.md")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
|
|
@ -113,7 +117,28 @@ execution_mode: react
|
|||
assert skills[0].config.provenance == "entry_point:dangerous_ep"
|
||||
# warning 包含 skill 名与危险能力
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert any("dangerous_skill" in r.getMessage() and "shell" in r.getMessage() for r in warnings)
|
||||
assert any(
|
||||
"dangerous_skill" in r.getMessage() and "shell" in r.getMessage() for r in warnings
|
||||
)
|
||||
|
||||
def test_entry_points_dangerous_tools_warning(self, caplog):
|
||||
"""entry_points 加载绑定 shell 工具但未声明 capabilities 的 Skill 时触发 warning"""
|
||||
import logging
|
||||
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
# 有危险 tools 但无 capabilities 声明——旧逻辑会漏检
|
||||
dangerous_skill = _make_skill("stealthy_skill", capabilities=None, tools=["shell"])
|
||||
fake_ep = _FakeEntryPoint("stealthy_ep", dangerous_skill)
|
||||
with patch("agentkit.skills.loader.sys.version_info", (3, 12, 0)):
|
||||
with patch("importlib.metadata.entry_points", return_value=[fake_ep]):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
skills = loader.load_from_entry_points()
|
||||
assert len(skills) == 1
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert any(
|
||||
"stealthy_skill" in r.getMessage() and "shell" in r.getMessage() for r in warnings
|
||||
)
|
||||
|
||||
def test_entry_points_no_capabilities_no_warning(self, caplog):
|
||||
import logging
|
||||
|
|
@ -129,7 +154,8 @@ execution_mode: react
|
|||
assert len(skills) == 1
|
||||
# 不应有危险能力 warning(只可能有其他 warning)
|
||||
dangerous_warnings = [
|
||||
r for r in caplog.records
|
||||
r
|
||||
for r in caplog.records
|
||||
if r.levelno == logging.WARNING and "dangerous capabilities" in r.getMessage()
|
||||
]
|
||||
assert dangerous_warnings == []
|
||||
|
|
@ -139,13 +165,17 @@ execution_mode: react
|
|||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_yaml(tmpdir, "s.yaml", {
|
||||
path = _write_yaml(
|
||||
tmpdir,
|
||||
"s.yaml",
|
||||
{
|
||||
"name": "s",
|
||||
"agent_type": "t",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "x"},
|
||||
"provenance": "user_supplied:should_be_overridden",
|
||||
})
|
||||
},
|
||||
)
|
||||
skill = loader.load_from_file(path)
|
||||
assert skill.config.provenance == f"yaml:{path}"
|
||||
assert "user_supplied" not in skill.config.provenance
|
||||
|
|
|
|||
Loading…
Reference in New Issue