refactor(orchestrator+bitable): remove Any from type signatures
Test / backend-test (pull_request) Has been cancelled Details
Test / frontend-unit (pull_request) Has been cancelled Details
Test / api-e2e (pull_request) Has been cancelled Details
Test / frontend-e2e (pull_request) Has been cancelled Details

Eliminate 112 Any usages across orchestrator/ (62) and bitable/ (50) via:
- TYPE_CHECKING Protocol for Redis/LLMGateway/Plan/Dispatcher/StateManager
- object for arbitrary dict/list/value types (Pydantic v2 serializes fine)
- RecalcTask concrete import (replacing Any in recalc_worker.py)
- Coroutine[object, object, object] for async generic
- Remove unused Any imports (F401 cleanup)

Note: Avoided recursive TypeAlias (FieldValue) because Pydantic v2 cannot
build schemas for recursive named aliases (RecursionError).

Tests: 245 passed (bitable 91 + orchestrator 154), 0 regressions
ruff: All checks passed
This commit is contained in:
chiguyong 2026-07-01 02:41:14 +08:00
parent 34a89c4873
commit 7b1b198058
19 changed files with 299 additions and 198 deletions

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import ast
from collections import deque
from typing import Any
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
from agentkit.bitable.formula.parser import (
@ -104,9 +103,9 @@ class FormulaEngine:
def evaluate(
self,
field_id: str,
row_values: dict[str, Any],
column_values: dict[str, list[Any]] | None = None,
) -> Any:
row_values: dict[str, object],
column_values: dict[str, list[object]] | None = None,
) -> object:
"""Evaluate a formula field for a specific record.
Args:
@ -130,7 +129,7 @@ class FormulaEngine:
# Build the field_values dict for the evaluator
# Aggregate refs get column values (lists), row refs get row values (scalars)
eval_values: dict[str, Any] = {}
eval_values: dict[str, object] = {}
# Map real field IDs to safe names
for safe_name, real_id in entry.field_mapping.items():
@ -143,16 +142,16 @@ class FormulaEngine:
def evaluate_all_for_record(
self,
row_values: dict[str, Any],
column_values: dict[str, list[Any]] | None = None,
) -> dict[str, Any]:
row_values: dict[str, object],
column_values: dict[str, list[object]] | None = None,
) -> dict[str, object]:
"""Evaluate all registered formulas for a record.
Returns a dict of field_id computed value.
Formulas are evaluated in topological order so that formula-to-formula
dependencies are resolved correctly.
"""
results: dict[str, Any] = {}
results: dict[str, object] = {}
column_values = column_values or {}
for field_id in self.topological_order():

View File

@ -16,13 +16,11 @@ Usage::
from __future__ import annotations
from typing import Any
def transform_records(
records: list[dict[str, Any]],
records: list[dict[str, object]],
field_mapping: dict[str, str],
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
"""Map source record keys to bitable field IDs via field_mapping.
Keys not in ``field_mapping`` are dropped. Values are passed through
@ -40,9 +38,9 @@ def transform_records(
if not field_mapping:
return []
transformed: list[dict[str, Any]] = []
transformed: list[dict[str, object]] = []
for rec in records:
out: dict[str, Any] = {}
out: dict[str, object] = {}
for src_key, field_id in field_mapping.items():
if src_key in rec:
out[field_id] = rec[src_key]

View File

@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable):
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import (
BigInteger,
@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = {
READ_BATCH = 1000
def infer_field_type(sqla_type: Any) -> str:
def infer_field_type(sqla_type: object) -> str:
"""Map a SQLAlchemy column type instance or class to a bitable field type.
Handles both type instances (``Integer()``) and type classes (``Integer``).
@ -78,7 +77,7 @@ def import_table(
table_name: str,
*,
max_rows: int = 50_000,
) -> dict[str, Any]:
) -> dict[str, object]:
"""Reflect a single table from an external DB.
Returns ``{"table_name": str, "fields": [...], "records": [...],
@ -97,7 +96,7 @@ def import_table(
engine.dispose()
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]:
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]:
"""Reflect one table and read its rows."""
insp = inspect(engine)
@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
table = Table(table_name, metadata, autoload_with=engine)
# Build field definitions
fields: list[dict[str, Any]] = []
fields: list[dict[str, object]] = []
pk_columns = list(table.primary_key.columns)
pk_name = pk_columns[0].name if pk_columns else None
@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
pk_name = "id"
# Read rows
records: list[dict[str, Any]] = []
records: list[dict[str, object]] = []
with engine.connect() as conn:
result = conn.execute(select(table))
for i, row in enumerate(result):
if i >= max_rows:
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
break
rec: dict[str, Any] = {}
rec: dict[str, object] = {}
for col in table.columns:
val = getattr(row, col.name, None)
if val is not None:
@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
}
def _serialize(val: Any) -> Any:
def _serialize(val: object) -> object:
"""Serialize a DB value to JSON-safe form."""
from datetime import date, datetime
from decimal import Decimal

View File

@ -18,7 +18,6 @@ import logging
import socket
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import httpx
@ -36,7 +35,7 @@ class ParsedSheet:
name: str
columns: list[str] = field(default_factory=list)
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
records: list[dict[str, Any]] = field(default_factory=list)
records: list[dict[str, object]] = field(default_factory=list)
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None:
col_count = len(clean_headers)
field_types = _infer_column_types(data_rows, col_count)
records: list[dict[str, Any]] = []
records: list[dict[str, object]] = []
for row in data_rows:
rec: dict[str, Any] = {}
rec: dict[str, object] = {}
for i, col_name in enumerate(clean_headers):
val = row[i] if i < len(row) else None
if val is not None:
@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]:
return types
def _coerce_value(val: Any, field_type: str) -> Any:
def _coerce_value(val: object, field_type: str) -> object:
"""Coerce a cell value to the inferred field type. Truncate long strings."""
if field_type == "date":
from datetime import datetime

