209 lines
7.0 KiB
Python
209 lines
7.0 KiB
Python
"""SkillPipeline - 技能编排,将多个 Skill 串联为 Pipeline 执行
|
||
|
||
复用 PipelineEngine 的设计理念,支持:
|
||
- 顺序执行(skill A → skill B → skill C)
|
||
- 条件分支(if skill A output contains X, run skill B, else skip)
|
||
- 输出映射(将上一步输出字段映射到下一步输入字段)
|
||
"""
|
||
|
||
import logging
|
||
import re
|
||
from typing import Callable, Coroutine
|
||
|
||
from agentkit.skills.registry import SkillRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SkillPipeline:
|
||
"""将多个 Skill 串联为 Pipeline 执行
|
||
|
||
每个步骤定义包含:
|
||
- skill_name: str (必需) — 要执行的 Skill 名称
|
||
- input_mapping: dict | None — 将上一步输出映射到当前步骤输入
|
||
- condition: str | None — 条件表达式,不满足则跳过
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
steps: list[dict[str, object]],
|
||
skill_registry: SkillRegistry | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
name: Pipeline 名称
|
||
steps: 步骤定义列表,每项包含 skill_name、input_mapping、condition
|
||
skill_registry: 用于查找 Skill 的注册中心
|
||
"""
|
||
self.name = name
|
||
self._steps = steps
|
||
self._skill_registry = skill_registry
|
||
|
||
async def execute(
|
||
self,
|
||
input_data: dict[str, object],
|
||
agent_factory: Callable[..., Coroutine] | None = None,
|
||
) -> dict[str, object]:
|
||
"""顺序执行 Pipeline 中所有步骤
|
||
|
||
Args:
|
||
input_data: 初始输入数据
|
||
agent_factory: 可选的 Agent 工厂函数,签名为
|
||
async (skill_name: str, input_data: dict) -> dict
|
||
|
||
Returns:
|
||
包含 pipeline 名称、各步骤结果和最终输出的字典
|
||
"""
|
||
success = True
|
||
current_input: dict[str, object] = input_data
|
||
results: list[dict[str, object]] = []
|
||
|
||
for i, step_def in enumerate(self._steps):
|
||
skill_name = step_def["skill_name"]
|
||
|
||
# 条件检查
|
||
condition = step_def.get("condition")
|
||
if condition and not self._evaluate_condition(condition, current_input, results):
|
||
results.append({
|
||
"step": i,
|
||
"skill": skill_name,
|
||
"status": "skipped",
|
||
"reason": f"Condition not met: {condition}",
|
||
})
|
||
continue
|
||
|
||
# 输入映射
|
||
input_mapping = step_def.get("input_mapping")
|
||
step_input = (
|
||
self._map_input(current_input, input_mapping, results)
|
||
if input_mapping
|
||
else current_input
|
||
)
|
||
|
||
# 执行 Skill
|
||
try:
|
||
step_result = await self._execute_skill(skill_name, step_input, agent_factory)
|
||
results.append({
|
||
"step": i,
|
||
"skill": skill_name,
|
||
"output": step_result,
|
||
"status": "success",
|
||
})
|
||
current_input = step_result
|
||
except Exception as e:
|
||
results.append({
|
||
"step": i,
|
||
"skill": skill_name,
|
||
"error": str(e),
|
||
"status": "failed",
|
||
})
|
||
success = False
|
||
break
|
||
|
||
return {
|
||
"pipeline": self.name,
|
||
"steps": results,
|
||
"final_output": current_input if success else None,
|
||
"success": success,
|
||
}
|
||
|
||
async def _execute_skill(
|
||
self,
|
||
skill_name: str,
|
||
input_data: dict[str, object],
|
||
agent_factory: Callable[..., Coroutine] | None = None,
|
||
) -> dict[str, object]:
|
||
"""执行单个 Skill
|
||
|
||
优先使用 agent_factory,其次通过 SkillRegistry 查找 Skill 并创建 Agent 执行。
|
||
"""
|
||
if agent_factory:
|
||
return await agent_factory(skill_name, input_data)
|
||
|
||
if self._skill_registry:
|
||
try:
|
||
skill = self._skill_registry.get(skill_name)
|
||
except Exception:
|
||
raise ValueError(f"Skill '{skill_name}' not found in registry")
|
||
|
||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||
from agentkit.core.protocol import TaskMessage
|
||
from datetime import datetime, timezone
|
||
|
||
agent = ConfigDrivenAgent(config=skill.config)
|
||
task = TaskMessage(
|
||
task_id=f"pipeline-{skill_name}",
|
||
agent_name=skill_name,
|
||
task_type=skill.config.agent_type,
|
||
priority=0,
|
||
input_data=input_data,
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
return await agent.handle_task(task)
|
||
|
||
raise ValueError(
|
||
f"Cannot execute skill '{skill_name}': "
|
||
"no agent_factory or skill_registry provided"
|
||
)
|
||
|
||
def _evaluate_condition(
|
||
self,
|
||
condition: str,
|
||
current_input: dict[str, object],
|
||
results: list[dict[str, object]],
|
||
) -> bool:
|
||
"""评估简单条件表达式
|
||
|
||
支持格式:
|
||
- "key.path == 'value'" — 字符串相等
|
||
- "key.path > 0.5" — 数值大于
|
||
"""
|
||
try:
|
||
eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip())
|
||
if eq_match:
|
||
path = eq_match.group(1)
|
||
value = eq_match.group(2).strip().strip("'\"")
|
||
actual = self._resolve_path(path, current_input)
|
||
return str(actual) == value
|
||
gt_match = re.match(r'^([\w.]+)\s*>\s*(.+)$', condition.strip())
|
||
if gt_match:
|
||
path = gt_match.group(1)
|
||
value = float(gt_match.group(2).strip())
|
||
actual = float(self._resolve_path(path, current_input))
|
||
return actual > value
|
||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||
logger.warning(f"Condition evaluation failed for '{condition}': {e}")
|
||
return False
|
||
return False
|
||
|
||
@staticmethod
|
||
def _resolve_path(path: str, data: dict[str, object]) -> object:
|
||
"""解析点号路径,如 'output.score'"""
|
||
parts = path.split(".")
|
||
obj: object = data
|
||
for part in parts:
|
||
if isinstance(obj, dict):
|
||
obj = obj.get(part)
|
||
else:
|
||
return None
|
||
return obj
|
||
|
||
def _map_input(
|
||
self,
|
||
current_input: dict[str, object],
|
||
mapping: dict[str, str],
|
||
results: list[dict[str, object]],
|
||
) -> dict[str, object]:
|
||
"""根据映射规则将上一步输出映射到当前步骤输入
|
||
|
||
mapping 格式: {"target_key": "source.path"}
|
||
"""
|
||
mapped: dict[str, object] = {}
|
||
for target_key, source_path in mapping.items():
|
||
value = self._resolve_path(source_path, current_input)
|
||
if value is not None:
|
||
mapped[target_key] = value
|
||
return mapped
|