283 lines
9.5 KiB
Python
283 lines
9.5 KiB
Python
"""Pipeline YAML加载器 - 从YAML文件/字符串加载Pipeline定义,验证DAG,解析变量"""
|
||
|
||
import logging
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import yaml
|
||
|
||
from .schema import Pipeline, PipelineStage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 变量引用正则:匹配 ${var.path} 格式
|
||
VARIABLE_PATTERN = re.compile(r"\$\{([^}]+)\}")
|
||
|
||
|
||
class PipelineLoadError(Exception):
|
||
"""Pipeline加载错误"""
|
||
|
||
def __init__(self, pipeline_name: str, reason: str = ""):
|
||
self.pipeline_name = pipeline_name
|
||
self.reason = reason
|
||
super().__init__(f"Failed to load pipeline '{pipeline_name}': {reason}")
|
||
|
||
|
||
class PipelineCyclicError(PipelineLoadError):
|
||
"""Pipeline存在循环依赖"""
|
||
|
||
def __init__(self, pipeline_name: str):
|
||
super().__init__(pipeline_name, "Cyclic dependency detected in stages")
|
||
|
||
|
||
class PipelineValidationError(PipelineLoadError):
|
||
"""Pipeline验证失败"""
|
||
|
||
def __init__(self, pipeline_name: str, reason: str = ""):
|
||
super().__init__(pipeline_name, f"Validation error: {reason}")
|
||
|
||
|
||
class PipelineLoader:
|
||
"""从YAML文件或字符串加载Pipeline定义"""
|
||
|
||
def __init__(self, pipelines_dir: str | None = None):
|
||
"""
|
||
Args:
|
||
pipelines_dir: Pipeline YAML文件所在目录,默认为 backend/pipelines
|
||
"""
|
||
if pipelines_dir is None:
|
||
# 默认路径:backend/pipelines(相对于项目根目录)
|
||
pipelines_dir = str(Path(__file__).resolve().parents[3] / "pipelines")
|
||
self.pipelines_dir = Path(pipelines_dir)
|
||
|
||
def load(self, pipeline_name: str) -> Pipeline:
|
||
"""
|
||
从YAML文件加载Pipeline定义。
|
||
|
||
Args:
|
||
pipeline_name: Pipeline名称,对应 {pipelines_dir}/{pipeline_name}.yaml
|
||
|
||
Returns:
|
||
Pipeline对象
|
||
|
||
Raises:
|
||
PipelineLoadError: 文件不存在或解析失败
|
||
PipelineCyclicError: 存在循环依赖
|
||
PipelineValidationError: 验证失败
|
||
"""
|
||
yaml_path = self.pipelines_dir / f"{pipeline_name}.yaml"
|
||
|
||
if not yaml_path.exists():
|
||
# 尝试 .yml 后缀
|
||
yml_path = self.pipelines_dir / f"{pipeline_name}.yml"
|
||
if yml_path.exists():
|
||
yaml_path = yml_path
|
||
else:
|
||
raise PipelineLoadError(
|
||
pipeline_name,
|
||
f"Pipeline file not found: {yaml_path}",
|
||
)
|
||
|
||
try:
|
||
raw_content = yaml_path.read_text(encoding="utf-8")
|
||
except OSError as e:
|
||
raise PipelineLoadError(pipeline_name, f"Failed to read file: {e}")
|
||
|
||
return self.load_from_yaml(raw_content, pipeline_name=pipeline_name)
|
||
|
||
def load_from_yaml(
|
||
self, yaml_content: str, pipeline_name: str | None = None
|
||
) -> Pipeline:
|
||
"""
|
||
从YAML字符串加载Pipeline定义。
|
||
|
||
Args:
|
||
yaml_content: YAML格式字符串
|
||
pipeline_name: 可选的Pipeline名称(用于错误信息),默认从YAML内容中读取
|
||
|
||
Returns:
|
||
Pipeline对象
|
||
|
||
Raises:
|
||
PipelineLoadError: 解析失败
|
||
PipelineCyclicError: 存在循环依赖
|
||
PipelineValidationError: 验证失败
|
||
"""
|
||
try:
|
||
raw_data = yaml.safe_load(yaml_content)
|
||
except yaml.YAMLError as e:
|
||
name = pipeline_name or "<unknown>"
|
||
raise PipelineLoadError(name, f"YAML parse error: {e}")
|
||
|
||
if not isinstance(raw_data, dict):
|
||
name = pipeline_name or "<unknown>"
|
||
raise PipelineLoadError(name, "YAML content must be a mapping")
|
||
|
||
# 使用Pydantic解析
|
||
try:
|
||
pipeline = Pipeline.model_validate(raw_data)
|
||
except Exception as e:
|
||
name = pipeline_name or raw_data.get("name", "<unknown>")
|
||
raise PipelineLoadError(str(name), f"Schema validation error: {e}")
|
||
|
||
# 验证
|
||
name = pipeline.name
|
||
self._validate_stage_names(pipeline)
|
||
self._validate_dependencies(pipeline)
|
||
|
||
# DAG无环检测
|
||
if not self.validate_dag(pipeline.stages):
|
||
raise PipelineCyclicError(name)
|
||
|
||
logger.info(f"Pipeline '{name}' loaded successfully ({len(pipeline.stages)} stages)")
|
||
return pipeline
|
||
|
||
def _validate_stage_names(self, pipeline: Pipeline) -> None:
|
||
"""验证stage名称唯一性"""
|
||
names = [s.name for s in pipeline.stages]
|
||
duplicates = [n for n in names if names.count(n) > 1]
|
||
if duplicates:
|
||
unique_dups = sorted(set(duplicates))
|
||
raise PipelineValidationError(
|
||
pipeline.name,
|
||
f"Duplicate stage names: {unique_dups}",
|
||
)
|
||
|
||
def _validate_dependencies(self, pipeline: Pipeline) -> None:
|
||
"""验证depends_on引用的stage存在"""
|
||
stage_names = {s.name for s in pipeline.stages}
|
||
for stage in pipeline.stages:
|
||
for dep in stage.depends_on:
|
||
if dep not in stage_names:
|
||
raise PipelineValidationError(
|
||
pipeline.name,
|
||
f"Stage '{stage.name}' depends on unknown stage '{dep}'",
|
||
)
|
||
|
||
@staticmethod
|
||
def validate_dag(stages: list[PipelineStage]) -> bool:
|
||
"""
|
||
验证stages的依赖关系是有向无环图(DAG)。
|
||
使用Kahn算法进行拓扑排序检测环。
|
||
|
||
Args:
|
||
stages: PipelineStage列表
|
||
|
||
Returns:
|
||
True如果是DAG,False如果存在环
|
||
"""
|
||
stage_names = {s.name for s in stages}
|
||
name_to_stage = {s.name: s for s in stages}
|
||
|
||
# 构建邻接表和入度表
|
||
in_degree: dict[str, int] = {name: 0 for name in stage_names}
|
||
adj: dict[str, list[str]] = {name: [] for name in stage_names}
|
||
|
||
for stage in stages:
|
||
for dep in stage.depends_on:
|
||
if dep in stage_names:
|
||
adj[dep].append(stage.name)
|
||
in_degree[stage.name] += 1
|
||
|
||
# Kahn算法:入度为0的节点入队
|
||
queue: list[str] = [name for name, deg in in_degree.items() if deg == 0]
|
||
visited_count = 0
|
||
|
||
while queue:
|
||
node = queue.pop(0)
|
||
visited_count += 1
|
||
for neighbor in adj[node]:
|
||
in_degree[neighbor] -= 1
|
||
if in_degree[neighbor] == 0:
|
||
queue.append(neighbor)
|
||
|
||
# 如果所有节点都被访问,说明无环
|
||
return visited_count == len(stage_names)
|
||
|
||
@staticmethod
|
||
def resolve_variables(template: Any, context: dict) -> Any:
|
||
"""
|
||
解析${var.path}格式的变量引用。
|
||
|
||
支持嵌套路径访问,例如:
|
||
- ${stages.step1.outputs.result} → context["stages"]["step1"]["outputs"]["result"]
|
||
- ${brand_name} → context["brand_name"]
|
||
|
||
对于字符串模板:替换其中的变量引用
|
||
对于dict/list:递归处理其中的变量引用
|
||
对于其他类型:原样返回
|
||
|
||
Args:
|
||
template: 待解析的模板(可以是str/dict/list/其他)
|
||
context: 变量上下文
|
||
|
||
Returns:
|
||
解析后的值
|
||
"""
|
||
if isinstance(template, str):
|
||
return PipelineLoader._resolve_string(template, context)
|
||
elif isinstance(template, dict):
|
||
return {
|
||
k: PipelineLoader.resolve_variables(v, context)
|
||
for k, v in template.items()
|
||
}
|
||
elif isinstance(template, list):
|
||
return [PipelineLoader.resolve_variables(item, context) for item in template]
|
||
else:
|
||
# int, float, bool, None 等原样返回
|
||
return template
|
||
|
||
@staticmethod
|
||
def _resolve_string(template: str, context: dict) -> Any:
|
||
"""
|
||
解析字符串中的变量引用。
|
||
|
||
如果整个字符串就是一个变量引用(如"${brand_name}"),
|
||
则直接返回变量值(保留原始类型)。
|
||
如果字符串中包含多个变量引用或混合文本,
|
||
则替换为字符串。
|
||
"""
|
||
matches = list(VARIABLE_PATTERN.finditer(template))
|
||
|
||
if not matches:
|
||
return template
|
||
|
||
# 整个字符串就是一个变量引用 → 保留原始类型
|
||
if len(matches) == 1 and matches[0].group(0) == template:
|
||
path = matches[0].group(1).strip()
|
||
return PipelineLoader._resolve_path(path, context)
|
||
|
||
# 混合文本 → 替换为字符串
|
||
def replacer(match: re.Match) -> str:
|
||
path = match.group(1).strip()
|
||
value = PipelineLoader._resolve_path(path, context)
|
||
return str(value)
|
||
|
||
return VARIABLE_PATTERN.sub(replacer, template)
|
||
|
||
@staticmethod
|
||
def _resolve_path(path: str, context: dict) -> Any:
|
||
"""
|
||
解析点分路径,从context中获取值。
|
||
|
||
例如:stages.step1.outputs.result
|
||
→ context["stages"]["step1"]["outputs"]["result"]
|
||
"""
|
||
parts = path.split(".")
|
||
current: Any = context
|
||
|
||
for part in parts:
|
||
if isinstance(current, dict):
|
||
if part in current:
|
||
current = current[part]
|
||
else:
|
||
logger.warning(f"Variable path '{path}' not found in context (missing key: '{part}')")
|
||
# 返回原始占位符
|
||
return f"${{{path}}}"
|
||
else:
|
||
logger.warning(f"Variable path '{path}' cannot be resolved (non-dict at '{part}')")
|
||
return f"${{{path}}}"
|
||
|
||
return current
|