View File

@ -9,10 +9,14 @@ from __future__ import annotations
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field as PydanticField
# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of
# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive
# named aliases (RecursionError). `object` is the most permissive type and
# Pydantic v2 serializes dict/list/primitive values fine at runtime.
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
@ -97,7 +101,7 @@ class Table(BaseModel):
# ---------------------------------------------------------------------------
# Status select field options — labels and colors match Feishu Bitable defaults.
_STATUS_OPTIONS: list[dict[str, Any]] = [
_STATUS_OPTIONS: list[dict[str, object]] = [
{"label": "未开始", "value": "not_started", "color": "default"},
{"label": "进行中", "value": "in_progress", "color": "processing"},
{"label": "已完成", "value": "done", "color": "success"},
@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [
#: Templates for the 5 default fields created on every new table (R2).
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
#: on record creation; user-owned fields are user-editable.
DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [
DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [
{
"name": "标题",
"field_type": FieldType.text,
@ -155,7 +159,7 @@ class Field(BaseModel):
table_id: str
name: str
field_type: FieldType
config: dict[str, Any] = PydanticField(default_factory=dict)
config: dict[str, object] = PydanticField(default_factory=dict)
owner: FieldOwner = FieldOwner.user
created_at: datetime = PydanticField(default_factory=_utcnow)
@ -167,7 +171,7 @@ class Record(BaseModel):
id: str
table_id: str
values: dict[str, Any] = PydanticField(default_factory=dict)
values: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow)
updated_at: datetime = PydanticField(default_factory=_utcnow)
@ -181,7 +185,7 @@ class View(BaseModel):
table_id: str
name: str
view_type: ViewType = ViewType.grid
config: dict[str, Any] = PydanticField(default_factory=dict)
config: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow)

View File

@ -18,11 +18,10 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any
from agentkit.bitable.db import BitableDB
from agentkit.bitable.formula.engine import FormulaEngine
from agentkit.bitable.models import FieldType, RecalcStatus
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
from agentkit.bitable.repository import BitableRepository
from agentkit.bitable.service import BitableService
@ -124,7 +123,7 @@ class RecalcWorker:
logger.exception("RecalcWorker error in main loop")
await asyncio.sleep(self._poll_interval)
async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]:
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
"""Sort claimed tasks so dependencies are processed first (P1 #7).
Groups tasks by table_id, builds (or reuses) the engine to get the
@ -146,7 +145,7 @@ class RecalcWorker:
order = engine.topological_order()
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
def _key(t: Any) -> tuple[str, str, int]:
def _key(t: RecalcTask) -> tuple[str, str, int]:
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
return (t.table_id, t.record_id, idx)
@ -175,7 +174,7 @@ class RecalcWorker:
except Exception:
logger.exception("RecalcWorker reaper error")
async def process_task(self, task: Any) -> None:
async def process_task(self, task: RecalcTask) -> None:
"""Process a single recalc task: evaluate formula → write result.
The task is expected to already be in ``calculating`` status when
@ -216,7 +215,7 @@ class RecalcWorker:
return
deps = engine.get_dependencies(task.field_id)
column_values: dict[str, list[Any]] = {}
column_values: dict[str, list[object]] = {}
for dep_field_id in deps:
column_values[dep_field_id] = await self._repo.get_column_values(
task.table_id, dep_field_id

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import delete, func, insert, select, text, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
@ -102,7 +101,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [BitableFile.model_validate(e) for e in result.scalars().all()]
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
"""Update a file's attributes."""
async with self._session_factory() as session:
stmt = (
@ -181,7 +180,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [Table.model_validate(e) for e in result.scalars().all()]
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
"""Update a table's attributes."""
async with self._session_factory() as session:
stmt = (
@ -236,7 +235,7 @@ class BitableRepository:
table_id: str,
name: str,
field_type: FieldType,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
owner: FieldOwner = FieldOwner.user,
) -> Field:
"""Create a new field in a table."""
@ -277,7 +276,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [Field.model_validate(e) for e in result.scalars().all()]
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
"""Update a field's attributes."""
async with self._session_factory() as session:
stmt = (
@ -300,7 +299,7 @@ class BitableRepository:
# ── Records ─────────────────────────────────────────────
async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record:
async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record:
"""Create a new record."""
async with self._session_factory() as session:
stmt = (
@ -318,7 +317,7 @@ class BitableRepository:
return Record.model_validate(entity)
async def create_records_batch(
self, table_id: str, records_values: list[dict[str, Any]]
self, table_id: str, records_values: list[dict[str, object]]
) -> list[Record]:
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
@ -376,7 +375,7 @@ class BitableRepository:
return [Record.model_validate(e) for e in entities], next_cursor
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None:
"""Update a record's values (full replace)."""
async with self._session_factory() as session:
stmt = (
@ -413,7 +412,7 @@ class BitableRepository:
table_id: str,
name: str,
view_type: ViewType = ViewType.grid,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
) -> View:
"""Create a new view."""
async with self._session_factory() as session:
@ -451,7 +450,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [View.model_validate(e) for e in result.scalars().all()]
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
"""Update a view's attributes."""
async with self._session_factory() as session:
stmt = (
@ -543,7 +542,7 @@ class BitableRepository:
) -> None:
"""Update a recalc task's status."""
async with self._session_factory() as session:
kwargs: dict[str, Any] = {"status": status.value}
kwargs: dict[str, object] = {"status": status.value}
if error_message is not None:
kwargs["error_message"] = error_message
if status in (RecalcStatus.done, RecalcStatus.error):
@ -630,7 +629,7 @@ class BitableRepository:
return result_map
async def upsert_record_agent_fields(
self, record_id: str, agent_field_values: dict[str, Any]
self, record_id: str, agent_field_values: dict[str, object]
) -> None:
"""Update agent-owned fields using jsonb_set (KTD8).
@ -646,7 +645,7 @@ class BitableRepository:
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
# misparses the `::` as part of the param name.
inner = "values"
params: dict[str, Any] = {"record_id": record_id}
params: dict[str, object] = {"record_id": record_id}
for i, (field_id, value) in enumerate(agent_field_values.items()):
param_key = f"v{i}"
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
@ -660,8 +659,8 @@ class BitableRepository:
async def list_records_filtered(
self,
table_id: str,
filters: list[dict[str, Any]] | None = None,
sorts: list[dict[str, Any]] | None = None,
filters: list[dict[str, object]] | None = None,
sorts: list[dict[str, object]] | None = None,
cursor: str | None = None,
limit: int = 50,
) -> tuple[list[Record], str | None]:
@ -682,7 +681,7 @@ class BitableRepository:
# Build raw SQL with JSONB filter/sort translation.
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
where_clauses = ["table_id = :table_id"]
params: dict[str, Any] = {"table_id": table_id}
params: dict[str, object] = {"table_id": table_id}
if filters:
for i, f in enumerate(filters):
@ -783,7 +782,7 @@ class BitableRepository:
last_mapping = last_row._mapping
# Build composite cursor from sort values + id.
# Sort values are extracted as text to match `values->>'fid'` expressions.
sv: list[Any] = []
sv: list[object] = []
last_values = last_mapping.get("values")
if isinstance(last_values, str):
# asyncpg may return JSONB as str in raw text() queries.
@ -852,7 +851,7 @@ class BitableRepository:
# ── Recalc support (U3) ────────────────────────────────
async def get_column_values(self, table_id: str, field_id: str) -> list[Any]:
async def get_column_values(self, table_id: str, field_id: str) -> list[object]:
"""Get all values for a field across all records in a table (for aggregates).
Returns a list of values (preserving order by record id). Missing values
@ -866,7 +865,7 @@ class BitableRepository:
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
return [row[0] for row in result.fetchall()]
async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None:
async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None:
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
import json

View File

@ -17,7 +17,7 @@ import logging
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
if TYPE_CHECKING:
from typing import Protocol
class _RedisPipelineLike(Protocol):
def set(self, key: str, value: str, ex: int | None = None) -> object: ...
def zadd(self, name: str, mapping: dict[str, float]) -> object: ...
def get(self, key: str) -> object: ...
def delete(self, *keys: str) -> object: ...
async def execute(self) -> list[object]: ...
class _RedisLike(Protocol):
async def set(self, key: str, value: str, ex: int | None = None) -> object: ...
async def get(self, key: str) -> object: ...
def pipeline(self) -> _RedisPipelineLike: ...
async def zrange(self, name: str, start: int, stop: int) -> list[object]: ...
class _PlanLike(Protocol):
@property
def id(self) -> str: ...
def to_dict(self) -> dict[str, object]: ...
@dataclass
class CheckpointData:
"""单个阶段的 checkpoint 数据。"""
@ -33,15 +55,15 @@ class CheckpointData:
phase_id: str
phase_name: str
phase_status: str
phase_result: dict[str, Any] | None = None
phase_result: dict[str, object] | None = None
plan_status: str = ""
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
def from_dict(cls, data: dict[str, object]) -> CheckpointData:
return cls(
plan_id=data.get("plan_id", ""),
phase_id=data.get("phase_id", ""),
@ -67,7 +89,7 @@ class PipelineCheckpoint:
def __init__(
self,
redis_client: Any = None,
redis_client: _RedisLike | None = None,
prefix: str = _KEY_PREFIX,
ttl_seconds: int = _TTL_SECONDS,
) -> None:
@ -78,7 +100,7 @@ class PipelineCheckpoint:
# P1 #6: 改用 dict keyed by phase_id避免重复 append
self._memory: dict[str, dict[str, CheckpointData]] = {}
# 内存降级存储plan_id → (plan_dict, saved_timestamp)
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
self._memory_plans: dict[str, tuple[dict[str, object], float]] = {}
def _is_expired(self, saved_at: str) -> bool:
"""检查 checkpoint 是否已过期(内存模式 TTL"""
@ -102,7 +124,7 @@ class PipelineCheckpoint:
"""完整 plan JSON 的存储键。"""
return f"{self._prefix}:plan:{plan_id}"
async def save_plan(self, plan: Any) -> None:
async def save_plan(self, plan: _PlanLike) -> None:
"""保存完整 TeamPlan用于 resume 重建)。
Args:
@ -121,7 +143,7 @@ class PipelineCheckpoint:
except Exception as e:
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
async def load_plan(self, plan_id: str) -> dict[str, object] | None:
"""加载完整 plan JSON。"""
# 优先 Redis
if self._redis is not None:
@ -142,7 +164,7 @@ class PipelineCheckpoint:
return None
return plan_dict
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
async def save(self, plan_id: str, phase: object, plan_status: str) -> None:
"""保存阶段 checkpoint。
Args:
@ -212,7 +234,8 @@ class PipelineCheckpoint:
if not phase_ids:
# Redis 无数据,检查内存(过滤过期)
return [
c for c in self._memory.get(plan_id, {}).values()
c
for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
]
@ -236,8 +259,7 @@ class PipelineCheckpoint:
# 内存降级(过滤过期 checkpoint
return [
c for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at)
]
async def clear(self, plan_id: str) -> None:

View File

@ -1,8 +1,8 @@
"""Saga compensation pattern for Pipeline execution"""
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from dataclasses import dataclass
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
@ -12,7 +12,7 @@ class CompletedStep:
"""Record of a completed step with its compensation"""
step_name: str
result: Any
result: object
compensate_action: str | None = None
@ -28,9 +28,7 @@ class CompensationResult:
class SagaOrchestrator:
"""Orchestrates LIFO compensation for failed pipelines"""
def __init__(
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
):
def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None):
"""
Args:
execute_skill_func: Async function to execute a skill by name
@ -42,7 +40,7 @@ class SagaOrchestrator:
def record_completed(
self,
step_name: str,
result: Any,
result: object,
compensate_action: str | None = None,
):
"""Record a completed step for potential compensation"""
@ -59,9 +57,7 @@ class SagaOrchestrator:
results: list[CompensationResult] = []
for step in reversed(self._completed_steps):
if step.compensate_action is None:
logger.info(
f"No compensation for step '{step.step_name}', skipping"
)
logger.info(f"No compensation for step '{step.step_name}', skipping")
results.append(
CompensationResult(
step_name=step.step_name,
@ -82,9 +78,7 @@ class SagaOrchestrator:
)
)
except Exception as e:
logger.error(
f"Compensation for step '{step.step_name}' failed: {e}"
)
logger.error(f"Compensation for step '{step.step_name}' failed: {e}")
results.append(
CompensationResult(
step_name=step.step_name,

View File

@ -4,7 +4,6 @@
"""
import logging
from typing import Any
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
class DynamicPipeline:
"""动态 Pipeline 组合器"""
def __init__(self, engine: PipelineEngine, loader: Any = None):
def __init__(self, engine: PipelineEngine, loader: object | None = None):
self._engine = engine
self._loader = loader
@ -23,7 +22,7 @@ class DynamicPipeline:
self,
pipelines: dict[str, Pipeline],
condition_key: str,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""根据条件选择子 Pipeline 执行"""
context = context or {}
@ -37,14 +36,16 @@ class DynamicPipeline:
)
selected = pipelines[condition_value]
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
logger.info(
f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}"
)
return await self._engine.execute(selected, context)
async def execute_nested(
self,
parent: Pipeline,
sub_pipeline_map: dict[str, Pipeline],
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""执行嵌套 Pipeline"""
# 先执行父 Pipeline
@ -52,7 +53,7 @@ class DynamicPipeline:
# 根据父 Pipeline 结果选择子 Pipeline
for stage_name, stage_result in parent_result.stage_results.items():
if hasattr(stage_result, 'output_data') and stage_result.output_data:
if hasattr(stage_result, "output_data") and stage_result.output_data:
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
sub = sub_pipeline_map[sub_pipeline_name]
@ -66,7 +67,7 @@ class DynamicPipeline:
pipeline: Pipeline,
max_iterations: int = 5,
exit_condition: str = "done",
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""循环执行 Pipeline 直到条件满足"""
current_context = context or {}

View File

@ -3,25 +3,42 @@
import asyncio
import json
import logging
from typing import Any
from typing import Awaitable, Callable, Protocol
from agentkit.core.protocol import HandoffMessage
logger = logging.getLogger(__name__)
class _RedisPubSubLike(Protocol):
"""Structural type for Redis pubsub object."""
async def subscribe(self, channel: str) -> None: ...
async def unsubscribe(self, channel: str) -> None: ...
def listen(self) -> object: ...
class _RedisLike(Protocol):
"""Structural type for async Redis client."""
async def publish(self, channel: str, message: str) -> int: ...
def pubsub(self) -> _RedisPubSubLike: ...
class HandoffManager:
"""Handoff 管理器
通过 Redis Pub/Sub 管理 Agent 间的任务转交
"""
def __init__(self, redis: Any = None, dispatcher: Any = None):
def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None):
self._redis = redis
self._dispatcher = dispatcher
self._handlers: dict[str, list[Any]] = {}
self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {}
def register_handler(self, agent_name: str, handler: Any) -> None:
def register_handler(
self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]]
) -> None:
"""注册 Handoff 处理器"""
if agent_name not in self._handlers:
self._handlers[agent_name] = []

View File

@ -1,10 +1,12 @@
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import (
@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _DispatcherLike(Protocol):
async def dispatch(self, task: object) -> None: ...
async def get_task_status(self, task_id: str) -> dict[str, object]: ...
class _StateManagerLike(Protocol):
async def create_execution(self, **kwargs: object) -> str: ...
async def update_step(self, **kwargs: object) -> None: ...
async def complete_execution(self, **kwargs: object) -> None: ...
async def fail_execution(self, **kwargs: object) -> None: ...
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineEngine:
"""Pipeline 执行引擎
@ -38,7 +57,12 @@ class PipelineEngine:
- 状态持久化可选
"""
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
def __init__(
self,
dispatcher: _DispatcherLike | None = None,
state_manager: _StateManagerLike | None = None,
llm_gateway: _LLMGatewayLike | None = None,
):
self._dispatcher = dispatcher
self._state_manager = state_manager
self._llm_gateway = llm_gateway
@ -46,7 +70,7 @@ class PipelineEngine:
async def execute(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
adaptive_config: AdaptiveConfig | None = None,
) -> PipelineResult:
"""执行 Pipeline
@ -68,7 +92,7 @@ class PipelineEngine:
async def _adaptive_loop(
self,
pipeline: Pipeline,
context: dict[str, Any] | None,
context: dict[str, object] | None,
failed_result: PipelineResult,
adaptive_config: AdaptiveConfig,
) -> PipelineResult:
@ -92,34 +116,30 @@ class PipelineEngine:
# Replan
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
logger.info(
f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)"
)
# Re-execute
current_result = await self._execute_pipeline(new_pipeline, context)
current_pipeline = new_pipeline
# Record reflection in metadata
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
if current_result.status == StageStatus.COMPLETED:
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
return current_result
# Exhausted reflections
logger.warning(
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
)
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)")
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
return current_result
async def _execute_pipeline(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
result = PipelineResult(pipeline_name=pipeline.name)
@ -151,7 +171,9 @@ class PipelineEngine:
# 逐层执行
for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
logger.info(
f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)"
)
# 并行执行同层 stages
tasks = []
@ -173,9 +195,11 @@ class PipelineEngine:
# Update step state
if self._state_manager is not None and execution_id is not None:
try:
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
step_output = sr.output_data if hasattr(sr, 'output_data') else None
step_error = sr.error_message if hasattr(sr, 'error_message') else None
step_status = (
"completed" if sr.status == StageStatus.COMPLETED else sr.status.value
)
step_output = sr.output_data if hasattr(sr, "output_data") else None
step_error = sr.error_message if hasattr(sr, "error_message") else None
await self._state_manager.update_step(
execution_id=execution_id,
step_name=stage.name,
@ -189,19 +213,21 @@ class PipelineEngine:
# 收集输出变量
if sr.output_data and isinstance(sr, dict):
pass
elif hasattr(sr, 'output_data') and sr.output_data:
elif hasattr(sr, "output_data") and sr.output_data:
for output_key in stage.outputs:
if output_key in sr.output_data:
result.variables[output_key] = sr.output_data[output_key]
# 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
if hasattr(sr, "status") and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure:
# Execute Saga compensation for completed steps
compensation_results = await saga.compensate()
if compensation_results:
failed_compensations = [
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
cr
for cr in compensation_results
if not cr.success and cr.error != "no_compensation_needed"
]
if failed_compensations:
logger.warning(
@ -219,7 +245,12 @@ class PipelineEngine:
step_name=stage.name,
error=result.error_message,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
except (
asyncio.TimeoutError,
ConnectionError,
RuntimeError,
ValueError,
) as exc:
logger.warning(f"Failed to persist failure state: {exc}")
return result
@ -252,7 +283,9 @@ class PipelineEngine:
started_at = datetime.now(timezone.utc).isoformat()
# 条件检查
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
if stage.condition and not self._evaluate_condition(
stage.condition, pipeline_result.variables
):
return StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
@ -312,7 +345,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
@ -406,7 +441,7 @@ class PipelineEngine:
return resolved
@staticmethod
def _get_nested(data: dict, path: str) -> Any:
def _get_nested(data: dict, path: str) -> object:
keys = path.split(".")
current = data
for key in keys:
@ -497,9 +532,7 @@ class PipelineEngine:
if verifier_feedback.passed:
# 审查通过,返回成功结果
logger.info(
f"Stage '{stage.name}' passed review in round {round_num}"
)
logger.info(f"Stage '{stage.name}' passed review in round {round_num}")
worker_result.output_data = worker_result.output_data or {}
worker_result.output_data["adversarial_metadata"] = {
"passed_round": round_num,
@ -553,7 +586,7 @@ class PipelineEngine:
self,
agent_name: str,
action: str,
input_data: dict[str, Any],
input_data: dict[str, object],
stage: PipelineStage,
started_at: str,
timeout_seconds: int | None = None,
@ -568,7 +601,9 @@ class PipelineEngine:
started_at: 开始时间
timeout_seconds: 独立超时时间不传则使用 stage.timeout_seconds
"""
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
effective_timeout = (
timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
)
if self._dispatcher is None:
# Dry-run 模式
return StageResult(
@ -602,7 +637,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
@ -639,7 +676,7 @@ class PipelineEngine:
async def _execute_verifier(
self,
verifier_name: str,
worker_output: dict[str, Any],
worker_output: dict[str, object],
stage: PipelineStage,
started_at: str,
) -> ReviewFeedback:
@ -679,10 +716,7 @@ class PipelineEngine:
try:
feedback = ReviewFeedback(
passed=output_data.get("passed", False),
issues=[
ReviewIssue(**issue)
for issue in output_data.get("issues", [])
],
issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])],
summary=output_data.get("summary", "No summary provided"),
score=output_data.get("score", 0.0),
)
@ -699,7 +733,7 @@ class PipelineEngine:
self,
feedback: ReviewFeedback,
feedback_mode: str = "structured+natural",
) -> dict[str, Any]:
) -> dict[str, object]:
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
Args:
@ -720,7 +754,7 @@ class PipelineEngine:
for issue in feedback.issues
]
feedback_context: dict[str, Any] = {
feedback_context: dict[str, object] = {
"previous_attempt_failed": True,
}
@ -756,7 +790,9 @@ class PipelineEngine:
)
else:
# 未知模式fallback 到 structured+natural
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
logger.warning(
f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural"
)
feedback_context["review_feedback"] = {
"summary": feedback.summary,
"issues": issues_list,

View File

@ -2,7 +2,6 @@
import logging
from pathlib import Path
from typing import Any
import yaml
@ -23,7 +22,9 @@ class PipelineLoader:
if not yaml_path.exists():
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
if not yaml_path.exists():
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
raise FileNotFoundError(
f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}"
)
content = yaml_path.read_text(encoding="utf-8")
return self.load_from_yaml(content, pipeline_name)

View File

@ -35,9 +35,7 @@ class PipelineExecutionModel(Base):
)
completed_at = Column(DateTime)
__table_args__ = (
Index("ix_pipeline_status_created", "status", "created_at"),
)
__table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),)
class PipelineStepHistoryModel(Base):

View File

@ -1,7 +1,7 @@
"""Pipeline 数据模型"""
from enum import Enum
from typing import Any, Literal
from typing import Literal
from pydantic import BaseModel, Field
@ -18,8 +18,11 @@ class StageStatus(str, Enum):
class ReviewIssue(BaseModel):
"""单条审查问题"""
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别")
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(
description="问题类别"
)
description: str = Field(min_length=1, description="问题描述")
location: str | None = Field(default=None, description="文件路径/行号")
suggestion: str | None = Field(default=None, description="修复建议")
@ -27,6 +30,7 @@ class ReviewIssue(BaseModel):
class ReviewFeedback(BaseModel):
"""Verifier 返回的结构化审查反馈"""
passed: bool = Field(description="是否通过审查")
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
summary: str = Field(min_length=1, description="自然语言审查报告")
@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel):
class AdversarialState(BaseModel):
"""对抗轮次状态追踪"""
current_round: int = Field(default=0, description="当前对抗轮次")
max_rounds: int = Field(default=3, description="最大对抗轮次")
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
@ -46,7 +51,7 @@ class PipelineStage(BaseModel):
agent: str
action: str
depends_on: list[str] = []
inputs: dict[str, Any] = {}
inputs: dict[str, object] = {}
outputs: list[str] = []
timeout_seconds: int = 300
retry_count: int = 0
@ -54,12 +59,19 @@ class PipelineStage(BaseModel):
condition: str | None = None
retry_policy: StepRetryPolicy | None = None
compensate: str | None = None
# 对抗闭环相关字段
verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式")
verifier: str | None = Field(
default=None, description="Verifier Agent 名称,配置后启用对抗模式"
)
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds")
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式")
verifier_timeout_seconds: int = Field(
default=120,
description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds",
)
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(
default="structured+natural", description="反馈模式"
)
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
model_config = {"arbitrary_types_allowed": True}
@ -70,13 +82,13 @@ class Pipeline(BaseModel):
version: str
description: str
stages: list[PipelineStage]
variables: dict[str, Any] = {}
variables: dict[str, object] = {}
class StageResult(BaseModel):
stage_name: str
status: StageStatus = StageStatus.PENDING
output_data: dict[str, Any] | None = None
output_data: dict[str, object] | None = None
error_message: str | None = None
started_at: str | None = None
completed_at: str | None = None
@ -86,9 +98,9 @@ class PipelineResult(BaseModel):
pipeline_name: str
status: StageStatus = StageStatus.PENDING
stage_results: dict[str, StageResult] = {}
variables: dict[str, Any] = {}
variables: dict[str, object] = {}
error_message: str | None = None
metadata: dict[str, Any] = {}
metadata: dict[str, object] = {}
class AdaptiveConfig(BaseModel):

View File

@ -13,7 +13,7 @@ import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
from typing import Callable, Coroutine
from agentkit.orchestrator.pipeline_models import (
PipelineExecutionModel,
@ -183,7 +183,7 @@ class PipelineStateRedis:
return self._redis
async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object
) -> object | None:
"""Execute a Redis call, falling back to memory on failure.

View File

@ -4,22 +4,30 @@
生成修正后的 Pipeline 重新执行
"""
from __future__ import annotations
import json
import logging
from typing import Any
from typing import TYPE_CHECKING
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineReflector:
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
@ -27,7 +35,7 @@ class PipelineReflector:
输出 ReflectionReport 包含 failure_typeroot_cause suggested_fix
"""
def __init__(self, llm_gateway: Any = None):
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway
async def reflect(
@ -54,19 +62,25 @@ class PipelineReflector:
if self._llm_gateway is not None:
try:
return await self._llm_reflect(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
pipeline,
failed_stage,
error_message,
completed_outputs,
reflection_number,
)
except Exception as e:
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
# 规则兜底:基于错误信息分类
return self._rule_based_reflect(
failed_stage, error_message, reflection_number,
failed_stage,
error_message,
reflection_number,
)
def _find_failure(
self, result: PipelineResult,
self,
result: PipelineResult,
) -> tuple[str, str]:
"""找到第一个失败的 stage 及其错误信息。"""
for name, sr in result.stage_results.items():
@ -75,8 +89,9 @@ class PipelineReflector:
return "", "no failed stage found"
def _collect_completed_outputs(
self, result: PipelineResult,
) -> dict[str, Any]:
self,
result: PipelineResult,
) -> dict[str, object]:
"""收集已完成步骤的输出。"""
outputs = {}
for name, sr in result.stage_results.items():
@ -89,13 +104,16 @@ class PipelineReflector:
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
completed_outputs: dict[str, object],
reflection_number: int,
) -> ReflectionReport:
"""使用 LLM 分析失败原因。"""
prompt = self._build_reflection_prompt(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
pipeline,
failed_stage,
error_message,
completed_outputs,
reflection_number,
)
response = await self._llm_gateway.chat(
@ -106,7 +124,9 @@ class PipelineReflector:
# 解析 LLM 返回的 JSON
content = response.content if hasattr(response, "content") else str(response)
return self._parse_reflection_response(
content, failed_stage, reflection_number,
content,
failed_stage,
reflection_number,
)
def _build_reflection_prompt(
@ -114,15 +134,14 @@ class PipelineReflector:
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
completed_outputs: dict[str, object],
reflection_number: int,
) -> str:
"""构建反思提示词。"""
stage_descriptions = []
for s in pipeline.stages:
stage_descriptions.append(
f" - {s.name}: agent={s.agent}, action={s.action}, "
f"depends_on={s.depends_on}"
f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}"
)
completed_summary = json.dumps(
@ -174,7 +193,9 @@ JSON response:"""
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM reflection response: {e}")
return self._rule_based_reflect(
failed_stage, content, reflection_number,
failed_stage,
content,
reflection_number,
)
def _rule_based_reflect(
@ -218,7 +239,7 @@ class PipelineReplanner:
保留已完成步骤的结果仅重新规划失败及后续步骤
"""
def __init__(self, llm_gateway: Any = None):
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway
async def replan(
@ -255,8 +276,7 @@ class PipelineReplanner:
) -> Pipeline:
"""使用 LLM 生成修正后的 Pipeline。"""
completed_stages = [
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
]
prompt = f"""Based on the reflection report, generate a corrected pipeline.
@ -284,7 +304,9 @@ JSON pipeline:"""
return self._parse_pipeline_response(content, pipeline)
def _parse_pipeline_response(
self, content: str, original: Pipeline,
self,
content: str,
original: Pipeline,
) -> Pipeline:
"""解析 LLM 返回的 Pipeline JSON。"""
try:
@ -294,9 +316,7 @@ JSON pipeline:"""
text = "\n".join(lines[1:-1])
data = json.loads(text)
stages = [
PipelineStage(**s) for s in data.get("stages", [])
]
stages = [PipelineStage(**s) for s in data.get("stages", [])]
return Pipeline(
name=data.get("name", original.name),
version=data.get("version", original.version),
@ -316,8 +336,7 @@ JSON pipeline:"""
) -> Pipeline:
"""基于规则的兜底重规划。"""
completed_stages = {
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
}
# 构建修正后的 stages 列表
@ -345,17 +364,21 @@ JSON pipeline:"""
)
def _adjust_failed_stage(
self, stage: PipelineStage, report: ReflectionReport,
self,
stage: PipelineStage,
report: ReflectionReport,
) -> PipelineStage:
"""根据反思报告调整失败的步骤。"""
adjustments: dict[str, Any] = {}
adjustments: dict[str, object] = {}
if report.failure_type == "timeout":
adjustments["timeout_seconds"] = min(
stage.timeout_seconds * 2, 3600,
stage.timeout_seconds * 2,
3600,
)
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
elif report.failure_type == "resource_error":
@ -365,6 +388,7 @@ JSON pipeline:"""
# 添加重试策略,可能输入在后续可用
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
return stage.model_copy(update=adjustments)

View File

@ -4,7 +4,7 @@ import asyncio
import logging
import random
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class StepRetryPolicy:
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt number (0-based)"""
delay = min(
self.base_delay * (self.exponential_base ** attempt),
self.base_delay * (self.exponential_base**attempt),
self.max_delay,
)
if self.jitter:
@ -36,10 +36,10 @@ class StepRetryPolicy:
async def execute_with_retry(
func: Callable[..., Awaitable[Any]],
func: Callable[..., Awaitable[object]],
retry_policy: StepRetryPolicy | None = None,
step_name: str = "",
) -> Any:
) -> object:
"""Execute a function with retry policy"""
if retry_policy is None:
return await func()

View File

@ -3,18 +3,17 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, Field
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
from agentkit.orchestrator.pipeline_schema import PipelineStage
class WorkflowStage(PipelineStage):
"""A workflow stage extending PipelineStage with type and config."""
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
config: dict[str, Any] = Field(default_factory=dict)
config: dict[str, object] = Field(default_factory=dict)
class WorkflowDefinition(BaseModel):
@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel):
name: str
version: int = 1
stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict)
triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, object] = Field(default_factory=dict)
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel):
workflow_id: str = ""
status: str = "pending" # pending|running|paused|completed|failed|cancelled
current_stage: str | None = None
stage_results: dict[str, Any] = Field(default_factory=dict)
stage_results: dict[str, object] = Field(default_factory=dict)
started_at: str | None = None
completed_at: str | None = None
error: str | None = None
variables: dict[str, Any] = Field(default_factory=dict)
variables: dict[str, object] = Field(default_factory=dict)
class WorkflowSummary(BaseModel):
@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel):
name: str
stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict)
triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, object] = Field(default_factory=dict)
class ExecuteWorkflowRequest(BaseModel):
"""Request body for executing a workflow."""
variables: dict[str, Any] = Field(default_factory=dict)
variables: dict[str, object] = Field(default_factory=dict)
class ApproveRequest(BaseModel):