refactor(orchestrator+bitable): remove Any from type signatures
Eliminate 112 Any usages across orchestrator/ (62) and bitable/ (50) via: - TYPE_CHECKING Protocol for Redis/LLMGateway/Plan/Dispatcher/StateManager - object for arbitrary dict/list/value types (Pydantic v2 serializes fine) - RecalcTask concrete import (replacing Any in recalc_worker.py) - Coroutine[object, object, object] for async generic - Remove unused Any imports (F401 cleanup) Note: Avoided recursive TypeAlias (FieldValue) because Pydantic v2 cannot build schemas for recursive named aliases (RecursionError). Tests: 245 passed (bitable 91 + orchestrator 154), 0 regressions ruff: All checks passed
This commit is contained in:
parent
34a89c4873
commit
7b1b198058
|
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
|
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
|
||||||
from agentkit.bitable.formula.parser import (
|
from agentkit.bitable.formula.parser import (
|
||||||
|
|
@ -104,9 +103,9 @@ class FormulaEngine:
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
field_id: str,
|
field_id: str,
|
||||||
row_values: dict[str, Any],
|
row_values: dict[str, object],
|
||||||
column_values: dict[str, list[Any]] | None = None,
|
column_values: dict[str, list[object]] | None = None,
|
||||||
) -> Any:
|
) -> object:
|
||||||
"""Evaluate a formula field for a specific record.
|
"""Evaluate a formula field for a specific record.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -130,7 +129,7 @@ class FormulaEngine:
|
||||||
|
|
||||||
# Build the field_values dict for the evaluator
|
# Build the field_values dict for the evaluator
|
||||||
# Aggregate refs get column values (lists), row refs get row values (scalars)
|
# Aggregate refs get column values (lists), row refs get row values (scalars)
|
||||||
eval_values: dict[str, Any] = {}
|
eval_values: dict[str, object] = {}
|
||||||
|
|
||||||
# Map real field IDs to safe names
|
# Map real field IDs to safe names
|
||||||
for safe_name, real_id in entry.field_mapping.items():
|
for safe_name, real_id in entry.field_mapping.items():
|
||||||
|
|
@ -143,16 +142,16 @@ class FormulaEngine:
|
||||||
|
|
||||||
def evaluate_all_for_record(
|
def evaluate_all_for_record(
|
||||||
self,
|
self,
|
||||||
row_values: dict[str, Any],
|
row_values: dict[str, object],
|
||||||
column_values: dict[str, list[Any]] | None = None,
|
column_values: dict[str, list[object]] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Evaluate all registered formulas for a record.
|
"""Evaluate all registered formulas for a record.
|
||||||
|
|
||||||
Returns a dict of field_id → computed value.
|
Returns a dict of field_id → computed value.
|
||||||
Formulas are evaluated in topological order so that formula-to-formula
|
Formulas are evaluated in topological order so that formula-to-formula
|
||||||
dependencies are resolved correctly.
|
dependencies are resolved correctly.
|
||||||
"""
|
"""
|
||||||
results: dict[str, Any] = {}
|
results: dict[str, object] = {}
|
||||||
column_values = column_values or {}
|
column_values = column_values or {}
|
||||||
|
|
||||||
for field_id in self.topological_order():
|
for field_id in self.topological_order():
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,11 @@ Usage::
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def transform_records(
|
def transform_records(
|
||||||
records: list[dict[str, Any]],
|
records: list[dict[str, object]],
|
||||||
field_mapping: dict[str, str],
|
field_mapping: dict[str, str],
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
"""Map source record keys to bitable field IDs via field_mapping.
|
"""Map source record keys to bitable field IDs via field_mapping.
|
||||||
|
|
||||||
Keys not in ``field_mapping`` are dropped. Values are passed through
|
Keys not in ``field_mapping`` are dropped. Values are passed through
|
||||||
|
|
@ -40,9 +38,9 @@ def transform_records(
|
||||||
if not field_mapping:
|
if not field_mapping:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
transformed: list[dict[str, Any]] = []
|
transformed: list[dict[str, object]] = []
|
||||||
for rec in records:
|
for rec in records:
|
||||||
out: dict[str, Any] = {}
|
out: dict[str, object] = {}
|
||||||
for src_key, field_id in field_mapping.items():
|
for src_key, field_id in field_mapping.items():
|
||||||
if src_key in rec:
|
if src_key in rec:
|
||||||
out[field_id] = rec[src_key]
|
out[field_id] = rec[src_key]
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable):
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
|
|
@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = {
|
||||||
READ_BATCH = 1000
|
READ_BATCH = 1000
|
||||||
|
|
||||||
|
|
||||||
def infer_field_type(sqla_type: Any) -> str:
|
def infer_field_type(sqla_type: object) -> str:
|
||||||
"""Map a SQLAlchemy column type instance or class to a bitable field type.
|
"""Map a SQLAlchemy column type instance or class to a bitable field type.
|
||||||
|
|
||||||
Handles both type instances (``Integer()``) and type classes (``Integer``).
|
Handles both type instances (``Integer()``) and type classes (``Integer``).
|
||||||
|
|
@ -78,7 +77,7 @@ def import_table(
|
||||||
table_name: str,
|
table_name: str,
|
||||||
*,
|
*,
|
||||||
max_rows: int = 50_000,
|
max_rows: int = 50_000,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Reflect a single table from an external DB.
|
"""Reflect a single table from an external DB.
|
||||||
|
|
||||||
Returns ``{"table_name": str, "fields": [...], "records": [...],
|
Returns ``{"table_name": str, "fields": [...], "records": [...],
|
||||||
|
|
@ -97,7 +96,7 @@ def import_table(
|
||||||
engine.dispose()
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]:
|
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]:
|
||||||
"""Reflect one table and read its rows."""
|
"""Reflect one table and read its rows."""
|
||||||
insp = inspect(engine)
|
insp = inspect(engine)
|
||||||
|
|
||||||
|
|
@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
table = Table(table_name, metadata, autoload_with=engine)
|
table = Table(table_name, metadata, autoload_with=engine)
|
||||||
|
|
||||||
# Build field definitions
|
# Build field definitions
|
||||||
fields: list[dict[str, Any]] = []
|
fields: list[dict[str, object]] = []
|
||||||
pk_columns = list(table.primary_key.columns)
|
pk_columns = list(table.primary_key.columns)
|
||||||
pk_name = pk_columns[0].name if pk_columns else None
|
pk_name = pk_columns[0].name if pk_columns else None
|
||||||
|
|
||||||
|
|
@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
pk_name = "id"
|
pk_name = "id"
|
||||||
|
|
||||||
# Read rows
|
# Read rows
|
||||||
records: list[dict[str, Any]] = []
|
records: list[dict[str, object]] = []
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
result = conn.execute(select(table))
|
result = conn.execute(select(table))
|
||||||
for i, row in enumerate(result):
|
for i, row in enumerate(result):
|
||||||
if i >= max_rows:
|
if i >= max_rows:
|
||||||
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
|
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
|
||||||
break
|
break
|
||||||
rec: dict[str, Any] = {}
|
rec: dict[str, object] = {}
|
||||||
for col in table.columns:
|
for col in table.columns:
|
||||||
val = getattr(row, col.name, None)
|
val = getattr(row, col.name, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
|
|
@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _serialize(val: Any) -> Any:
|
def _serialize(val: object) -> object:
|
||||||
"""Serialize a DB value to JSON-safe form."""
|
"""Serialize a DB value to JSON-safe form."""
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import logging
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -36,7 +35,7 @@ class ParsedSheet:
|
||||||
name: str
|
name: str
|
||||||
columns: list[str] = field(default_factory=list)
|
columns: list[str] = field(default_factory=list)
|
||||||
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
|
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
|
||||||
records: list[dict[str, Any]] = field(default_factory=list)
|
records: list[dict[str, object]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
|
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
|
||||||
|
|
@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None:
|
||||||
col_count = len(clean_headers)
|
col_count = len(clean_headers)
|
||||||
field_types = _infer_column_types(data_rows, col_count)
|
field_types = _infer_column_types(data_rows, col_count)
|
||||||
|
|
||||||
records: list[dict[str, Any]] = []
|
records: list[dict[str, object]] = []
|
||||||
for row in data_rows:
|
for row in data_rows:
|
||||||
rec: dict[str, Any] = {}
|
rec: dict[str, object] = {}
|
||||||
for i, col_name in enumerate(clean_headers):
|
for i, col_name in enumerate(clean_headers):
|
||||||
val = row[i] if i < len(row) else None
|
val = row[i] if i < len(row) else None
|
||||||
if val is not None:
|
if val is not None:
|
||||||
|
|
@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]:
|
||||||
return types
|
return types
|
||||||
|
|
||||||
|
|
||||||
def _coerce_value(val: Any, field_type: str) -> Any:
|
def _coerce_value(val: object, field_type: str) -> object:
|
||||||
"""Coerce a cell value to the inferred field type. Truncate long strings."""
|
"""Coerce a cell value to the inferred field type. Truncate long strings."""
|
||||||
if field_type == "date":
|
if field_type == "date":
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field as PydanticField
|
from pydantic import BaseModel, ConfigDict, Field as PydanticField
|
||||||
|
|
||||||
|
# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of
|
||||||
|
# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive
|
||||||
|
# named aliases (RecursionError). `object` is the most permissive type and
|
||||||
|
# Pydantic v2 serializes dict/list/primitive values fine at runtime.
|
||||||
|
|
||||||
|
|
||||||
def _utcnow() -> datetime:
|
def _utcnow() -> datetime:
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
@ -97,7 +101,7 @@ class Table(BaseModel):
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Status select field options — labels and colors match Feishu Bitable defaults.
|
# Status select field options — labels and colors match Feishu Bitable defaults.
|
||||||
_STATUS_OPTIONS: list[dict[str, Any]] = [
|
_STATUS_OPTIONS: list[dict[str, object]] = [
|
||||||
{"label": "未开始", "value": "not_started", "color": "default"},
|
{"label": "未开始", "value": "not_started", "color": "default"},
|
||||||
{"label": "进行中", "value": "in_progress", "color": "processing"},
|
{"label": "进行中", "value": "in_progress", "color": "processing"},
|
||||||
{"label": "已完成", "value": "done", "color": "success"},
|
{"label": "已完成", "value": "done", "color": "success"},
|
||||||
|
|
@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [
|
||||||
#: Templates for the 5 default fields created on every new table (R2).
|
#: Templates for the 5 default fields created on every new table (R2).
|
||||||
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
|
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
|
||||||
#: on record creation; user-owned fields are user-editable.
|
#: on record creation; user-owned fields are user-editable.
|
||||||
DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [
|
DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [
|
||||||
{
|
{
|
||||||
"name": "标题",
|
"name": "标题",
|
||||||
"field_type": FieldType.text,
|
"field_type": FieldType.text,
|
||||||
|
|
@ -155,7 +159,7 @@ class Field(BaseModel):
|
||||||
table_id: str
|
table_id: str
|
||||||
name: str
|
name: str
|
||||||
field_type: FieldType
|
field_type: FieldType
|
||||||
config: dict[str, Any] = PydanticField(default_factory=dict)
|
config: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
owner: FieldOwner = FieldOwner.user
|
owner: FieldOwner = FieldOwner.user
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
@ -167,7 +171,7 @@ class Record(BaseModel):
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
table_id: str
|
table_id: str
|
||||||
values: dict[str, Any] = PydanticField(default_factory=dict)
|
values: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
updated_at: datetime = PydanticField(default_factory=_utcnow)
|
updated_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
@ -181,7 +185,7 @@ class View(BaseModel):
|
||||||
table_id: str
|
table_id: str
|
||||||
name: str
|
name: str
|
||||||
view_type: ViewType = ViewType.grid
|
view_type: ViewType = ViewType.grid
|
||||||
config: dict[str, Any] = PydanticField(default_factory=dict)
|
config: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.bitable.db import BitableDB
|
from agentkit.bitable.db import BitableDB
|
||||||
from agentkit.bitable.formula.engine import FormulaEngine
|
from agentkit.bitable.formula.engine import FormulaEngine
|
||||||
from agentkit.bitable.models import FieldType, RecalcStatus
|
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
|
||||||
from agentkit.bitable.repository import BitableRepository
|
from agentkit.bitable.repository import BitableRepository
|
||||||
from agentkit.bitable.service import BitableService
|
from agentkit.bitable.service import BitableService
|
||||||
|
|
||||||
|
|
@ -124,7 +123,7 @@ class RecalcWorker:
|
||||||
logger.exception("RecalcWorker error in main loop")
|
logger.exception("RecalcWorker error in main loop")
|
||||||
await asyncio.sleep(self._poll_interval)
|
await asyncio.sleep(self._poll_interval)
|
||||||
|
|
||||||
async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]:
|
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
|
||||||
"""Sort claimed tasks so dependencies are processed first (P1 #7).
|
"""Sort claimed tasks so dependencies are processed first (P1 #7).
|
||||||
|
|
||||||
Groups tasks by table_id, builds (or reuses) the engine to get the
|
Groups tasks by table_id, builds (or reuses) the engine to get the
|
||||||
|
|
@ -146,7 +145,7 @@ class RecalcWorker:
|
||||||
order = engine.topological_order()
|
order = engine.topological_order()
|
||||||
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
|
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
|
||||||
|
|
||||||
def _key(t: Any) -> tuple[str, str, int]:
|
def _key(t: RecalcTask) -> tuple[str, str, int]:
|
||||||
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
|
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
|
||||||
return (t.table_id, t.record_id, idx)
|
return (t.table_id, t.record_id, idx)
|
||||||
|
|
||||||
|
|
@ -175,7 +174,7 @@ class RecalcWorker:
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("RecalcWorker reaper error")
|
logger.exception("RecalcWorker reaper error")
|
||||||
|
|
||||||
async def process_task(self, task: Any) -> None:
|
async def process_task(self, task: RecalcTask) -> None:
|
||||||
"""Process a single recalc task: evaluate formula → write result.
|
"""Process a single recalc task: evaluate formula → write result.
|
||||||
|
|
||||||
The task is expected to already be in ``calculating`` status when
|
The task is expected to already be in ``calculating`` status when
|
||||||
|
|
@ -216,7 +215,7 @@ class RecalcWorker:
|
||||||
return
|
return
|
||||||
|
|
||||||
deps = engine.get_dependencies(task.field_id)
|
deps = engine.get_dependencies(task.field_id)
|
||||||
column_values: dict[str, list[Any]] = {}
|
column_values: dict[str, list[object]] = {}
|
||||||
for dep_field_id in deps:
|
for dep_field_id in deps:
|
||||||
column_values[dep_field_id] = await self._repo.get_column_values(
|
column_values[dep_field_id] = await self._repo.get_column_values(
|
||||||
task.table_id, dep_field_id
|
task.table_id, dep_field_id
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import delete, func, insert, select, text, update
|
from sqlalchemy import delete, func, insert, select, text, update
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
@ -102,7 +101,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [BitableFile.model_validate(e) for e in result.scalars().all()]
|
return [BitableFile.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
|
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
|
||||||
"""Update a file's attributes."""
|
"""Update a file's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -181,7 +180,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [Table.model_validate(e) for e in result.scalars().all()]
|
return [Table.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
|
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
|
||||||
"""Update a table's attributes."""
|
"""Update a table's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -236,7 +235,7 @@ class BitableRepository:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
field_type: FieldType,
|
field_type: FieldType,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
owner: FieldOwner = FieldOwner.user,
|
owner: FieldOwner = FieldOwner.user,
|
||||||
) -> Field:
|
) -> Field:
|
||||||
"""Create a new field in a table."""
|
"""Create a new field in a table."""
|
||||||
|
|
@ -277,7 +276,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [Field.model_validate(e) for e in result.scalars().all()]
|
return [Field.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
|
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
|
||||||
"""Update a field's attributes."""
|
"""Update a field's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -300,7 +299,7 @@ class BitableRepository:
|
||||||
|
|
||||||
# ── Records ─────────────────────────────────────────────
|
# ── Records ─────────────────────────────────────────────
|
||||||
|
|
||||||
async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record:
|
async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record:
|
||||||
"""Create a new record."""
|
"""Create a new record."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -318,7 +317,7 @@ class BitableRepository:
|
||||||
return Record.model_validate(entity)
|
return Record.model_validate(entity)
|
||||||
|
|
||||||
async def create_records_batch(
|
async def create_records_batch(
|
||||||
self, table_id: str, records_values: list[dict[str, Any]]
|
self, table_id: str, records_values: list[dict[str, object]]
|
||||||
) -> list[Record]:
|
) -> list[Record]:
|
||||||
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
|
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
|
||||||
|
|
||||||
|
|
@ -376,7 +375,7 @@ class BitableRepository:
|
||||||
|
|
||||||
return [Record.model_validate(e) for e in entities], next_cursor
|
return [Record.model_validate(e) for e in entities], next_cursor
|
||||||
|
|
||||||
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
|
async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None:
|
||||||
"""Update a record's values (full replace)."""
|
"""Update a record's values (full replace)."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -413,7 +412,7 @@ class BitableRepository:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
view_type: ViewType = ViewType.grid,
|
view_type: ViewType = ViewType.grid,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
) -> View:
|
) -> View:
|
||||||
"""Create a new view."""
|
"""Create a new view."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
|
|
@ -451,7 +450,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [View.model_validate(e) for e in result.scalars().all()]
|
return [View.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
|
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
|
||||||
"""Update a view's attributes."""
|
"""Update a view's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -543,7 +542,7 @@ class BitableRepository:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a recalc task's status."""
|
"""Update a recalc task's status."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
kwargs: dict[str, Any] = {"status": status.value}
|
kwargs: dict[str, object] = {"status": status.value}
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
kwargs["error_message"] = error_message
|
kwargs["error_message"] = error_message
|
||||||
if status in (RecalcStatus.done, RecalcStatus.error):
|
if status in (RecalcStatus.done, RecalcStatus.error):
|
||||||
|
|
@ -630,7 +629,7 @@ class BitableRepository:
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
async def upsert_record_agent_fields(
|
async def upsert_record_agent_fields(
|
||||||
self, record_id: str, agent_field_values: dict[str, Any]
|
self, record_id: str, agent_field_values: dict[str, object]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update agent-owned fields using jsonb_set (KTD8).
|
"""Update agent-owned fields using jsonb_set (KTD8).
|
||||||
|
|
||||||
|
|
@ -646,7 +645,7 @@ class BitableRepository:
|
||||||
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
|
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
|
||||||
# misparses the `::` as part of the param name.
|
# misparses the `::` as part of the param name.
|
||||||
inner = "values"
|
inner = "values"
|
||||||
params: dict[str, Any] = {"record_id": record_id}
|
params: dict[str, object] = {"record_id": record_id}
|
||||||
for i, (field_id, value) in enumerate(agent_field_values.items()):
|
for i, (field_id, value) in enumerate(agent_field_values.items()):
|
||||||
param_key = f"v{i}"
|
param_key = f"v{i}"
|
||||||
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
|
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
|
||||||
|
|
@ -660,8 +659,8 @@ class BitableRepository:
|
||||||
async def list_records_filtered(
|
async def list_records_filtered(
|
||||||
self,
|
self,
|
||||||
table_id: str,
|
table_id: str,
|
||||||
filters: list[dict[str, Any]] | None = None,
|
filters: list[dict[str, object]] | None = None,
|
||||||
sorts: list[dict[str, Any]] | None = None,
|
sorts: list[dict[str, object]] | None = None,
|
||||||
cursor: str | None = None,
|
cursor: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
) -> tuple[list[Record], str | None]:
|
) -> tuple[list[Record], str | None]:
|
||||||
|
|
@ -682,7 +681,7 @@ class BitableRepository:
|
||||||
# Build raw SQL with JSONB filter/sort translation.
|
# Build raw SQL with JSONB filter/sort translation.
|
||||||
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
|
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
|
||||||
where_clauses = ["table_id = :table_id"]
|
where_clauses = ["table_id = :table_id"]
|
||||||
params: dict[str, Any] = {"table_id": table_id}
|
params: dict[str, object] = {"table_id": table_id}
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
for i, f in enumerate(filters):
|
for i, f in enumerate(filters):
|
||||||
|
|
@ -783,7 +782,7 @@ class BitableRepository:
|
||||||
last_mapping = last_row._mapping
|
last_mapping = last_row._mapping
|
||||||
# Build composite cursor from sort values + id.
|
# Build composite cursor from sort values + id.
|
||||||
# Sort values are extracted as text to match `values->>'fid'` expressions.
|
# Sort values are extracted as text to match `values->>'fid'` expressions.
|
||||||
sv: list[Any] = []
|
sv: list[object] = []
|
||||||
last_values = last_mapping.get("values")
|
last_values = last_mapping.get("values")
|
||||||
if isinstance(last_values, str):
|
if isinstance(last_values, str):
|
||||||
# asyncpg may return JSONB as str in raw text() queries.
|
# asyncpg may return JSONB as str in raw text() queries.
|
||||||
|
|
@ -852,7 +851,7 @@ class BitableRepository:
|
||||||
|
|
||||||
# ── Recalc support (U3) ────────────────────────────────
|
# ── Recalc support (U3) ────────────────────────────────
|
||||||
|
|
||||||
async def get_column_values(self, table_id: str, field_id: str) -> list[Any]:
|
async def get_column_values(self, table_id: str, field_id: str) -> list[object]:
|
||||||
"""Get all values for a field across all records in a table (for aggregates).
|
"""Get all values for a field across all records in a table (for aggregates).
|
||||||
|
|
||||||
Returns a list of values (preserving order by record id). Missing values
|
Returns a list of values (preserving order by record id). Missing values
|
||||||
|
|
@ -866,7 +865,7 @@ class BitableRepository:
|
||||||
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
|
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
|
||||||
return [row[0] for row in result.fetchall()]
|
return [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None:
|
async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None:
|
||||||
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
|
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||||||
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _RedisPipelineLike(Protocol):
|
||||||
|
def set(self, key: str, value: str, ex: int | None = None) -> object: ...
|
||||||
|
def zadd(self, name: str, mapping: dict[str, float]) -> object: ...
|
||||||
|
def get(self, key: str) -> object: ...
|
||||||
|
def delete(self, *keys: str) -> object: ...
|
||||||
|
async def execute(self) -> list[object]: ...
|
||||||
|
|
||||||
|
class _RedisLike(Protocol):
|
||||||
|
async def set(self, key: str, value: str, ex: int | None = None) -> object: ...
|
||||||
|
async def get(self, key: str) -> object: ...
|
||||||
|
def pipeline(self) -> _RedisPipelineLike: ...
|
||||||
|
async def zrange(self, name: str, start: int, stop: int) -> list[object]: ...
|
||||||
|
|
||||||
|
class _PlanLike(Protocol):
|
||||||
|
@property
|
||||||
|
def id(self) -> str: ...
|
||||||
|
def to_dict(self) -> dict[str, object]: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CheckpointData:
|
class CheckpointData:
|
||||||
"""单个阶段的 checkpoint 数据。"""
|
"""单个阶段的 checkpoint 数据。"""
|
||||||
|
|
@ -33,15 +55,15 @@ class CheckpointData:
|
||||||
phase_id: str
|
phase_id: str
|
||||||
phase_name: str
|
phase_name: str
|
||||||
phase_status: str
|
phase_status: str
|
||||||
phase_result: dict[str, Any] | None = None
|
phase_result: dict[str, object] | None = None
|
||||||
plan_status: str = ""
|
plan_status: str = ""
|
||||||
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
|
def from_dict(cls, data: dict[str, object]) -> CheckpointData:
|
||||||
return cls(
|
return cls(
|
||||||
plan_id=data.get("plan_id", ""),
|
plan_id=data.get("plan_id", ""),
|
||||||
phase_id=data.get("phase_id", ""),
|
phase_id=data.get("phase_id", ""),
|
||||||
|
|
@ -67,7 +89,7 @@ class PipelineCheckpoint:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_client: Any = None,
|
redis_client: _RedisLike | None = None,
|
||||||
prefix: str = _KEY_PREFIX,
|
prefix: str = _KEY_PREFIX,
|
||||||
ttl_seconds: int = _TTL_SECONDS,
|
ttl_seconds: int = _TTL_SECONDS,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -78,7 +100,7 @@ class PipelineCheckpoint:
|
||||||
# P1 #6: 改用 dict keyed by phase_id,避免重复 append
|
# P1 #6: 改用 dict keyed by phase_id,避免重复 append
|
||||||
self._memory: dict[str, dict[str, CheckpointData]] = {}
|
self._memory: dict[str, dict[str, CheckpointData]] = {}
|
||||||
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
||||||
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
|
self._memory_plans: dict[str, tuple[dict[str, object], float]] = {}
|
||||||
|
|
||||||
def _is_expired(self, saved_at: str) -> bool:
|
def _is_expired(self, saved_at: str) -> bool:
|
||||||
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
||||||
|
|
@ -102,7 +124,7 @@ class PipelineCheckpoint:
|
||||||
"""完整 plan JSON 的存储键。"""
|
"""完整 plan JSON 的存储键。"""
|
||||||
return f"{self._prefix}:plan:{plan_id}"
|
return f"{self._prefix}:plan:{plan_id}"
|
||||||
|
|
||||||
async def save_plan(self, plan: Any) -> None:
|
async def save_plan(self, plan: _PlanLike) -> None:
|
||||||
"""保存完整 TeamPlan(用于 resume 重建)。
|
"""保存完整 TeamPlan(用于 resume 重建)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -121,7 +143,7 @@ class PipelineCheckpoint:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
||||||
|
|
||||||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
async def load_plan(self, plan_id: str) -> dict[str, object] | None:
|
||||||
"""加载完整 plan JSON。"""
|
"""加载完整 plan JSON。"""
|
||||||
# 优先 Redis
|
# 优先 Redis
|
||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
|
|
@ -142,7 +164,7 @@ class PipelineCheckpoint:
|
||||||
return None
|
return None
|
||||||
return plan_dict
|
return plan_dict
|
||||||
|
|
||||||
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
async def save(self, plan_id: str, phase: object, plan_status: str) -> None:
|
||||||
"""保存阶段 checkpoint。
|
"""保存阶段 checkpoint。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -212,7 +234,8 @@ class PipelineCheckpoint:
|
||||||
if not phase_ids:
|
if not phase_ids:
|
||||||
# Redis 无数据,检查内存(过滤过期)
|
# Redis 无数据,检查内存(过滤过期)
|
||||||
return [
|
return [
|
||||||
c for c in self._memory.get(plan_id, {}).values()
|
c
|
||||||
|
for c in self._memory.get(plan_id, {}).values()
|
||||||
if not self._is_expired(c.saved_at)
|
if not self._is_expired(c.saved_at)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -236,8 +259,7 @@ class PipelineCheckpoint:
|
||||||
|
|
||||||
# 内存降级(过滤过期 checkpoint)
|
# 内存降级(过滤过期 checkpoint)
|
||||||
return [
|
return [
|
||||||
c for c in self._memory.get(plan_id, {}).values()
|
c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at)
|
||||||
if not self._is_expired(c.saved_at)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
async def clear(self, plan_id: str) -> None:
|
async def clear(self, plan_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""Saga compensation pattern for Pipeline execution"""
|
"""Saga compensation pattern for Pipeline execution"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ class CompletedStep:
|
||||||
"""Record of a completed step with its compensation"""
|
"""Record of a completed step with its compensation"""
|
||||||
|
|
||||||
step_name: str
|
step_name: str
|
||||||
result: Any
|
result: object
|
||||||
compensate_action: str | None = None
|
compensate_action: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,9 +28,7 @@ class CompensationResult:
|
||||||
class SagaOrchestrator:
|
class SagaOrchestrator:
|
||||||
"""Orchestrates LIFO compensation for failed pipelines"""
|
"""Orchestrates LIFO compensation for failed pipelines"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None):
|
||||||
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
execute_skill_func: Async function to execute a skill by name
|
execute_skill_func: Async function to execute a skill by name
|
||||||
|
|
@ -42,7 +40,7 @@ class SagaOrchestrator:
|
||||||
def record_completed(
|
def record_completed(
|
||||||
self,
|
self,
|
||||||
step_name: str,
|
step_name: str,
|
||||||
result: Any,
|
result: object,
|
||||||
compensate_action: str | None = None,
|
compensate_action: str | None = None,
|
||||||
):
|
):
|
||||||
"""Record a completed step for potential compensation"""
|
"""Record a completed step for potential compensation"""
|
||||||
|
|
@ -59,9 +57,7 @@ class SagaOrchestrator:
|
||||||
results: list[CompensationResult] = []
|
results: list[CompensationResult] = []
|
||||||
for step in reversed(self._completed_steps):
|
for step in reversed(self._completed_steps):
|
||||||
if step.compensate_action is None:
|
if step.compensate_action is None:
|
||||||
logger.info(
|
logger.info(f"No compensation for step '{step.step_name}', skipping")
|
||||||
f"No compensation for step '{step.step_name}', skipping"
|
|
||||||
)
|
|
||||||
results.append(
|
results.append(
|
||||||
CompensationResult(
|
CompensationResult(
|
||||||
step_name=step.step_name,
|
step_name=step.step_name,
|
||||||
|
|
@ -82,9 +78,7 @@ class SagaOrchestrator:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Compensation for step '{step.step_name}' failed: {e}")
|
||||||
f"Compensation for step '{step.step_name}' failed: {e}"
|
|
||||||
)
|
|
||||||
results.append(
|
results.append(
|
||||||
CompensationResult(
|
CompensationResult(
|
||||||
step_name=step.step_name,
|
step_name=step.step_name,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
|
||||||
|
|
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||||
class DynamicPipeline:
|
class DynamicPipeline:
|
||||||
"""动态 Pipeline 组合器"""
|
"""动态 Pipeline 组合器"""
|
||||||
|
|
||||||
def __init__(self, engine: PipelineEngine, loader: Any = None):
|
def __init__(self, engine: PipelineEngine, loader: object | None = None):
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
self._loader = loader
|
self._loader = loader
|
||||||
|
|
||||||
|
|
@ -23,7 +22,7 @@ class DynamicPipeline:
|
||||||
self,
|
self,
|
||||||
pipelines: dict[str, Pipeline],
|
pipelines: dict[str, Pipeline],
|
||||||
condition_key: str,
|
condition_key: str,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""根据条件选择子 Pipeline 执行"""
|
"""根据条件选择子 Pipeline 执行"""
|
||||||
context = context or {}
|
context = context or {}
|
||||||
|
|
@ -37,14 +36,16 @@ class DynamicPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
selected = pipelines[condition_value]
|
selected = pipelines[condition_value]
|
||||||
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
|
logger.info(
|
||||||
|
f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}"
|
||||||
|
)
|
||||||
return await self._engine.execute(selected, context)
|
return await self._engine.execute(selected, context)
|
||||||
|
|
||||||
async def execute_nested(
|
async def execute_nested(
|
||||||
self,
|
self,
|
||||||
parent: Pipeline,
|
parent: Pipeline,
|
||||||
sub_pipeline_map: dict[str, Pipeline],
|
sub_pipeline_map: dict[str, Pipeline],
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行嵌套 Pipeline"""
|
"""执行嵌套 Pipeline"""
|
||||||
# 先执行父 Pipeline
|
# 先执行父 Pipeline
|
||||||
|
|
@ -52,7 +53,7 @@ class DynamicPipeline:
|
||||||
|
|
||||||
# 根据父 Pipeline 结果选择子 Pipeline
|
# 根据父 Pipeline 结果选择子 Pipeline
|
||||||
for stage_name, stage_result in parent_result.stage_results.items():
|
for stage_name, stage_result in parent_result.stage_results.items():
|
||||||
if hasattr(stage_result, 'output_data') and stage_result.output_data:
|
if hasattr(stage_result, "output_data") and stage_result.output_data:
|
||||||
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
|
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
|
||||||
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
|
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
|
||||||
sub = sub_pipeline_map[sub_pipeline_name]
|
sub = sub_pipeline_map[sub_pipeline_name]
|
||||||
|
|
@ -66,7 +67,7 @@ class DynamicPipeline:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
exit_condition: str = "done",
|
exit_condition: str = "done",
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""循环执行 Pipeline 直到条件满足"""
|
"""循环执行 Pipeline 直到条件满足"""
|
||||||
current_context = context or {}
|
current_context = context or {}
|
||||||
|
|
|
||||||
|
|
@ -3,25 +3,42 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Awaitable, Callable, Protocol
|
||||||
|
|
||||||
from agentkit.core.protocol import HandoffMessage
|
from agentkit.core.protocol import HandoffMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _RedisPubSubLike(Protocol):
|
||||||
|
"""Structural type for Redis pubsub object."""
|
||||||
|
|
||||||
|
async def subscribe(self, channel: str) -> None: ...
|
||||||
|
async def unsubscribe(self, channel: str) -> None: ...
|
||||||
|
def listen(self) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _RedisLike(Protocol):
|
||||||
|
"""Structural type for async Redis client."""
|
||||||
|
|
||||||
|
async def publish(self, channel: str, message: str) -> int: ...
|
||||||
|
def pubsub(self) -> _RedisPubSubLike: ...
|
||||||
|
|
||||||
|
|
||||||
class HandoffManager:
|
class HandoffManager:
|
||||||
"""Handoff 管理器
|
"""Handoff 管理器
|
||||||
|
|
||||||
通过 Redis Pub/Sub 管理 Agent 间的任务转交。
|
通过 Redis Pub/Sub 管理 Agent 间的任务转交。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, redis: Any = None, dispatcher: Any = None):
|
def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None):
|
||||||
self._redis = redis
|
self._redis = redis
|
||||||
self._dispatcher = dispatcher
|
self._dispatcher = dispatcher
|
||||||
self._handlers: dict[str, list[Any]] = {}
|
self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {}
|
||||||
|
|
||||||
def register_handler(self, agent_name: str, handler: Any) -> None:
|
def register_handler(
|
||||||
|
self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
"""注册 Handoff 处理器"""
|
"""注册 Handoff 处理器"""
|
||||||
if agent_name not in self._handlers:
|
if agent_name not in self._handlers:
|
||||||
self._handlers[agent_name] = []
|
self._handlers[agent_name] = []
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||||
from agentkit.orchestrator.pipeline_schema import (
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
|
|
@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _DispatcherLike(Protocol):
|
||||||
|
async def dispatch(self, task: object) -> None: ...
|
||||||
|
async def get_task_status(self, task_id: str) -> dict[str, object]: ...
|
||||||
|
|
||||||
|
class _StateManagerLike(Protocol):
|
||||||
|
async def create_execution(self, **kwargs: object) -> str: ...
|
||||||
|
async def update_step(self, **kwargs: object) -> None: ...
|
||||||
|
async def complete_execution(self, **kwargs: object) -> None: ...
|
||||||
|
async def fail_execution(self, **kwargs: object) -> None: ...
|
||||||
|
|
||||||
|
class _LLMGatewayLike(Protocol):
|
||||||
|
async def chat(self, **kwargs: object) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
class PipelineEngine:
|
class PipelineEngine:
|
||||||
"""Pipeline 执行引擎
|
"""Pipeline 执行引擎
|
||||||
|
|
||||||
|
|
@ -38,7 +57,12 @@ class PipelineEngine:
|
||||||
- 状态持久化(可选)
|
- 状态持久化(可选)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dispatcher: _DispatcherLike | None = None,
|
||||||
|
state_manager: _StateManagerLike | None = None,
|
||||||
|
llm_gateway: _LLMGatewayLike | None = None,
|
||||||
|
):
|
||||||
self._dispatcher = dispatcher
|
self._dispatcher = dispatcher
|
||||||
self._state_manager = state_manager
|
self._state_manager = state_manager
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
@ -46,7 +70,7 @@ class PipelineEngine:
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
adaptive_config: AdaptiveConfig | None = None,
|
adaptive_config: AdaptiveConfig | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行 Pipeline
|
"""执行 Pipeline
|
||||||
|
|
@ -68,7 +92,7 @@ class PipelineEngine:
|
||||||
async def _adaptive_loop(
|
async def _adaptive_loop(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None,
|
context: dict[str, object] | None,
|
||||||
failed_result: PipelineResult,
|
failed_result: PipelineResult,
|
||||||
adaptive_config: AdaptiveConfig,
|
adaptive_config: AdaptiveConfig,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
|
|
@ -92,34 +116,30 @@ class PipelineEngine:
|
||||||
|
|
||||||
# Replan
|
# Replan
|
||||||
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
|
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
|
||||||
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
|
logger.info(
|
||||||
|
f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)"
|
||||||
|
)
|
||||||
|
|
||||||
# Re-execute
|
# Re-execute
|
||||||
current_result = await self._execute_pipeline(new_pipeline, context)
|
current_result = await self._execute_pipeline(new_pipeline, context)
|
||||||
current_pipeline = new_pipeline
|
current_pipeline = new_pipeline
|
||||||
|
|
||||||
# Record reflection in metadata
|
# Record reflection in metadata
|
||||||
current_result.metadata["reflections"] = [
|
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
|
||||||
r.model_dump() for r in reflections
|
|
||||||
]
|
|
||||||
|
|
||||||
if current_result.status == StageStatus.COMPLETED:
|
if current_result.status == StageStatus.COMPLETED:
|
||||||
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
|
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
|
||||||
return current_result
|
return current_result
|
||||||
|
|
||||||
# Exhausted reflections
|
# Exhausted reflections
|
||||||
logger.warning(
|
logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)")
|
||||||
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
|
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
|
||||||
)
|
|
||||||
current_result.metadata["reflections"] = [
|
|
||||||
r.model_dump() for r in reflections
|
|
||||||
]
|
|
||||||
return current_result
|
return current_result
|
||||||
|
|
||||||
async def _execute_pipeline(
|
async def _execute_pipeline(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
|
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
|
||||||
result = PipelineResult(pipeline_name=pipeline.name)
|
result = PipelineResult(pipeline_name=pipeline.name)
|
||||||
|
|
@ -151,7 +171,9 @@ class PipelineEngine:
|
||||||
|
|
||||||
# 逐层执行
|
# 逐层执行
|
||||||
for level, stages in enumerate(level_groups):
|
for level, stages in enumerate(level_groups):
|
||||||
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
|
logger.info(
|
||||||
|
f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)"
|
||||||
|
)
|
||||||
|
|
||||||
# 并行执行同层 stages
|
# 并行执行同层 stages
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
@ -173,9 +195,11 @@ class PipelineEngine:
|
||||||
# Update step state
|
# Update step state
|
||||||
if self._state_manager is not None and execution_id is not None:
|
if self._state_manager is not None and execution_id is not None:
|
||||||
try:
|
try:
|
||||||
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
step_status = (
|
||||||
step_output = sr.output_data if hasattr(sr, 'output_data') else None
|
"completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
||||||
step_error = sr.error_message if hasattr(sr, 'error_message') else None
|
)
|
||||||
|
step_output = sr.output_data if hasattr(sr, "output_data") else None
|
||||||
|
step_error = sr.error_message if hasattr(sr, "error_message") else None
|
||||||
await self._state_manager.update_step(
|
await self._state_manager.update_step(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
step_name=stage.name,
|
step_name=stage.name,
|
||||||
|
|
@ -189,19 +213,21 @@ class PipelineEngine:
|
||||||
# 收集输出变量
|
# 收集输出变量
|
||||||
if sr.output_data and isinstance(sr, dict):
|
if sr.output_data and isinstance(sr, dict):
|
||||||
pass
|
pass
|
||||||
elif hasattr(sr, 'output_data') and sr.output_data:
|
elif hasattr(sr, "output_data") and sr.output_data:
|
||||||
for output_key in stage.outputs:
|
for output_key in stage.outputs:
|
||||||
if output_key in sr.output_data:
|
if output_key in sr.output_data:
|
||||||
result.variables[output_key] = sr.output_data[output_key]
|
result.variables[output_key] = sr.output_data[output_key]
|
||||||
|
|
||||||
# 检查是否需要中止
|
# 检查是否需要中止
|
||||||
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
|
if hasattr(sr, "status") and sr.status == StageStatus.FAILED:
|
||||||
if not stage.continue_on_failure:
|
if not stage.continue_on_failure:
|
||||||
# Execute Saga compensation for completed steps
|
# Execute Saga compensation for completed steps
|
||||||
compensation_results = await saga.compensate()
|
compensation_results = await saga.compensate()
|
||||||
if compensation_results:
|
if compensation_results:
|
||||||
failed_compensations = [
|
failed_compensations = [
|
||||||
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
|
cr
|
||||||
|
for cr in compensation_results
|
||||||
|
if not cr.success and cr.error != "no_compensation_needed"
|
||||||
]
|
]
|
||||||
if failed_compensations:
|
if failed_compensations:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -219,7 +245,12 @@ class PipelineEngine:
|
||||||
step_name=stage.name,
|
step_name=stage.name,
|
||||||
error=result.error_message,
|
error=result.error_message,
|
||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
except (
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
RuntimeError,
|
||||||
|
ValueError,
|
||||||
|
) as exc:
|
||||||
logger.warning(f"Failed to persist failure state: {exc}")
|
logger.warning(f"Failed to persist failure state: {exc}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -252,7 +283,9 @@ class PipelineEngine:
|
||||||
started_at = datetime.now(timezone.utc).isoformat()
|
started_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
# 条件检查
|
# 条件检查
|
||||||
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
|
if stage.condition and not self._evaluate_condition(
|
||||||
|
stage.condition, pipeline_result.variables
|
||||||
|
):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.SKIPPED,
|
status=StageStatus.SKIPPED,
|
||||||
|
|
@ -312,7 +345,9 @@ class PipelineEngine:
|
||||||
if status["status"] in ("completed", "failed", "cancelled"):
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
status=StageStatus.COMPLETED
|
||||||
|
if status["status"] == "completed"
|
||||||
|
else StageStatus.FAILED,
|
||||||
output_data=status.get("output_data"),
|
output_data=status.get("output_data"),
|
||||||
error_message=status.get("error_message"),
|
error_message=status.get("error_message"),
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
|
|
@ -406,7 +441,7 @@ class PipelineEngine:
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_nested(data: dict, path: str) -> Any:
|
def _get_nested(data: dict, path: str) -> object:
|
||||||
keys = path.split(".")
|
keys = path.split(".")
|
||||||
current = data
|
current = data
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
@ -497,9 +532,7 @@ class PipelineEngine:
|
||||||
|
|
||||||
if verifier_feedback.passed:
|
if verifier_feedback.passed:
|
||||||
# 审查通过,返回成功结果
|
# 审查通过,返回成功结果
|
||||||
logger.info(
|
logger.info(f"Stage '{stage.name}' passed review in round {round_num}")
|
||||||
f"Stage '{stage.name}' passed review in round {round_num}"
|
|
||||||
)
|
|
||||||
worker_result.output_data = worker_result.output_data or {}
|
worker_result.output_data = worker_result.output_data or {}
|
||||||
worker_result.output_data["adversarial_metadata"] = {
|
worker_result.output_data["adversarial_metadata"] = {
|
||||||
"passed_round": round_num,
|
"passed_round": round_num,
|
||||||
|
|
@ -553,7 +586,7 @@ class PipelineEngine:
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
action: str,
|
action: str,
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, object],
|
||||||
stage: PipelineStage,
|
stage: PipelineStage,
|
||||||
started_at: str,
|
started_at: str,
|
||||||
timeout_seconds: int | None = None,
|
timeout_seconds: int | None = None,
|
||||||
|
|
@ -568,7 +601,9 @@ class PipelineEngine:
|
||||||
started_at: 开始时间
|
started_at: 开始时间
|
||||||
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
|
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
|
||||||
"""
|
"""
|
||||||
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
|
effective_timeout = (
|
||||||
|
timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
|
||||||
|
)
|
||||||
if self._dispatcher is None:
|
if self._dispatcher is None:
|
||||||
# Dry-run 模式
|
# Dry-run 模式
|
||||||
return StageResult(
|
return StageResult(
|
||||||
|
|
@ -602,7 +637,9 @@ class PipelineEngine:
|
||||||
if status["status"] in ("completed", "failed", "cancelled"):
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
status=StageStatus.COMPLETED
|
||||||
|
if status["status"] == "completed"
|
||||||
|
else StageStatus.FAILED,
|
||||||
output_data=status.get("output_data"),
|
output_data=status.get("output_data"),
|
||||||
error_message=status.get("error_message"),
|
error_message=status.get("error_message"),
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
|
|
@ -639,7 +676,7 @@ class PipelineEngine:
|
||||||
async def _execute_verifier(
|
async def _execute_verifier(
|
||||||
self,
|
self,
|
||||||
verifier_name: str,
|
verifier_name: str,
|
||||||
worker_output: dict[str, Any],
|
worker_output: dict[str, object],
|
||||||
stage: PipelineStage,
|
stage: PipelineStage,
|
||||||
started_at: str,
|
started_at: str,
|
||||||
) -> ReviewFeedback:
|
) -> ReviewFeedback:
|
||||||
|
|
@ -679,10 +716,7 @@ class PipelineEngine:
|
||||||
try:
|
try:
|
||||||
feedback = ReviewFeedback(
|
feedback = ReviewFeedback(
|
||||||
passed=output_data.get("passed", False),
|
passed=output_data.get("passed", False),
|
||||||
issues=[
|
issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])],
|
||||||
ReviewIssue(**issue)
|
|
||||||
for issue in output_data.get("issues", [])
|
|
||||||
],
|
|
||||||
summary=output_data.get("summary", "No summary provided"),
|
summary=output_data.get("summary", "No summary provided"),
|
||||||
score=output_data.get("score", 0.0),
|
score=output_data.get("score", 0.0),
|
||||||
)
|
)
|
||||||
|
|
@ -699,7 +733,7 @@ class PipelineEngine:
|
||||||
self,
|
self,
|
||||||
feedback: ReviewFeedback,
|
feedback: ReviewFeedback,
|
||||||
feedback_mode: str = "structured+natural",
|
feedback_mode: str = "structured+natural",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -720,7 +754,7 @@ class PipelineEngine:
|
||||||
for issue in feedback.issues
|
for issue in feedback.issues
|
||||||
]
|
]
|
||||||
|
|
||||||
feedback_context: dict[str, Any] = {
|
feedback_context: dict[str, object] = {
|
||||||
"previous_attempt_failed": True,
|
"previous_attempt_failed": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -756,7 +790,9 @@ class PipelineEngine:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 未知模式,fallback 到 structured+natural
|
# 未知模式,fallback 到 structured+natural
|
||||||
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
|
logger.warning(
|
||||||
|
f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural"
|
||||||
|
)
|
||||||
feedback_context["review_feedback"] = {
|
feedback_context["review_feedback"] = {
|
||||||
"summary": feedback.summary,
|
"summary": feedback.summary,
|
||||||
"issues": issues_list,
|
"issues": issues_list,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -23,7 +22,9 @@ class PipelineLoader:
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
|
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
|
raise FileNotFoundError(
|
||||||
|
f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
content = yaml_path.read_text(encoding="utf-8")
|
content = yaml_path.read_text(encoding="utf-8")
|
||||||
return self.load_from_yaml(content, pipeline_name)
|
return self.load_from_yaml(content, pipeline_name)
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,7 @@ class PipelineExecutionModel(Base):
|
||||||
)
|
)
|
||||||
completed_at = Column(DateTime)
|
completed_at = Column(DateTime)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),)
|
||||||
Index("ix_pipeline_status_created", "status", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineStepHistoryModel(Base):
|
class PipelineStepHistoryModel(Base):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Pipeline 数据模型"""
|
"""Pipeline 数据模型"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -18,8 +18,11 @@ class StageStatus(str, Enum):
|
||||||
|
|
||||||
class ReviewIssue(BaseModel):
|
class ReviewIssue(BaseModel):
|
||||||
"""单条审查问题"""
|
"""单条审查问题"""
|
||||||
|
|
||||||
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
|
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
|
||||||
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别")
|
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(
|
||||||
|
description="问题类别"
|
||||||
|
)
|
||||||
description: str = Field(min_length=1, description="问题描述")
|
description: str = Field(min_length=1, description="问题描述")
|
||||||
location: str | None = Field(default=None, description="文件路径/行号")
|
location: str | None = Field(default=None, description="文件路径/行号")
|
||||||
suggestion: str | None = Field(default=None, description="修复建议")
|
suggestion: str | None = Field(default=None, description="修复建议")
|
||||||
|
|
@ -27,6 +30,7 @@ class ReviewIssue(BaseModel):
|
||||||
|
|
||||||
class ReviewFeedback(BaseModel):
|
class ReviewFeedback(BaseModel):
|
||||||
"""Verifier 返回的结构化审查反馈"""
|
"""Verifier 返回的结构化审查反馈"""
|
||||||
|
|
||||||
passed: bool = Field(description="是否通过审查")
|
passed: bool = Field(description="是否通过审查")
|
||||||
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
|
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
|
||||||
summary: str = Field(min_length=1, description="自然语言审查报告")
|
summary: str = Field(min_length=1, description="自然语言审查报告")
|
||||||
|
|
@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel):
|
||||||
|
|
||||||
class AdversarialState(BaseModel):
|
class AdversarialState(BaseModel):
|
||||||
"""对抗轮次状态追踪"""
|
"""对抗轮次状态追踪"""
|
||||||
|
|
||||||
current_round: int = Field(default=0, description="当前对抗轮次")
|
current_round: int = Field(default=0, description="当前对抗轮次")
|
||||||
max_rounds: int = Field(default=3, description="最大对抗轮次")
|
max_rounds: int = Field(default=3, description="最大对抗轮次")
|
||||||
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
|
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
|
||||||
|
|
@ -46,7 +51,7 @@ class PipelineStage(BaseModel):
|
||||||
agent: str
|
agent: str
|
||||||
action: str
|
action: str
|
||||||
depends_on: list[str] = []
|
depends_on: list[str] = []
|
||||||
inputs: dict[str, Any] = {}
|
inputs: dict[str, object] = {}
|
||||||
outputs: list[str] = []
|
outputs: list[str] = []
|
||||||
timeout_seconds: int = 300
|
timeout_seconds: int = 300
|
||||||
retry_count: int = 0
|
retry_count: int = 0
|
||||||
|
|
@ -56,10 +61,17 @@ class PipelineStage(BaseModel):
|
||||||
compensate: str | None = None
|
compensate: str | None = None
|
||||||
|
|
||||||
# 对抗闭环相关字段
|
# 对抗闭环相关字段
|
||||||
verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式")
|
verifier: str | None = Field(
|
||||||
|
default=None, description="Verifier Agent 名称,配置后启用对抗模式"
|
||||||
|
)
|
||||||
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
|
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
|
||||||
verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds")
|
verifier_timeout_seconds: int = Field(
|
||||||
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式")
|
default=120,
|
||||||
|
description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds",
|
||||||
|
)
|
||||||
|
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(
|
||||||
|
default="structured+natural", description="反馈模式"
|
||||||
|
)
|
||||||
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
|
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
@ -70,13 +82,13 @@ class Pipeline(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
description: str
|
description: str
|
||||||
stages: list[PipelineStage]
|
stages: list[PipelineStage]
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
class StageResult(BaseModel):
|
class StageResult(BaseModel):
|
||||||
stage_name: str
|
stage_name: str
|
||||||
status: StageStatus = StageStatus.PENDING
|
status: StageStatus = StageStatus.PENDING
|
||||||
output_data: dict[str, Any] | None = None
|
output_data: dict[str, object] | None = None
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
started_at: str | None = None
|
started_at: str | None = None
|
||||||
completed_at: str | None = None
|
completed_at: str | None = None
|
||||||
|
|
@ -86,9 +98,9 @@ class PipelineResult(BaseModel):
|
||||||
pipeline_name: str
|
pipeline_name: str
|
||||||
status: StageStatus = StageStatus.PENDING
|
status: StageStatus = StageStatus.PENDING
|
||||||
stage_results: dict[str, StageResult] = {}
|
stage_results: dict[str, StageResult] = {}
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, object] = {}
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
metadata: dict[str, Any] = {}
|
metadata: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveConfig(BaseModel):
|
class AdaptiveConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Callable, Coroutine
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_models import (
|
from agentkit.orchestrator.pipeline_models import (
|
||||||
PipelineExecutionModel,
|
PipelineExecutionModel,
|
||||||
|
|
@ -183,7 +183,7 @@ class PipelineStateRedis:
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
async def _safe_redis_call(
|
async def _safe_redis_call(
|
||||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
|
self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object
|
||||||
) -> object | None:
|
) -> object | None:
|
||||||
"""Execute a Redis call, falling back to memory on failure.
|
"""Execute a Redis call, falling back to memory on failure.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,30 @@
|
||||||
生成修正后的 Pipeline 重新执行。
|
生成修正后的 Pipeline 重新执行。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_schema import (
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineResult,
|
PipelineResult,
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
ReflectionReport,
|
ReflectionReport,
|
||||||
StageResult,
|
|
||||||
StageStatus,
|
StageStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _LLMGatewayLike(Protocol):
|
||||||
|
async def chat(self, **kwargs: object) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
class PipelineReflector:
|
class PipelineReflector:
|
||||||
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
|
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
|
||||||
|
|
||||||
|
|
@ -27,7 +35,7 @@ class PipelineReflector:
|
||||||
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
|
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_gateway: Any = None):
|
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
||||||
async def reflect(
|
async def reflect(
|
||||||
|
|
@ -54,19 +62,25 @@ class PipelineReflector:
|
||||||
if self._llm_gateway is not None:
|
if self._llm_gateway is not None:
|
||||||
try:
|
try:
|
||||||
return await self._llm_reflect(
|
return await self._llm_reflect(
|
||||||
pipeline, failed_stage, error_message,
|
pipeline,
|
||||||
completed_outputs, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
completed_outputs,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
|
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
|
||||||
|
|
||||||
# 规则兜底:基于错误信息分类
|
# 规则兜底:基于错误信息分类
|
||||||
return self._rule_based_reflect(
|
return self._rule_based_reflect(
|
||||||
failed_stage, error_message, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _find_failure(
|
def _find_failure(
|
||||||
self, result: PipelineResult,
|
self,
|
||||||
|
result: PipelineResult,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""找到第一个失败的 stage 及其错误信息。"""
|
"""找到第一个失败的 stage 及其错误信息。"""
|
||||||
for name, sr in result.stage_results.items():
|
for name, sr in result.stage_results.items():
|
||||||
|
|
@ -75,8 +89,9 @@ class PipelineReflector:
|
||||||
return "", "no failed stage found"
|
return "", "no failed stage found"
|
||||||
|
|
||||||
def _collect_completed_outputs(
|
def _collect_completed_outputs(
|
||||||
self, result: PipelineResult,
|
self,
|
||||||
) -> dict[str, Any]:
|
result: PipelineResult,
|
||||||
|
) -> dict[str, object]:
|
||||||
"""收集已完成步骤的输出。"""
|
"""收集已完成步骤的输出。"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for name, sr in result.stage_results.items():
|
for name, sr in result.stage_results.items():
|
||||||
|
|
@ -89,13 +104,16 @@ class PipelineReflector:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
failed_stage: str,
|
failed_stage: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
completed_outputs: dict[str, Any],
|
completed_outputs: dict[str, object],
|
||||||
reflection_number: int,
|
reflection_number: int,
|
||||||
) -> ReflectionReport:
|
) -> ReflectionReport:
|
||||||
"""使用 LLM 分析失败原因。"""
|
"""使用 LLM 分析失败原因。"""
|
||||||
prompt = self._build_reflection_prompt(
|
prompt = self._build_reflection_prompt(
|
||||||
pipeline, failed_stage, error_message,
|
pipeline,
|
||||||
completed_outputs, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
completed_outputs,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._llm_gateway.chat(
|
response = await self._llm_gateway.chat(
|
||||||
|
|
@ -106,7 +124,9 @@ class PipelineReflector:
|
||||||
# 解析 LLM 返回的 JSON
|
# 解析 LLM 返回的 JSON
|
||||||
content = response.content if hasattr(response, "content") else str(response)
|
content = response.content if hasattr(response, "content") else str(response)
|
||||||
return self._parse_reflection_response(
|
return self._parse_reflection_response(
|
||||||
content, failed_stage, reflection_number,
|
content,
|
||||||
|
failed_stage,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_reflection_prompt(
|
def _build_reflection_prompt(
|
||||||
|
|
@ -114,15 +134,14 @@ class PipelineReflector:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
failed_stage: str,
|
failed_stage: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
completed_outputs: dict[str, Any],
|
completed_outputs: dict[str, object],
|
||||||
reflection_number: int,
|
reflection_number: int,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建反思提示词。"""
|
"""构建反思提示词。"""
|
||||||
stage_descriptions = []
|
stage_descriptions = []
|
||||||
for s in pipeline.stages:
|
for s in pipeline.stages:
|
||||||
stage_descriptions.append(
|
stage_descriptions.append(
|
||||||
f" - {s.name}: agent={s.agent}, action={s.action}, "
|
f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}"
|
||||||
f"depends_on={s.depends_on}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
completed_summary = json.dumps(
|
completed_summary = json.dumps(
|
||||||
|
|
@ -174,7 +193,9 @@ JSON response:"""
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
logger.warning(f"Failed to parse LLM reflection response: {e}")
|
logger.warning(f"Failed to parse LLM reflection response: {e}")
|
||||||
return self._rule_based_reflect(
|
return self._rule_based_reflect(
|
||||||
failed_stage, content, reflection_number,
|
failed_stage,
|
||||||
|
content,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _rule_based_reflect(
|
def _rule_based_reflect(
|
||||||
|
|
@ -218,7 +239,7 @@ class PipelineReplanner:
|
||||||
保留已完成步骤的结果,仅重新规划失败及后续步骤。
|
保留已完成步骤的结果,仅重新规划失败及后续步骤。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_gateway: Any = None):
|
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
||||||
async def replan(
|
async def replan(
|
||||||
|
|
@ -255,8 +276,7 @@ class PipelineReplanner:
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""使用 LLM 生成修正后的 Pipeline。"""
|
"""使用 LLM 生成修正后的 Pipeline。"""
|
||||||
completed_stages = [
|
completed_stages = [
|
||||||
name for name, sr in result.stage_results.items()
|
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
|
||||||
if sr.status == StageStatus.COMPLETED
|
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = f"""Based on the reflection report, generate a corrected pipeline.
|
prompt = f"""Based on the reflection report, generate a corrected pipeline.
|
||||||
|
|
@ -284,7 +304,9 @@ JSON pipeline:"""
|
||||||
return self._parse_pipeline_response(content, pipeline)
|
return self._parse_pipeline_response(content, pipeline)
|
||||||
|
|
||||||
def _parse_pipeline_response(
|
def _parse_pipeline_response(
|
||||||
self, content: str, original: Pipeline,
|
self,
|
||||||
|
content: str,
|
||||||
|
original: Pipeline,
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""解析 LLM 返回的 Pipeline JSON。"""
|
"""解析 LLM 返回的 Pipeline JSON。"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -294,9 +316,7 @@ JSON pipeline:"""
|
||||||
text = "\n".join(lines[1:-1])
|
text = "\n".join(lines[1:-1])
|
||||||
|
|
||||||
data = json.loads(text)
|
data = json.loads(text)
|
||||||
stages = [
|
stages = [PipelineStage(**s) for s in data.get("stages", [])]
|
||||||
PipelineStage(**s) for s in data.get("stages", [])
|
|
||||||
]
|
|
||||||
return Pipeline(
|
return Pipeline(
|
||||||
name=data.get("name", original.name),
|
name=data.get("name", original.name),
|
||||||
version=data.get("version", original.version),
|
version=data.get("version", original.version),
|
||||||
|
|
@ -316,8 +336,7 @@ JSON pipeline:"""
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""基于规则的兜底重规划。"""
|
"""基于规则的兜底重规划。"""
|
||||||
completed_stages = {
|
completed_stages = {
|
||||||
name for name, sr in result.stage_results.items()
|
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
|
||||||
if sr.status == StageStatus.COMPLETED
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 构建修正后的 stages 列表
|
# 构建修正后的 stages 列表
|
||||||
|
|
@ -345,17 +364,21 @@ JSON pipeline:"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def _adjust_failed_stage(
|
def _adjust_failed_stage(
|
||||||
self, stage: PipelineStage, report: ReflectionReport,
|
self,
|
||||||
|
stage: PipelineStage,
|
||||||
|
report: ReflectionReport,
|
||||||
) -> PipelineStage:
|
) -> PipelineStage:
|
||||||
"""根据反思报告调整失败的步骤。"""
|
"""根据反思报告调整失败的步骤。"""
|
||||||
adjustments: dict[str, Any] = {}
|
adjustments: dict[str, object] = {}
|
||||||
|
|
||||||
if report.failure_type == "timeout":
|
if report.failure_type == "timeout":
|
||||||
adjustments["timeout_seconds"] = min(
|
adjustments["timeout_seconds"] = min(
|
||||||
stage.timeout_seconds * 2, 3600,
|
stage.timeout_seconds * 2,
|
||||||
|
3600,
|
||||||
)
|
)
|
||||||
if stage.retry_policy is None:
|
if stage.retry_policy is None:
|
||||||
from agentkit.orchestrator.retry import StepRetryPolicy
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
|
||||||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||||||
|
|
||||||
elif report.failure_type == "resource_error":
|
elif report.failure_type == "resource_error":
|
||||||
|
|
@ -365,6 +388,7 @@ JSON pipeline:"""
|
||||||
# 添加重试策略,可能输入在后续可用
|
# 添加重试策略,可能输入在后续可用
|
||||||
if stage.retry_policy is None:
|
if stage.retry_policy is None:
|
||||||
from agentkit.orchestrator.retry import StepRetryPolicy
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
|
||||||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||||||
|
|
||||||
return stage.model_copy(update=adjustments)
|
return stage.model_copy(update=adjustments)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ class StepRetryPolicy:
|
||||||
def calculate_delay(self, attempt: int) -> float:
|
def calculate_delay(self, attempt: int) -> float:
|
||||||
"""Calculate delay for given attempt number (0-based)"""
|
"""Calculate delay for given attempt number (0-based)"""
|
||||||
delay = min(
|
delay = min(
|
||||||
self.base_delay * (self.exponential_base ** attempt),
|
self.base_delay * (self.exponential_base**attempt),
|
||||||
self.max_delay,
|
self.max_delay,
|
||||||
)
|
)
|
||||||
if self.jitter:
|
if self.jitter:
|
||||||
|
|
@ -36,10 +36,10 @@ class StepRetryPolicy:
|
||||||
|
|
||||||
|
|
||||||
async def execute_with_retry(
|
async def execute_with_retry(
|
||||||
func: Callable[..., Awaitable[Any]],
|
func: Callable[..., Awaitable[object]],
|
||||||
retry_policy: StepRetryPolicy | None = None,
|
retry_policy: StepRetryPolicy | None = None,
|
||||||
step_name: str = "",
|
step_name: str = "",
|
||||||
) -> Any:
|
) -> object:
|
||||||
"""Execute a function with retry policy"""
|
"""Execute a function with retry policy"""
|
||||||
if retry_policy is None:
|
if retry_policy is None:
|
||||||
return await func()
|
return await func()
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
|
from agentkit.orchestrator.pipeline_schema import PipelineStage
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStage(PipelineStage):
|
class WorkflowStage(PipelineStage):
|
||||||
"""A workflow stage extending PipelineStage with type and config."""
|
"""A workflow stage extending PipelineStage with type and config."""
|
||||||
|
|
||||||
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
|
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
|
||||||
config: dict[str, Any] = Field(default_factory=dict)
|
config: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDefinition(BaseModel):
|
class WorkflowDefinition(BaseModel):
|
||||||
|
|
@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
version: int = 1
|
version: int = 1
|
||||||
stages: list[WorkflowStage] = Field(default_factory=list)
|
stages: list[WorkflowStage] = Field(default_factory=list)
|
||||||
triggers: list[dict[str, Any]] = Field(default_factory=list)
|
triggers: list[dict[str, object]] = Field(default_factory=list)
|
||||||
variables_schema: dict[str, Any] = Field(default_factory=dict)
|
variables_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
output_schema: dict[str, Any] = Field(default_factory=dict)
|
output_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
|
@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel):
|
||||||
workflow_id: str = ""
|
workflow_id: str = ""
|
||||||
status: str = "pending" # pending|running|paused|completed|failed|cancelled
|
status: str = "pending" # pending|running|paused|completed|failed|cancelled
|
||||||
current_stage: str | None = None
|
current_stage: str | None = None
|
||||||
stage_results: dict[str, Any] = Field(default_factory=dict)
|
stage_results: dict[str, object] = Field(default_factory=dict)
|
||||||
started_at: str | None = None
|
started_at: str | None = None
|
||||||
completed_at: str | None = None
|
completed_at: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
variables: dict[str, Any] = Field(default_factory=dict)
|
variables: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowSummary(BaseModel):
|
class WorkflowSummary(BaseModel):
|
||||||
|
|
@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
stages: list[WorkflowStage] = Field(default_factory=list)
|
stages: list[WorkflowStage] = Field(default_factory=list)
|
||||||
triggers: list[dict[str, Any]] = Field(default_factory=list)
|
triggers: list[dict[str, object]] = Field(default_factory=list)
|
||||||
variables_schema: dict[str, Any] = Field(default_factory=dict)
|
variables_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
output_schema: dict[str, Any] = Field(default_factory=dict)
|
output_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ExecuteWorkflowRequest(BaseModel):
|
class ExecuteWorkflowRequest(BaseModel):
|
||||||
"""Request body for executing a workflow."""
|
"""Request body for executing a workflow."""
|
||||||
|
|
||||||
variables: dict[str, Any] = Field(default_factory=dict)
|
variables: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ApproveRequest(BaseModel):
|
class ApproveRequest(BaseModel):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue