feat(skills): U5 GEO Pipeline orchestration with DAG execution
- GEOPipeline: YAML-driven DAG pipeline with parallel/sequential execution - PipelineStep with input_mapping ($.input.xxx, $.steps.name.output.xxx) - Topological sort for execution groups, SharedWorkspace integration - geo_full_pipeline.yaml: detect→analyze→optimize→track workflow - 10 tests passing
This commit is contained in:
parent
23934602c0
commit
1390bd8d6e
|
|
@ -0,0 +1,42 @@
|
|||
name: geo_full_pipeline
|
||||
description: "GEO 端到端工作流:检测→分析→优化→追踪"
|
||||
|
||||
steps:
|
||||
- name: detect
|
||||
skill: citation_detector
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
platforms: $.input.platforms
|
||||
|
||||
- name: analyze_competitor
|
||||
skill: competitor_analyzer
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
detection_result: $.steps.detect.output
|
||||
depends_on: [detect]
|
||||
|
||||
- name: analyze_trend
|
||||
skill: trend_agent
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
depends_on: [detect]
|
||||
|
||||
- name: optimize
|
||||
skill: geo_optimizer
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
analysis: $.steps.analyze_competitor.output
|
||||
depends_on: [analyze_competitor, analyze_trend]
|
||||
|
||||
- name: schema
|
||||
skill: schema_advisor
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
optimization: $.steps.optimize.output
|
||||
depends_on: [optimize]
|
||||
|
||||
- name: monitor
|
||||
skill: monitor
|
||||
input_mapping:
|
||||
brand: $.input.brand
|
||||
depends_on: [optimize]
|
||||
|
|
@ -0,0 +1,395 @@
|
|||
"""GEOPipeline - GEO 端到端工作流编排
|
||||
|
||||
实现检测→分析→优化→追踪的 DAG Pipeline,
|
||||
基于 Orchestrator 的多 Agent 协作模式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
from agentkit.core.shared_workspace import SharedWorkspace
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStep:
|
||||
"""Pipeline 步骤定义"""
|
||||
|
||||
name: str
|
||||
skill: str
|
||||
input_mapping: dict[str, str] = field(default_factory=dict)
|
||||
depends_on: list[str] = field(default_factory=list)
|
||||
condition: str | None = None
|
||||
parallel_with: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineStepResult:
|
||||
"""步骤执行结果"""
|
||||
|
||||
step_name: str
|
||||
skill: str
|
||||
status: str # "success", "failed", "skipped"
|
||||
output: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
duration_ms: float = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Pipeline 执行结果"""
|
||||
|
||||
pipeline_name: str
|
||||
execution_id: str
|
||||
steps: list[PipelineStepResult]
|
||||
final_output: dict[str, Any] | None
|
||||
success: bool
|
||||
total_duration_ms: float
|
||||
|
||||
|
||||
class GEOPipeline:
|
||||
"""GEO 端到端工作流编排
|
||||
|
||||
支持:
|
||||
- YAML 配置驱动的 Pipeline 定义
|
||||
- DAG 依赖关系(depends_on)
|
||||
- 并行执行无依赖的步骤
|
||||
- 步骤间数据通过 SharedWorkspace 传递
|
||||
- 条件跳过步骤
|
||||
|
||||
使用方式:
|
||||
pipeline = GEOPipeline.from_config(config, skill_registry, agent_pool)
|
||||
result = await pipeline.execute(input_data)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
steps: list[PipelineStep],
|
||||
skill_registry: SkillRegistry | None = None,
|
||||
agent_pool: Any = None,
|
||||
workspace: SharedWorkspace | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self._steps = steps
|
||||
self._skill_registry = skill_registry
|
||||
self._agent_pool = agent_pool
|
||||
self._workspace = workspace or SharedWorkspace()
|
||||
self._step_map = {s.name: s for s in steps}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
skill_registry: SkillRegistry | None = None,
|
||||
agent_pool: Any = None,
|
||||
workspace: SharedWorkspace | None = None,
|
||||
) -> GEOPipeline:
|
||||
"""从 YAML 配置创建 Pipeline
|
||||
|
||||
配置格式:
|
||||
name: geo_full_pipeline
|
||||
steps:
|
||||
- name: detect
|
||||
skill: citation_detector
|
||||
input_mapping: {brand: $.input.brand}
|
||||
- name: analyze
|
||||
skill: competitor_analyzer
|
||||
depends_on: [detect]
|
||||
"""
|
||||
steps = []
|
||||
for step_conf in config.get("steps", []):
|
||||
step = PipelineStep(
|
||||
name=step_conf["name"],
|
||||
skill=step_conf["skill"],
|
||||
input_mapping=step_conf.get("input_mapping", {}),
|
||||
depends_on=step_conf.get("depends_on", []),
|
||||
condition=step_conf.get("condition"),
|
||||
parallel_with=step_conf.get("parallel_with", []),
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
return cls(
|
||||
name=config.get("name", "geo_pipeline"),
|
||||
steps=steps,
|
||||
skill_registry=skill_registry,
|
||||
agent_pool=agent_pool,
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: dict[str, Any]) -> PipelineResult:
|
||||
"""执行 Pipeline
|
||||
|
||||
Args:
|
||||
input_data: 初始输入数据
|
||||
|
||||
Returns:
|
||||
PipelineResult: 包含各步骤结果和最终输出
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
execution_id = str(uuid.uuid4())[:8]
|
||||
step_results: list[PipelineStepResult] = []
|
||||
step_outputs: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Store initial input in workspace
|
||||
await self._workspace.write(
|
||||
f"pipeline:{execution_id}:input",
|
||||
input_data,
|
||||
agent_id="pipeline",
|
||||
)
|
||||
|
||||
# Build execution order (topological sort)
|
||||
execution_groups = self._build_execution_groups()
|
||||
|
||||
for group in execution_groups:
|
||||
# Execute group in parallel
|
||||
tasks = []
|
||||
for step_name in group:
|
||||
step = self._step_map[step_name]
|
||||
tasks.append(self._execute_step(step, input_data, step_outputs, execution_id))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for step_name, result in zip(group, results):
|
||||
if isinstance(result, Exception):
|
||||
step_result = PipelineStepResult(
|
||||
step_name=step_name,
|
||||
skill=self._step_map[step_name].skill,
|
||||
status="failed",
|
||||
error=str(result),
|
||||
)
|
||||
else:
|
||||
step_result = result
|
||||
|
||||
step_results.append(step_result)
|
||||
if step_result.status == "success" and step_result.output:
|
||||
step_outputs[step_name] = step_result.output
|
||||
|
||||
# Build final output
|
||||
final_output = self._build_final_output(step_outputs, input_data)
|
||||
|
||||
duration_ms = (time.monotonic() - start_time) * 1000
|
||||
success = all(r.status in ("success", "skipped") for r in step_results)
|
||||
|
||||
return PipelineResult(
|
||||
pipeline_name=self.name,
|
||||
execution_id=execution_id,
|
||||
steps=step_results,
|
||||
final_output=final_output,
|
||||
success=success,
|
||||
total_duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
async def _execute_step(
|
||||
self,
|
||||
step: PipelineStep,
|
||||
input_data: dict[str, Any],
|
||||
step_outputs: dict[str, dict[str, Any]],
|
||||
execution_id: str,
|
||||
) -> PipelineStepResult:
|
||||
"""执行单个 Pipeline 步骤"""
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Check condition
|
||||
if step.condition and not self._evaluate_condition(step.condition, input_data, step_outputs):
|
||||
return PipelineStepResult(
|
||||
step_name=step.name,
|
||||
skill=step.skill,
|
||||
status="skipped",
|
||||
)
|
||||
|
||||
# Build step input from mapping
|
||||
step_input = self._map_input(step, input_data, step_outputs)
|
||||
|
||||
# Execute skill
|
||||
try:
|
||||
output = await self._execute_skill(step.skill, step_input)
|
||||
duration_ms = (time.monotonic() - start_time) * 1000
|
||||
|
||||
# Store result in workspace
|
||||
await self._workspace.write(
|
||||
f"pipeline:{execution_id}:step:{step.name}",
|
||||
output,
|
||||
agent_id=step.skill,
|
||||
)
|
||||
|
||||
return PipelineStepResult(
|
||||
step_name=step.name,
|
||||
skill=step.skill,
|
||||
status="success",
|
||||
output=output,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
duration_ms = (time.monotonic() - start_time) * 1000
|
||||
logger.error(f"Pipeline step '{step.name}' failed: {e}")
|
||||
return PipelineStepResult(
|
||||
step_name=step.name,
|
||||
skill=step.skill,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
async def _execute_skill(
|
||||
self, skill_name: str, input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行 Skill"""
|
||||
if self._agent_pool:
|
||||
agent = self._agent_pool.get_agent(skill_name)
|
||||
if agent:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
task = TaskMessage(
|
||||
task_id=f"pipeline-{skill_name}",
|
||||
agent_name=skill_name,
|
||||
task_type=skill_name,
|
||||
priority=0,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
result = await agent.execute(task)
|
||||
return result.output_data if hasattr(result, "output_data") else result
|
||||
|
||||
if self._skill_registry:
|
||||
skill = self._skill_registry.get(skill_name)
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from datetime import datetime, timezone
|
||||
|
||||
agent = ConfigDrivenAgent(config=skill.config)
|
||||
task = TaskMessage(
|
||||
task_id=f"pipeline-{skill_name}",
|
||||
agent_name=skill_name,
|
||||
task_type=skill.config.agent_type,
|
||||
priority=0,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
return await agent.handle_task(task)
|
||||
|
||||
raise ValueError(f"Skill '{skill_name}' not found: no agent_pool or skill_registry")
|
||||
|
||||
def _build_execution_groups(self) -> list[list[str]]:
|
||||
"""构建并行执行组(拓扑排序)"""
|
||||
completed: set[str] = set()
|
||||
groups: list[list[str]] = []
|
||||
remaining = set(s.name for s in self._steps)
|
||||
|
||||
while remaining:
|
||||
ready = []
|
||||
for name in remaining:
|
||||
step = self._step_map[name]
|
||||
if all(dep in completed for dep in step.depends_on):
|
||||
ready.append(name)
|
||||
|
||||
if not ready:
|
||||
# Circular dependency — force remaining into one group
|
||||
groups.append(list(remaining))
|
||||
break
|
||||
|
||||
groups.append(ready)
|
||||
for name in ready:
|
||||
completed.add(name)
|
||||
remaining.discard(name)
|
||||
|
||||
return groups
|
||||
|
||||
def _map_input(
|
||||
self,
|
||||
step: PipelineStep,
|
||||
input_data: dict[str, Any],
|
||||
step_outputs: dict[str, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""根据 input_mapping 构建步骤输入
|
||||
|
||||
映射格式: {"target_key": "source_path"}
|
||||
source_path 支持:
|
||||
- $.input.xxx — 初始输入
|
||||
- $.steps.step_name.output.xxx — 步骤输出
|
||||
"""
|
||||
if not step.input_mapping:
|
||||
# Default: merge all dependency outputs + original input
|
||||
merged = dict(input_data)
|
||||
for dep in step.depends_on:
|
||||
if dep in step_outputs:
|
||||
merged.update(step_outputs[dep])
|
||||
return merged
|
||||
|
||||
mapped: dict[str, Any] = {}
|
||||
for target_key, source_path in step.input_mapping.items():
|
||||
value = self._resolve_mapping_path(source_path, input_data, step_outputs)
|
||||
if value is not None:
|
||||
mapped[target_key] = value
|
||||
|
||||
return mapped
|
||||
|
||||
@staticmethod
|
||||
def _resolve_mapping_path(
|
||||
path: str,
|
||||
input_data: dict[str, Any],
|
||||
step_outputs: dict[str, dict[str, Any]],
|
||||
) -> Any:
|
||||
"""解析映射路径"""
|
||||
if path.startswith("$.input."):
|
||||
key = path[len("$.input."):]
|
||||
return input_data.get(key)
|
||||
elif path.startswith("$.steps."):
|
||||
# $.steps.step_name or $.steps.step_name.output.field
|
||||
rest = path[len("$.steps."):]
|
||||
parts = rest.split(".", 2)
|
||||
step_name = parts[0]
|
||||
if step_name not in step_outputs:
|
||||
return None
|
||||
if len(parts) == 1:
|
||||
# $.steps.step_name — return whole output
|
||||
return step_outputs[step_name]
|
||||
if len(parts) >= 2 and parts[1] == "output":
|
||||
if len(parts) >= 3:
|
||||
return step_outputs[step_name].get(parts[2])
|
||||
return step_outputs[step_name]
|
||||
# $.steps.step_name.field (without .output)
|
||||
return step_outputs[step_name].get(parts[1])
|
||||
return None
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str, input_data: dict[str, Any], step_outputs: dict[str, Any]
|
||||
) -> bool:
|
||||
"""评估条件表达式"""
|
||||
import re
|
||||
|
||||
try:
|
||||
eq_match = re.match(r'^([\w.]+)\s*==\s*(.+)$', condition.strip())
|
||||
if eq_match:
|
||||
path = eq_match.group(1)
|
||||
value = eq_match.group(2).strip().strip("'\"")
|
||||
actual = self._resolve_mapping_path(f"$.{path}", input_data, step_outputs)
|
||||
return str(actual) == value
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Condition evaluation failed for '{condition}': {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _build_final_output(
|
||||
self,
|
||||
step_outputs: dict[str, dict[str, Any]],
|
||||
input_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""构建最终输出"""
|
||||
final = {"input": input_data}
|
||||
for step_name, output in step_outputs.items():
|
||||
final[step_name] = output
|
||||
return final
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
"""Tests for GEOPipeline"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.skills.geo_pipeline import (
|
||||
GEOPipeline,
|
||||
PipelineStep,
|
||||
PipelineStepResult,
|
||||
PipelineResult,
|
||||
)
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock Agent for pipeline testing"""
|
||||
|
||||
def __init__(self, name: str, output_data: dict | None = None):
|
||||
self.name = name
|
||||
self.agent_type = "mock"
|
||||
self._output_data = output_data or {"result": f"output from {name}"}
|
||||
|
||||
async def execute(self, task):
|
||||
from agentkit.core.protocol import TaskResult, TaskStatus
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc)
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=self._output_data,
|
||||
error_message=None,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
|
||||
class MockAgentPool:
|
||||
"""Mock AgentPool"""
|
||||
|
||||
def __init__(self, agents: dict[str, MockAgent] | None = None):
|
||||
self._agents = agents or {}
|
||||
|
||||
def get_agent(self, name: str):
|
||||
return self._agents.get(name)
|
||||
|
||||
def list_agents(self):
|
||||
return [{"name": a.name, "agent_type": a.agent_type} for a in self._agents.values()]
|
||||
|
||||
|
||||
class TestGEOPipeline:
|
||||
"""GEOPipeline unit tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_pipeline(self):
|
||||
"""Sequential steps should execute in order"""
|
||||
steps = [
|
||||
PipelineStep(name="step1", skill="skill_a"),
|
||||
PipelineStep(name="step2", skill="skill_b", depends_on=["step1"]),
|
||||
]
|
||||
pool = MockAgentPool({
|
||||
"skill_a": MockAgent("skill_a", {"data": "result_a"}),
|
||||
"skill_b": MockAgent("skill_b", {"data": "result_b"}),
|
||||
})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"query": "test"})
|
||||
assert result.success
|
||||
assert len(result.steps) == 2
|
||||
assert result.steps[0].status == "success"
|
||||
assert result.steps[1].status == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_steps(self):
|
||||
"""Steps without dependencies should execute in parallel"""
|
||||
steps = [
|
||||
PipelineStep(name="step1", skill="skill_a"),
|
||||
PipelineStep(name="step2", skill="skill_b"),
|
||||
]
|
||||
pool = MockAgentPool({
|
||||
"skill_a": MockAgent("skill_a", {"data": "a"}),
|
||||
"skill_b": MockAgent("skill_b", {"data": "b"}),
|
||||
})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"query": "test"})
|
||||
assert result.success
|
||||
assert len(result.steps) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_execution(self):
|
||||
"""DAG with mixed parallel/sequential steps"""
|
||||
steps = [
|
||||
PipelineStep(name="detect", skill="skill_a"),
|
||||
PipelineStep(name="analyze_1", skill="skill_b", depends_on=["detect"]),
|
||||
PipelineStep(name="analyze_2", skill="skill_c", depends_on=["detect"]),
|
||||
PipelineStep(name="optimize", skill="skill_d", depends_on=["analyze_1", "analyze_2"]),
|
||||
]
|
||||
pool = MockAgentPool({
|
||||
"skill_a": MockAgent("skill_a", {"citations": 5}),
|
||||
"skill_b": MockAgent("skill_b", {"competitor": "data"}),
|
||||
"skill_c": MockAgent("skill_c", {"trend": "up"}),
|
||||
"skill_d": MockAgent("skill_d", {"optimized": True}),
|
||||
})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"brand": "TestBrand"})
|
||||
assert result.success
|
||||
assert len(result.steps) == 4
|
||||
|
||||
# Check execution groups
|
||||
groups = pipeline._build_execution_groups()
|
||||
assert len(groups) == 3 # [detect], [analyze_1, analyze_2], [optimize]
|
||||
assert "detect" in groups[0]
|
||||
assert set(groups[1]) == {"analyze_1", "analyze_2"}
|
||||
assert groups[2] == ["optimize"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_failure(self):
|
||||
"""Failed step should be recorded"""
|
||||
class FailingAgent:
|
||||
name = "skill_a"
|
||||
agent_type = "mock"
|
||||
async def execute(self, task):
|
||||
raise RuntimeError("Agent failed")
|
||||
|
||||
steps = [PipelineStep(name="step1", skill="skill_a")]
|
||||
pool = MockAgentPool({"skill_a": FailingAgent()})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"query": "test"})
|
||||
assert not result.success
|
||||
assert result.steps[0].status == "failed"
|
||||
assert "Agent failed" in result.steps[0].error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_mapping(self):
|
||||
"""Input mapping should resolve paths correctly"""
|
||||
steps = [
|
||||
PipelineStep(name="step1", skill="skill_a"),
|
||||
PipelineStep(
|
||||
name="step2",
|
||||
skill="skill_b",
|
||||
input_mapping={"brand": "$.input.brand"},
|
||||
depends_on=["step1"],
|
||||
),
|
||||
]
|
||||
pool = MockAgentPool({
|
||||
"skill_a": MockAgent("skill_a", {"data": "a"}),
|
||||
"skill_b": MockAgent("skill_b", {"data": "b"}),
|
||||
})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"brand": "TestBrand"})
|
||||
assert result.success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_config(self):
|
||||
"""Pipeline should be created from YAML config"""
|
||||
config = {
|
||||
"name": "geo_test",
|
||||
"steps": [
|
||||
{"name": "detect", "skill": "citation_detector"},
|
||||
{"name": "analyze", "skill": "competitor_analyzer", "depends_on": ["detect"]},
|
||||
],
|
||||
}
|
||||
pipeline = GEOPipeline.from_config(config)
|
||||
assert pipeline.name == "geo_test"
|
||||
assert len(pipeline._steps) == 2
|
||||
assert pipeline._steps[1].depends_on == ["detect"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_groups_topological_sort(self):
|
||||
"""Execution groups should follow topological order"""
|
||||
steps = [
|
||||
PipelineStep(name="a", skill="s1"),
|
||||
PipelineStep(name="b", skill="s2", depends_on=["a"]),
|
||||
PipelineStep(name="c", skill="s3", depends_on=["a"]),
|
||||
PipelineStep(name="d", skill="s4", depends_on=["b", "c"]),
|
||||
]
|
||||
pipeline = GEOPipeline(name="test", steps=steps)
|
||||
|
||||
groups = pipeline._build_execution_groups()
|
||||
assert len(groups) == 3
|
||||
assert groups[0] == ["a"]
|
||||
assert set(groups[1]) == {"b", "c"}
|
||||
assert groups[2] == ["d"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_mapping_path(self):
|
||||
"""Mapping path resolution"""
|
||||
input_data = {"brand": "TestBrand", "platforms": ["chatgpt"]}
|
||||
step_outputs = {
|
||||
"detect": {"citations": 5, "records": []},
|
||||
}
|
||||
|
||||
# $.input.brand
|
||||
result = GEOPipeline._resolve_mapping_path("$.input.brand", input_data, step_outputs)
|
||||
assert result == "TestBrand"
|
||||
|
||||
# $.steps.detect.output.citations
|
||||
result = GEOPipeline._resolve_mapping_path("$.steps.detect.output.citations", input_data, step_outputs)
|
||||
assert result == 5
|
||||
|
||||
# $.steps.detect (whole output)
|
||||
result = GEOPipeline._resolve_mapping_path("$.steps.detect", input_data, step_outputs)
|
||||
assert result == {"citations": 5, "records": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_output_includes_all_steps(self):
|
||||
"""Final output should include all step results"""
|
||||
steps = [
|
||||
PipelineStep(name="step1", skill="skill_a"),
|
||||
PipelineStep(name="step2", skill="skill_b", depends_on=["step1"]),
|
||||
]
|
||||
pool = MockAgentPool({
|
||||
"skill_a": MockAgent("skill_a", {"result": "a"}),
|
||||
"skill_b": MockAgent("skill_b", {"result": "b"}),
|
||||
})
|
||||
pipeline = GEOPipeline(name="test", steps=steps, agent_pool=pool)
|
||||
|
||||
result = await pipeline.execute({"query": "test"})
|
||||
assert "step1" in result.final_output
|
||||
assert "step2" in result.final_output
|
||||
assert "input" in result.final_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline(self):
|
||||
"""Empty pipeline should succeed with no steps"""
|
||||
pipeline = GEOPipeline(name="empty", steps=[])
|
||||
result = await pipeline.execute({"query": "test"})
|
||||
assert result.success
|
||||
assert len(result.steps) == 0
|
||||
Loading…
Reference in New Issue