geo/backend/app/agent_framework/pipeline/loader.py

283 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pipeline YAML加载器 - 从YAML文件/字符串加载Pipeline定义验证DAG解析变量"""
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from .schema import Pipeline, PipelineStage
logger = logging.getLogger(__name__)
# 变量引用正则:匹配 ${var.path} 格式
VARIABLE_PATTERN = re.compile(r"\$\{([^}]+)\}")
class PipelineLoadError(Exception):
"""Pipeline加载错误"""
def __init__(self, pipeline_name: str, reason: str = ""):
self.pipeline_name = pipeline_name
self.reason = reason
super().__init__(f"Failed to load pipeline '{pipeline_name}': {reason}")
class PipelineCyclicError(PipelineLoadError):
"""Pipeline存在循环依赖"""
def __init__(self, pipeline_name: str):
super().__init__(pipeline_name, "Cyclic dependency detected in stages")
class PipelineValidationError(PipelineLoadError):
"""Pipeline验证失败"""
def __init__(self, pipeline_name: str, reason: str = ""):
super().__init__(pipeline_name, f"Validation error: {reason}")
class PipelineLoader:
"""从YAML文件或字符串加载Pipeline定义"""
def __init__(self, pipelines_dir: str | None = None):
"""
Args:
pipelines_dir: Pipeline YAML文件所在目录默认为 backend/pipelines
"""
if pipelines_dir is None:
# 默认路径backend/pipelines相对于项目根目录
pipelines_dir = str(Path(__file__).resolve().parents[3] / "pipelines")
self.pipelines_dir = Path(pipelines_dir)
def load(self, pipeline_name: str) -> Pipeline:
"""
从YAML文件加载Pipeline定义。
Args:
pipeline_name: Pipeline名称对应 {pipelines_dir}/{pipeline_name}.yaml
Returns:
Pipeline对象
Raises:
PipelineLoadError: 文件不存在或解析失败
PipelineCyclicError: 存在循环依赖
PipelineValidationError: 验证失败
"""
yaml_path = self.pipelines_dir / f"{pipeline_name}.yaml"
if not yaml_path.exists():
# 尝试 .yml 后缀
yml_path = self.pipelines_dir / f"{pipeline_name}.yml"
if yml_path.exists():
yaml_path = yml_path
else:
raise PipelineLoadError(
pipeline_name,
f"Pipeline file not found: {yaml_path}",
)
try:
raw_content = yaml_path.read_text(encoding="utf-8")
except OSError as e:
raise PipelineLoadError(pipeline_name, f"Failed to read file: {e}")
return self.load_from_yaml(raw_content, pipeline_name=pipeline_name)
def load_from_yaml(
self, yaml_content: str, pipeline_name: str | None = None
) -> Pipeline:
"""
从YAML字符串加载Pipeline定义。
Args:
yaml_content: YAML格式字符串
pipeline_name: 可选的Pipeline名称用于错误信息默认从YAML内容中读取
Returns:
Pipeline对象
Raises:
PipelineLoadError: 解析失败
PipelineCyclicError: 存在循环依赖
PipelineValidationError: 验证失败
"""
try:
raw_data = yaml.safe_load(yaml_content)
except yaml.YAMLError as e:
name = pipeline_name or "<unknown>"
raise PipelineLoadError(name, f"YAML parse error: {e}")
if not isinstance(raw_data, dict):
name = pipeline_name or "<unknown>"
raise PipelineLoadError(name, "YAML content must be a mapping")
# 使用Pydantic解析
try:
pipeline = Pipeline.model_validate(raw_data)
except Exception as e:
name = pipeline_name or raw_data.get("name", "<unknown>")
raise PipelineLoadError(str(name), f"Schema validation error: {e}")
# 验证
name = pipeline.name
self._validate_stage_names(pipeline)
self._validate_dependencies(pipeline)
# DAG无环检测
if not self.validate_dag(pipeline.stages):
raise PipelineCyclicError(name)
logger.info(f"Pipeline '{name}' loaded successfully ({len(pipeline.stages)} stages)")
return pipeline
def _validate_stage_names(self, pipeline: Pipeline) -> None:
"""验证stage名称唯一性"""
names = [s.name for s in pipeline.stages]
duplicates = [n for n in names if names.count(n) > 1]
if duplicates:
unique_dups = sorted(set(duplicates))
raise PipelineValidationError(
pipeline.name,
f"Duplicate stage names: {unique_dups}",
)
def _validate_dependencies(self, pipeline: Pipeline) -> None:
"""验证depends_on引用的stage存在"""
stage_names = {s.name for s in pipeline.stages}
for stage in pipeline.stages:
for dep in stage.depends_on:
if dep not in stage_names:
raise PipelineValidationError(
pipeline.name,
f"Stage '{stage.name}' depends on unknown stage '{dep}'",
)
@staticmethod
def validate_dag(stages: list[PipelineStage]) -> bool:
"""
验证stages的依赖关系是有向无环图DAG
使用Kahn算法进行拓扑排序检测环。
Args:
stages: PipelineStage列表
Returns:
True如果是DAGFalse如果存在环
"""
stage_names = {s.name for s in stages}
name_to_stage = {s.name: s for s in stages}
# 构建邻接表和入度表
in_degree: dict[str, int] = {name: 0 for name in stage_names}
adj: dict[str, list[str]] = {name: [] for name in stage_names}
for stage in stages:
for dep in stage.depends_on:
if dep in stage_names:
adj[dep].append(stage.name)
in_degree[stage.name] += 1
# Kahn算法入度为0的节点入队
queue: list[str] = [name for name, deg in in_degree.items() if deg == 0]
visited_count = 0
while queue:
node = queue.pop(0)
visited_count += 1
for neighbor in adj[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# 如果所有节点都被访问,说明无环
return visited_count == len(stage_names)
@staticmethod
def resolve_variables(template: Any, context: dict) -> Any:
"""
解析${var.path}格式的变量引用。
支持嵌套路径访问,例如:
- ${stages.step1.outputs.result} → context["stages"]["step1"]["outputs"]["result"]
- ${brand_name} → context["brand_name"]
对于字符串模板:替换其中的变量引用
对于dict/list递归处理其中的变量引用
对于其他类型:原样返回
Args:
template: 待解析的模板可以是str/dict/list/其他)
context: 变量上下文
Returns:
解析后的值
"""
if isinstance(template, str):
return PipelineLoader._resolve_string(template, context)
elif isinstance(template, dict):
return {
k: PipelineLoader.resolve_variables(v, context)
for k, v in template.items()
}
elif isinstance(template, list):
return [PipelineLoader.resolve_variables(item, context) for item in template]
else:
# int, float, bool, None 等原样返回
return template
@staticmethod
def _resolve_string(template: str, context: dict) -> Any:
"""
解析字符串中的变量引用。
如果整个字符串就是一个变量引用(如"${brand_name}"
则直接返回变量值(保留原始类型)。
如果字符串中包含多个变量引用或混合文本,
则替换为字符串。
"""
matches = list(VARIABLE_PATTERN.finditer(template))
if not matches:
return template
# 整个字符串就是一个变量引用 → 保留原始类型
if len(matches) == 1 and matches[0].group(0) == template:
path = matches[0].group(1).strip()
return PipelineLoader._resolve_path(path, context)
# 混合文本 → 替换为字符串
def replacer(match: re.Match) -> str:
path = match.group(1).strip()
value = PipelineLoader._resolve_path(path, context)
return str(value)
return VARIABLE_PATTERN.sub(replacer, template)
@staticmethod
def _resolve_path(path: str, context: dict) -> Any:
"""
解析点分路径从context中获取值。
例如stages.step1.outputs.result
→ context["stages"]["step1"]["outputs"]["result"]
"""
parts = path.split(".")
current: Any = context
for part in parts:
if isinstance(current, dict):
if part in current:
current = current[part]
else:
logger.warning(f"Variable path '{path}' not found in context (missing key: '{part}')")
# 返回原始占位符
return f"${{{path}}}"
else:
logger.warning(f"Variable path '{path}' cannot be resolved (non-dict at '{part}')")
return f"${{{path}}}"
return current