fischer-agentkit/src/agentkit/skills/pipeline.py

209 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""SkillPipeline - 技能编排,将多个 Skill 串联为 Pipeline 执行
复用 PipelineEngine 的设计理念,支持:
- 顺序执行skill A → skill B → skill C
- 条件分支if skill A output contains X, run skill B, else skip
- 输出映射(将上一步输出字段映射到下一步输入字段)
"""
import logging
import re
from typing import Callable, Coroutine
from agentkit.skills.registry import SkillRegistry
logger = logging.getLogger(__name__)
class SkillPipeline:
"""将多个 Skill 串联为 Pipeline 执行
每个步骤定义包含:
- skill_name: str (必需) — 要执行的 Skill 名称
- input_mapping: dict | None — 将上一步输出映射到当前步骤输入
- condition: str | None — 条件表达式,不满足则跳过
"""
def __init__(
self,
name: str,
steps: list[dict[str, object]],
skill_registry: SkillRegistry | None = None,
):
"""
Args:
name: Pipeline 名称
steps: 步骤定义列表,每项包含 skill_name、input_mapping、condition
skill_registry: 用于查找 Skill 的注册中心
"""
self.name = name
self._steps = steps
self._skill_registry = skill_registry
async def execute(
self,
input_data: dict[str, object],
agent_factory: Callable[..., Coroutine] | None = None,
) -> dict[str, object]:
"""顺序执行 Pipeline 中所有步骤
Args:
input_data: 初始输入数据
agent_factory: 可选的 Agent 工厂函数,签名为
async (skill_name: str, input_data: dict) -> dict
Returns:
包含 pipeline 名称、各步骤结果和最终输出的字典
"""
success = True
current_input: dict[str, object] = input_data
results: list[dict[str, object]] = []
for i, step_def in enumerate(self._steps):
skill_name = step_def["skill_name"]
# 条件检查
condition = step_def.get("condition")
if condition and not self._evaluate_condition(condition, current_input, results):
results.append({
"step": i,
"skill": skill_name,
"status": "skipped",
"reason": f"Condition not met: {condition}",
})
continue
# 输入映射
input_mapping = step_def.get("input_mapping")
step_input = (
self._map_input(current_input, input_mapping, results)
if input_mapping
else current_input
)
# 执行 Skill
try:
step_result = await self._execute_skill(skill_name, step_input, agent_factory)
results.append({
"step": i,
"skill": skill_name,
"output": step_result,
"status": "success",
})
current_input = step_result
except Exception as e:
results.append({
"step": i,
"skill": skill_name,
"error": str(e),
"status": "failed",
})
success = False
break
return {
"pipeline": self.name,
"steps": results,
"final_output": current_input if success else None,
"success": success,
}
async def _execute_skill(
self,
skill_name: str,
input_data: dict[str, object],
agent_factory: Callable[..., Coroutine] | None = None,
) -> dict[str, object]:
"""执行单个 Skill
优先使用 agent_factory其次通过 SkillRegistry 查找 Skill 并创建 Agent 执行。
"""
if agent_factory:
return await agent_factory(skill_name, input_data)
if self._skill_registry:
try:
skill = self._skill_registry.get(skill_name)
except Exception:
raise ValueError(f"Skill '{skill_name}' not found in registry")
from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage
from datetime import datetime, timezone
agent = ConfigDrivenAgent(config=skill.config)
task = TaskMessage(
task_id=f"pipeline-{skill_name}",
agent_name=skill_name,
task_type=skill.config.agent_type,
priority=0,
input_data=input_data,
callback_url=None,
created_at=datetime.now(timezone.utc),
)
return await agent.handle_task(task)
raise ValueError(
f"Cannot execute skill '{skill_name}': "
"no agent_factory or skill_registry provided"
)
def _evaluate_condition(
self,
condition: str,
current_input: dict[str, object],
results: list[dict[str, object]],
) -> bool:
"""评估简单条件表达式
支持格式:
- "key.path == 'value'" — 字符串相等
- "key.path > 0.5" — 数值大于
"""
try:
eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip())
if eq_match:
path = eq_match.group(1)
value = eq_match.group(2).strip().strip("'\"")
actual = self._resolve_path(path, current_input)
return str(actual) == value
gt_match = re.match(r'^([\w.]+)\s*>\s*(.+)$', condition.strip())
if gt_match:
path = gt_match.group(1)
value = float(gt_match.group(2).strip())
actual = float(self._resolve_path(path, current_input))
return actual > value
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.warning(f"Condition evaluation failed for '{condition}': {e}")
return False
return False
@staticmethod
def _resolve_path(path: str, data: dict[str, object]) -> object:
"""解析点号路径,如 'output.score'"""
parts = path.split(".")
obj: object = data
for part in parts:
if isinstance(obj, dict):
obj = obj.get(part)
else:
return None
return obj
def _map_input(
self,
current_input: dict[str, object],
mapping: dict[str, str],
results: list[dict[str, object]],
) -> dict[str, object]:
"""根据映射规则将上一步输出映射到当前步骤输入
mapping 格式: {"target_key": "source.path"}
"""
mapped: dict[str, object] = {}
for target_key, source_path in mapping.items():
value = self._resolve_path(source_path, current_input)
if value is not None:
mapped[target_key] = value
return mapped