From 7b1b1980586d12c8a825d76f5ca8134f5ccf8648 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Wed, 1 Jul 2026 02:41:14 +0800 Subject: [PATCH] refactor(orchestrator+bitable): remove Any from type signatures Eliminate 112 Any usages across orchestrator/ (62) and bitable/ (50) via: - TYPE_CHECKING Protocol for Redis/LLMGateway/Plan/Dispatcher/StateManager - object for arbitrary dict/list/value types (Pydantic v2 serializes fine) - RecalcTask concrete import (replacing Any in recalc_worker.py) - Coroutine[object, object, object] for async generic - Remove unused Any imports (F401 cleanup) Note: Avoided recursive TypeAlias (FieldValue) because Pydantic v2 cannot build schemas for recursive named aliases (RecursionError). Tests: 245 passed (bitable 91 + orchestrator 154), 0 regressions ruff: All checks passed --- src/agentkit/bitable/formula/engine.py | 17 ++- .../bitable/ingestion/api_collector.py | 10 +- src/agentkit/bitable/ingestion/database.py | 15 ++- src/agentkit/bitable/ingestion/excel.py | 9 +- src/agentkit/bitable/models.py | 16 ++- src/agentkit/bitable/recalc_worker.py | 11 +- src/agentkit/bitable/repository.py | 37 +++--- src/agentkit/orchestrator/checkpoint.py | 46 +++++-- src/agentkit/orchestrator/compensation.py | 20 ++- src/agentkit/orchestrator/dynamic_pipeline.py | 15 +-- src/agentkit/orchestrator/handoff.py | 25 +++- src/agentkit/orchestrator/pipeline_engine.py | 116 ++++++++++++------ src/agentkit/orchestrator/pipeline_loader.py | 5 +- src/agentkit/orchestrator/pipeline_models.py | 4 +- src/agentkit/orchestrator/pipeline_schema.py | 34 +++-- src/agentkit/orchestrator/pipeline_state.py | 4 +- src/agentkit/orchestrator/reflection.py | 82 ++++++++----- src/agentkit/orchestrator/retry.py | 8 +- src/agentkit/orchestrator/workflow_schema.py | 23 ++-- 19 files changed, 299 insertions(+), 198 deletions(-) diff --git a/src/agentkit/bitable/formula/engine.py b/src/agentkit/bitable/formula/engine.py index 589c6f6..9e06e07 100644 --- a/src/agentkit/bitable/formula/engine.py +++ b/src/agentkit/bitable/formula/engine.py @@ -16,7 +16,6 @@ from __future__ import annotations import ast from collections import deque -from typing import Any from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY from agentkit.bitable.formula.parser import ( @@ -104,9 +103,9 @@ class FormulaEngine: def evaluate( self, field_id: str, - row_values: dict[str, Any], - column_values: dict[str, list[Any]] | None = None, - ) -> Any: + row_values: dict[str, object], + column_values: dict[str, list[object]] | None = None, + ) -> object: """Evaluate a formula field for a specific record. Args: @@ -130,7 +129,7 @@ class FormulaEngine: # Build the field_values dict for the evaluator # Aggregate refs get column values (lists), row refs get row values (scalars) - eval_values: dict[str, Any] = {} + eval_values: dict[str, object] = {} # Map real field IDs to safe names for safe_name, real_id in entry.field_mapping.items(): @@ -143,16 +142,16 @@ class FormulaEngine: def evaluate_all_for_record( self, - row_values: dict[str, Any], - column_values: dict[str, list[Any]] | None = None, - ) -> dict[str, Any]: + row_values: dict[str, object], + column_values: dict[str, list[object]] | None = None, + ) -> dict[str, object]: """Evaluate all registered formulas for a record. Returns a dict of field_id → computed value. Formulas are evaluated in topological order so that formula-to-formula dependencies are resolved correctly. """ - results: dict[str, Any] = {} + results: dict[str, object] = {} column_values = column_values or {} for field_id in self.topological_order(): diff --git a/src/agentkit/bitable/ingestion/api_collector.py b/src/agentkit/bitable/ingestion/api_collector.py index b14d92a..3dc8f59 100644 --- a/src/agentkit/bitable/ingestion/api_collector.py +++ b/src/agentkit/bitable/ingestion/api_collector.py @@ -16,13 +16,11 @@ Usage:: from __future__ import annotations -from typing import Any - def transform_records( - records: list[dict[str, Any]], + records: list[dict[str, object]], field_mapping: dict[str, str], -) -> list[dict[str, Any]]: +) -> list[dict[str, object]]: """Map source record keys to bitable field IDs via field_mapping. Keys not in ``field_mapping`` are dropped. Values are passed through @@ -40,9 +38,9 @@ def transform_records( if not field_mapping: return [] - transformed: list[dict[str, Any]] = [] + transformed: list[dict[str, object]] = [] for rec in records: - out: dict[str, Any] = {} + out: dict[str, object] = {} for src_key, field_id in field_mapping.items(): if src_key in rec: out[field_id] = rec[src_key] diff --git a/src/agentkit/bitable/ingestion/database.py b/src/agentkit/bitable/ingestion/database.py index 9f8d9fe..78f56db 100644 --- a/src/agentkit/bitable/ingestion/database.py +++ b/src/agentkit/bitable/ingestion/database.py @@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable): from __future__ import annotations import logging -from typing import Any from sqlalchemy import ( BigInteger, @@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = { READ_BATCH = 1000 -def infer_field_type(sqla_type: Any) -> str: +def infer_field_type(sqla_type: object) -> str: """Map a SQLAlchemy column type instance or class to a bitable field type. Handles both type instances (``Integer()``) and type classes (``Integer``). @@ -78,7 +77,7 @@ def import_table( table_name: str, *, max_rows: int = 50_000, -) -> dict[str, Any]: +) -> dict[str, object]: """Reflect a single table from an external DB. Returns ``{"table_name": str, "fields": [...], "records": [...], @@ -97,7 +96,7 @@ def import_table( engine.dispose() -def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]: +def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]: """Reflect one table and read its rows.""" insp = inspect(engine) @@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st table = Table(table_name, metadata, autoload_with=engine) # Build field definitions - fields: list[dict[str, Any]] = [] + fields: list[dict[str, object]] = [] pk_columns = list(table.primary_key.columns) pk_name = pk_columns[0].name if pk_columns else None @@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st pk_name = "id" # Read rows - records: list[dict[str, Any]] = [] + records: list[dict[str, object]] = [] with engine.connect() as conn: result = conn.execute(select(table)) for i, row in enumerate(result): if i >= max_rows: logger.warning("Table %r truncated at %d rows during import", table_name, max_rows) break - rec: dict[str, Any] = {} + rec: dict[str, object] = {} for col in table.columns: val = getattr(row, col.name, None) if val is not None: @@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st } -def _serialize(val: Any) -> Any: +def _serialize(val: object) -> object: """Serialize a DB value to JSON-safe form.""" from datetime import date, datetime from decimal import Decimal diff --git a/src/agentkit/bitable/ingestion/excel.py b/src/agentkit/bitable/ingestion/excel.py index 34365d4..d92af55 100644 --- a/src/agentkit/bitable/ingestion/excel.py +++ b/src/agentkit/bitable/ingestion/excel.py @@ -18,7 +18,6 @@ import logging import socket from dataclasses import dataclass, field from pathlib import Path -from typing import Any from urllib.parse import urlparse import httpx @@ -36,7 +35,7 @@ class ParsedSheet: name: str columns: list[str] = field(default_factory=list) field_types: list[str] = field(default_factory=list) # "text" | "number" | "date" - records: list[dict[str, Any]] = field(default_factory=list) + records: list[dict[str, object]] = field(default_factory=list) def parse_excel(file_path: str | Path) -> list[ParsedSheet]: @@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None: col_count = len(clean_headers) field_types = _infer_column_types(data_rows, col_count) - records: list[dict[str, Any]] = [] + records: list[dict[str, object]] = [] for row in data_rows: - rec: dict[str, Any] = {} + rec: dict[str, object] = {} for i, col_name in enumerate(clean_headers): val = row[i] if i < len(row) else None if val is not None: @@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]: return types -def _coerce_value(val: Any, field_type: str) -> Any: +def _coerce_value(val: object, field_type: str) -> object: """Coerce a cell value to the inferred field type. Truncate long strings.""" if field_type == "date": from datetime import datetime diff --git a/src/agentkit/bitable/models.py b/src/agentkit/bitable/models.py index 2f1ea76..db8d90e 100644 --- a/src/agentkit/bitable/models.py +++ b/src/agentkit/bitable/models.py @@ -9,10 +9,14 @@ from __future__ import annotations from datetime import datetime, timezone from enum import Enum -from typing import Any from pydantic import BaseModel, ConfigDict, Field as PydanticField +# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of +# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive +# named aliases (RecursionError). `object` is the most permissive type and +# Pydantic v2 serializes dict/list/primitive values fine at runtime. + def _utcnow() -> datetime: return datetime.now(timezone.utc) @@ -97,7 +101,7 @@ class Table(BaseModel): # --------------------------------------------------------------------------- # Status select field options — labels and colors match Feishu Bitable defaults. -_STATUS_OPTIONS: list[dict[str, Any]] = [ +_STATUS_OPTIONS: list[dict[str, object]] = [ {"label": "未开始", "value": "not_started", "color": "default"}, {"label": "进行中", "value": "in_progress", "color": "processing"}, {"label": "已完成", "value": "done", "color": "success"}, @@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [ #: Templates for the 5 default fields created on every new table (R2). #: agent-owned fields (创建人/创建时间) are auto-filled by the service layer #: on record creation; user-owned fields are user-editable. -DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [ +DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [ { "name": "标题", "field_type": FieldType.text, @@ -155,7 +159,7 @@ class Field(BaseModel): table_id: str name: str field_type: FieldType - config: dict[str, Any] = PydanticField(default_factory=dict) + config: dict[str, object] = PydanticField(default_factory=dict) owner: FieldOwner = FieldOwner.user created_at: datetime = PydanticField(default_factory=_utcnow) @@ -167,7 +171,7 @@ class Record(BaseModel): id: str table_id: str - values: dict[str, Any] = PydanticField(default_factory=dict) + values: dict[str, object] = PydanticField(default_factory=dict) created_at: datetime = PydanticField(default_factory=_utcnow) updated_at: datetime = PydanticField(default_factory=_utcnow) @@ -181,7 +185,7 @@ class View(BaseModel): table_id: str name: str view_type: ViewType = ViewType.grid - config: dict[str, Any] = PydanticField(default_factory=dict) + config: dict[str, object] = PydanticField(default_factory=dict) created_at: datetime = PydanticField(default_factory=_utcnow) diff --git a/src/agentkit/bitable/recalc_worker.py b/src/agentkit/bitable/recalc_worker.py index 0976b32..d6b8497 100644 --- a/src/agentkit/bitable/recalc_worker.py +++ b/src/agentkit/bitable/recalc_worker.py @@ -18,11 +18,10 @@ from __future__ import annotations import asyncio import logging -from typing import Any from agentkit.bitable.db import BitableDB from agentkit.bitable.formula.engine import FormulaEngine -from agentkit.bitable.models import FieldType, RecalcStatus +from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask from agentkit.bitable.repository import BitableRepository from agentkit.bitable.service import BitableService @@ -124,7 +123,7 @@ class RecalcWorker: logger.exception("RecalcWorker error in main loop") await asyncio.sleep(self._poll_interval) - async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]: + async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]: """Sort claimed tasks so dependencies are processed first (P1 #7). Groups tasks by table_id, builds (or reuses) the engine to get the @@ -146,7 +145,7 @@ class RecalcWorker: order = engine.topological_order() topo_index[tid] = {fid: i for i, fid in enumerate(order)} - def _key(t: Any) -> tuple[str, str, int]: + def _key(t: RecalcTask) -> tuple[str, str, int]: idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30) return (t.table_id, t.record_id, idx) @@ -175,7 +174,7 @@ class RecalcWorker: except Exception: logger.exception("RecalcWorker reaper error") - async def process_task(self, task: Any) -> None: + async def process_task(self, task: RecalcTask) -> None: """Process a single recalc task: evaluate formula → write result. The task is expected to already be in ``calculating`` status when @@ -216,7 +215,7 @@ class RecalcWorker: return deps = engine.get_dependencies(task.field_id) - column_values: dict[str, list[Any]] = {} + column_values: dict[str, list[object]] = {} for dep_field_id in deps: column_values[dep_field_id] = await self._repo.get_column_values( task.table_id, dep_field_id diff --git a/src/agentkit/bitable/repository.py b/src/agentkit/bitable/repository.py index 2560dd4..8278ff5 100644 --- a/src/agentkit/bitable/repository.py +++ b/src/agentkit/bitable/repository.py @@ -10,7 +10,6 @@ from __future__ import annotations import logging import re from datetime import datetime, timedelta, timezone -from typing import Any from sqlalchemy import delete, func, insert, select, text, update from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -102,7 +101,7 @@ class BitableRepository: result = await session.execute(stmt) return [BitableFile.model_validate(e) for e in result.scalars().all()] - async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None: + async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None: """Update a file's attributes.""" async with self._session_factory() as session: stmt = ( @@ -181,7 +180,7 @@ class BitableRepository: result = await session.execute(stmt) return [Table.model_validate(e) for e in result.scalars().all()] - async def update_table(self, table_id: str, **kwargs: Any) -> Table | None: + async def update_table(self, table_id: str, **kwargs: object) -> Table | None: """Update a table's attributes.""" async with self._session_factory() as session: stmt = ( @@ -236,7 +235,7 @@ class BitableRepository: table_id: str, name: str, field_type: FieldType, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, owner: FieldOwner = FieldOwner.user, ) -> Field: """Create a new field in a table.""" @@ -277,7 +276,7 @@ class BitableRepository: result = await session.execute(stmt) return [Field.model_validate(e) for e in result.scalars().all()] - async def update_field(self, field_id: str, **kwargs: Any) -> Field | None: + async def update_field(self, field_id: str, **kwargs: object) -> Field | None: """Update a field's attributes.""" async with self._session_factory() as session: stmt = ( @@ -300,7 +299,7 @@ class BitableRepository: # ── Records ───────────────────────────────────────────── - async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record: + async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record: """Create a new record.""" async with self._session_factory() as session: stmt = ( @@ -318,7 +317,7 @@ class BitableRepository: return Record.model_validate(entity) async def create_records_batch( - self, table_id: str, records_values: list[dict[str, Any]] + self, table_id: str, records_values: list[dict[str, object]] ) -> list[Record]: """Batch-insert multiple records (P2 #19: eliminates per-record INSERT). @@ -376,7 +375,7 @@ class BitableRepository: return [Record.model_validate(e) for e in entities], next_cursor - async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None: + async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None: """Update a record's values (full replace).""" async with self._session_factory() as session: stmt = ( @@ -413,7 +412,7 @@ class BitableRepository: table_id: str, name: str, view_type: ViewType = ViewType.grid, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, ) -> View: """Create a new view.""" async with self._session_factory() as session: @@ -451,7 +450,7 @@ class BitableRepository: result = await session.execute(stmt) return [View.model_validate(e) for e in result.scalars().all()] - async def update_view(self, view_id: str, **kwargs: Any) -> View | None: + async def update_view(self, view_id: str, **kwargs: object) -> View | None: """Update a view's attributes.""" async with self._session_factory() as session: stmt = ( @@ -543,7 +542,7 @@ class BitableRepository: ) -> None: """Update a recalc task's status.""" async with self._session_factory() as session: - kwargs: dict[str, Any] = {"status": status.value} + kwargs: dict[str, object] = {"status": status.value} if error_message is not None: kwargs["error_message"] = error_message if status in (RecalcStatus.done, RecalcStatus.error): @@ -630,7 +629,7 @@ class BitableRepository: return result_map async def upsert_record_agent_fields( - self, record_id: str, agent_field_values: dict[str, Any] + self, record_id: str, agent_field_values: dict[str, object] ) -> None: """Update agent-owned fields using jsonb_set (KTD8). @@ -646,7 +645,7 @@ class BitableRepository: # Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect # misparses the `::` as part of the param name. inner = "values" - params: dict[str, Any] = {"record_id": record_id} + params: dict[str, object] = {"record_id": record_id} for i, (field_id, value) in enumerate(agent_field_values.items()): param_key = f"v{i}" inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)" @@ -660,8 +659,8 @@ class BitableRepository: async def list_records_filtered( self, table_id: str, - filters: list[dict[str, Any]] | None = None, - sorts: list[dict[str, Any]] | None = None, + filters: list[dict[str, object]] | None = None, + sorts: list[dict[str, object]] | None = None, cursor: str | None = None, limit: int = 50, ) -> tuple[list[Record], str | None]: @@ -682,7 +681,7 @@ class BitableRepository: # Build raw SQL with JSONB filter/sort translation. # ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer). where_clauses = ["table_id = :table_id"] - params: dict[str, Any] = {"table_id": table_id} + params: dict[str, object] = {"table_id": table_id} if filters: for i, f in enumerate(filters): @@ -783,7 +782,7 @@ class BitableRepository: last_mapping = last_row._mapping # Build composite cursor from sort values + id. # Sort values are extracted as text to match `values->>'fid'` expressions. - sv: list[Any] = [] + sv: list[object] = [] last_values = last_mapping.get("values") if isinstance(last_values, str): # asyncpg may return JSONB as str in raw text() queries. @@ -852,7 +851,7 @@ class BitableRepository: # ── Recalc support (U3) ──────────────────────────────── - async def get_column_values(self, table_id: str, field_id: str) -> list[Any]: + async def get_column_values(self, table_id: str, field_id: str) -> list[object]: """Get all values for a field across all records in a table (for aggregates). Returns a list of values (preserving order by record id). Missing values @@ -866,7 +865,7 @@ class BitableRepository: result = await session.execute(sql, {"field_id": field_id, "table_id": table_id}) return [row[0] for row in result.fetchall()] - async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None: + async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None: """Set a single formula field value in a record's JSONB (jsonb_set).""" import json diff --git a/src/agentkit/orchestrator/checkpoint.py b/src/agentkit/orchestrator/checkpoint.py index 856e9cf..9cee4d0 100644 --- a/src/agentkit/orchestrator/checkpoint.py +++ b/src/agentkit/orchestrator/checkpoint.py @@ -17,7 +17,7 @@ import logging import time from dataclasses import asdict, dataclass, field from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING logger = logging.getLogger(__name__) @@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days _KEY_PREFIX = "agentkit:pipeline:checkpoint" +if TYPE_CHECKING: + from typing import Protocol + + class _RedisPipelineLike(Protocol): + def set(self, key: str, value: str, ex: int | None = None) -> object: ... + def zadd(self, name: str, mapping: dict[str, float]) -> object: ... + def get(self, key: str) -> object: ... + def delete(self, *keys: str) -> object: ... + async def execute(self) -> list[object]: ... + + class _RedisLike(Protocol): + async def set(self, key: str, value: str, ex: int | None = None) -> object: ... + async def get(self, key: str) -> object: ... + def pipeline(self) -> _RedisPipelineLike: ... + async def zrange(self, name: str, start: int, stop: int) -> list[object]: ... + + class _PlanLike(Protocol): + @property + def id(self) -> str: ... + def to_dict(self) -> dict[str, object]: ... + + @dataclass class CheckpointData: """单个阶段的 checkpoint 数据。""" @@ -33,15 +55,15 @@ class CheckpointData: phase_id: str phase_name: str phase_status: str - phase_result: dict[str, Any] | None = None + phase_result: dict[str, object] | None = None plan_status: str = "" saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, object]: return asdict(self) @classmethod - def from_dict(cls, data: dict[str, Any]) -> CheckpointData: + def from_dict(cls, data: dict[str, object]) -> CheckpointData: return cls( plan_id=data.get("plan_id", ""), phase_id=data.get("phase_id", ""), @@ -67,7 +89,7 @@ class PipelineCheckpoint: def __init__( self, - redis_client: Any = None, + redis_client: _RedisLike | None = None, prefix: str = _KEY_PREFIX, ttl_seconds: int = _TTL_SECONDS, ) -> None: @@ -78,7 +100,7 @@ class PipelineCheckpoint: # P1 #6: 改用 dict keyed by phase_id,避免重复 append self._memory: dict[str, dict[str, CheckpointData]] = {} # 内存降级存储:plan_id → (plan_dict, saved_timestamp) - self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {} + self._memory_plans: dict[str, tuple[dict[str, object], float]] = {} def _is_expired(self, saved_at: str) -> bool: """检查 checkpoint 是否已过期(内存模式 TTL)。""" @@ -102,7 +124,7 @@ class PipelineCheckpoint: """完整 plan JSON 的存储键。""" return f"{self._prefix}:plan:{plan_id}" - async def save_plan(self, plan: Any) -> None: + async def save_plan(self, plan: _PlanLike) -> None: """保存完整 TeamPlan(用于 resume 重建)。 Args: @@ -121,7 +143,7 @@ class PipelineCheckpoint: except Exception as e: logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}") - async def load_plan(self, plan_id: str) -> dict[str, Any] | None: + async def load_plan(self, plan_id: str) -> dict[str, object] | None: """加载完整 plan JSON。""" # 优先 Redis if self._redis is not None: @@ -142,7 +164,7 @@ class PipelineCheckpoint: return None return plan_dict - async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: + async def save(self, plan_id: str, phase: object, plan_status: str) -> None: """保存阶段 checkpoint。 Args: @@ -212,7 +234,8 @@ class PipelineCheckpoint: if not phase_ids: # Redis 无数据,检查内存(过滤过期) return [ - c for c in self._memory.get(plan_id, {}).values() + c + for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] @@ -236,8 +259,7 @@ class PipelineCheckpoint: # 内存降级(过滤过期 checkpoint) return [ - c for c in self._memory.get(plan_id, {}).values() - if not self._is_expired(c.saved_at) + c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] async def clear(self, plan_id: str) -> None: diff --git a/src/agentkit/orchestrator/compensation.py b/src/agentkit/orchestrator/compensation.py index 87eef65..92cd06a 100644 --- a/src/agentkit/orchestrator/compensation.py +++ b/src/agentkit/orchestrator/compensation.py @@ -1,8 +1,8 @@ """Saga compensation pattern for Pipeline execution""" import logging -from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable +from dataclasses import dataclass +from typing import Awaitable, Callable logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ class CompletedStep: """Record of a completed step with its compensation""" step_name: str - result: Any + result: object compensate_action: str | None = None @@ -28,9 +28,7 @@ class CompensationResult: class SagaOrchestrator: """Orchestrates LIFO compensation for failed pipelines""" - def __init__( - self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None - ): + def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None): """ Args: execute_skill_func: Async function to execute a skill by name @@ -42,7 +40,7 @@ class SagaOrchestrator: def record_completed( self, step_name: str, - result: Any, + result: object, compensate_action: str | None = None, ): """Record a completed step for potential compensation""" @@ -59,9 +57,7 @@ class SagaOrchestrator: results: list[CompensationResult] = [] for step in reversed(self._completed_steps): if step.compensate_action is None: - logger.info( - f"No compensation for step '{step.step_name}', skipping" - ) + logger.info(f"No compensation for step '{step.step_name}', skipping") results.append( CompensationResult( step_name=step.step_name, @@ -82,9 +78,7 @@ class SagaOrchestrator: ) ) except Exception as e: - logger.error( - f"Compensation for step '{step.step_name}' failed: {e}" - ) + logger.error(f"Compensation for step '{step.step_name}' failed: {e}") results.append( CompensationResult( step_name=step.step_name, diff --git a/src/agentkit/orchestrator/dynamic_pipeline.py b/src/agentkit/orchestrator/dynamic_pipeline.py index a6b8e51..cf82500 100644 --- a/src/agentkit/orchestrator/dynamic_pipeline.py +++ b/src/agentkit/orchestrator/dynamic_pipeline.py @@ -4,7 +4,6 @@ """ import logging -from typing import Any from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus @@ -15,7 +14,7 @@ logger = logging.getLogger(__name__) class DynamicPipeline: """动态 Pipeline 组合器""" - def __init__(self, engine: PipelineEngine, loader: Any = None): + def __init__(self, engine: PipelineEngine, loader: object | None = None): self._engine = engine self._loader = loader @@ -23,7 +22,7 @@ class DynamicPipeline: self, pipelines: dict[str, Pipeline], condition_key: str, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """根据条件选择子 Pipeline 执行""" context = context or {} @@ -37,14 +36,16 @@ class DynamicPipeline: ) selected = pipelines[condition_value] - logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}") + logger.info( + f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}" + ) return await self._engine.execute(selected, context) async def execute_nested( self, parent: Pipeline, sub_pipeline_map: dict[str, Pipeline], - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """执行嵌套 Pipeline""" # 先执行父 Pipeline @@ -52,7 +53,7 @@ class DynamicPipeline: # 根据父 Pipeline 结果选择子 Pipeline for stage_name, stage_result in parent_result.stage_results.items(): - if hasattr(stage_result, 'output_data') and stage_result.output_data: + if hasattr(stage_result, "output_data") and stage_result.output_data: sub_pipeline_name = stage_result.output_data.get("sub_pipeline") if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map: sub = sub_pipeline_map[sub_pipeline_name] @@ -66,7 +67,7 @@ class DynamicPipeline: pipeline: Pipeline, max_iterations: int = 5, exit_condition: str = "done", - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """循环执行 Pipeline 直到条件满足""" current_context = context or {} diff --git a/src/agentkit/orchestrator/handoff.py b/src/agentkit/orchestrator/handoff.py index cc13631..4c10751 100644 --- a/src/agentkit/orchestrator/handoff.py +++ b/src/agentkit/orchestrator/handoff.py @@ -3,25 +3,42 @@ import asyncio import json import logging -from typing import Any +from typing import Awaitable, Callable, Protocol from agentkit.core.protocol import HandoffMessage logger = logging.getLogger(__name__) +class _RedisPubSubLike(Protocol): + """Structural type for Redis pubsub object.""" + + async def subscribe(self, channel: str) -> None: ... + async def unsubscribe(self, channel: str) -> None: ... + def listen(self) -> object: ... + + +class _RedisLike(Protocol): + """Structural type for async Redis client.""" + + async def publish(self, channel: str, message: str) -> int: ... + def pubsub(self) -> _RedisPubSubLike: ... + + class HandoffManager: """Handoff 管理器 通过 Redis Pub/Sub 管理 Agent 间的任务转交。 """ - def __init__(self, redis: Any = None, dispatcher: Any = None): + def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None): self._redis = redis self._dispatcher = dispatcher - self._handlers: dict[str, list[Any]] = {} + self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {} - def register_handler(self, agent_name: str, handler: Any) -> None: + def register_handler( + self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]] + ) -> None: """注册 Handoff 处理器""" if agent_name not in self._handlers: self._handlers[agent_name] = [] diff --git a/src/agentkit/orchestrator/pipeline_engine.py b/src/agentkit/orchestrator/pipeline_engine.py index a8b0bd2..ec4c903 100644 --- a/src/agentkit/orchestrator/pipeline_engine.py +++ b/src/agentkit/orchestrator/pipeline_engine.py @@ -1,10 +1,12 @@ """Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿""" +from __future__ import annotations + import asyncio import logging from collections import defaultdict from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING from agentkit.orchestrator.compensation import SagaOrchestrator from agentkit.orchestrator.pipeline_schema import ( @@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Protocol + + class _DispatcherLike(Protocol): + async def dispatch(self, task: object) -> None: ... + async def get_task_status(self, task_id: str) -> dict[str, object]: ... + + class _StateManagerLike(Protocol): + async def create_execution(self, **kwargs: object) -> str: ... + async def update_step(self, **kwargs: object) -> None: ... + async def complete_execution(self, **kwargs: object) -> None: ... + async def fail_execution(self, **kwargs: object) -> None: ... + + class _LLMGatewayLike(Protocol): + async def chat(self, **kwargs: object) -> object: ... + + class PipelineEngine: """Pipeline 执行引擎 @@ -38,7 +57,12 @@ class PipelineEngine: - 状态持久化(可选) """ - def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None): + def __init__( + self, + dispatcher: _DispatcherLike | None = None, + state_manager: _StateManagerLike | None = None, + llm_gateway: _LLMGatewayLike | None = None, + ): self._dispatcher = dispatcher self._state_manager = state_manager self._llm_gateway = llm_gateway @@ -46,7 +70,7 @@ class PipelineEngine: async def execute( self, pipeline: Pipeline, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, adaptive_config: AdaptiveConfig | None = None, ) -> PipelineResult: """执行 Pipeline @@ -68,7 +92,7 @@ class PipelineEngine: async def _adaptive_loop( self, pipeline: Pipeline, - context: dict[str, Any] | None, + context: dict[str, object] | None, failed_result: PipelineResult, adaptive_config: AdaptiveConfig, ) -> PipelineResult: @@ -92,34 +116,30 @@ class PipelineEngine: # Replan new_pipeline = await replanner.replan(current_pipeline, current_result, report) - logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)") + logger.info( + f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)" + ) # Re-execute current_result = await self._execute_pipeline(new_pipeline, context) current_pipeline = new_pipeline # Record reflection in metadata - current_result.metadata["reflections"] = [ - r.model_dump() for r in reflections - ] + current_result.metadata["reflections"] = [r.model_dump() for r in reflections] if current_result.status == StageStatus.COMPLETED: logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)") return current_result # Exhausted reflections - logger.warning( - f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)" - ) - current_result.metadata["reflections"] = [ - r.model_dump() for r in reflections - ] + logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)") + current_result.metadata["reflections"] = [r.model_dump() for r in reflections] return current_result async def _execute_pipeline( self, pipeline: Pipeline, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """执行 Pipeline 的核心逻辑(不含反思-重规划)。""" result = PipelineResult(pipeline_name=pipeline.name) @@ -151,7 +171,9 @@ class PipelineEngine: # 逐层执行 for level, stages in enumerate(level_groups): - logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") + logger.info( + f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)" + ) # 并行执行同层 stages tasks = [] @@ -173,9 +195,11 @@ class PipelineEngine: # Update step state if self._state_manager is not None and execution_id is not None: try: - step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value - step_output = sr.output_data if hasattr(sr, 'output_data') else None - step_error = sr.error_message if hasattr(sr, 'error_message') else None + step_status = ( + "completed" if sr.status == StageStatus.COMPLETED else sr.status.value + ) + step_output = sr.output_data if hasattr(sr, "output_data") else None + step_error = sr.error_message if hasattr(sr, "error_message") else None await self._state_manager.update_step( execution_id=execution_id, step_name=stage.name, @@ -189,19 +213,21 @@ class PipelineEngine: # 收集输出变量 if sr.output_data and isinstance(sr, dict): pass - elif hasattr(sr, 'output_data') and sr.output_data: + elif hasattr(sr, "output_data") and sr.output_data: for output_key in stage.outputs: if output_key in sr.output_data: result.variables[output_key] = sr.output_data[output_key] # 检查是否需要中止 - if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: + if hasattr(sr, "status") and sr.status == StageStatus.FAILED: if not stage.continue_on_failure: # Execute Saga compensation for completed steps compensation_results = await saga.compensate() if compensation_results: failed_compensations = [ - cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed" + cr + for cr in compensation_results + if not cr.success and cr.error != "no_compensation_needed" ] if failed_compensations: logger.warning( @@ -219,7 +245,12 @@ class PipelineEngine: step_name=stage.name, error=result.error_message, ) - except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc: + except ( + asyncio.TimeoutError, + ConnectionError, + RuntimeError, + ValueError, + ) as exc: logger.warning(f"Failed to persist failure state: {exc}") return result @@ -252,7 +283,9 @@ class PipelineEngine: started_at = datetime.now(timezone.utc).isoformat() # 条件检查 - if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables): + if stage.condition and not self._evaluate_condition( + stage.condition, pipeline_result.variables + ): return StageResult( stage_name=stage.name, status=StageStatus.SKIPPED, @@ -312,7 +345,9 @@ class PipelineEngine: if status["status"] in ("completed", "failed", "cancelled"): return StageResult( stage_name=stage.name, - status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, + status=StageStatus.COMPLETED + if status["status"] == "completed" + else StageStatus.FAILED, output_data=status.get("output_data"), error_message=status.get("error_message"), started_at=started_at, @@ -406,7 +441,7 @@ class PipelineEngine: return resolved @staticmethod - def _get_nested(data: dict, path: str) -> Any: + def _get_nested(data: dict, path: str) -> object: keys = path.split(".") current = data for key in keys: @@ -497,9 +532,7 @@ class PipelineEngine: if verifier_feedback.passed: # 审查通过,返回成功结果 - logger.info( - f"Stage '{stage.name}' passed review in round {round_num}" - ) + logger.info(f"Stage '{stage.name}' passed review in round {round_num}") worker_result.output_data = worker_result.output_data or {} worker_result.output_data["adversarial_metadata"] = { "passed_round": round_num, @@ -553,7 +586,7 @@ class PipelineEngine: self, agent_name: str, action: str, - input_data: dict[str, Any], + input_data: dict[str, object], stage: PipelineStage, started_at: str, timeout_seconds: int | None = None, @@ -568,7 +601,9 @@ class PipelineEngine: started_at: 开始时间 timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds """ - effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds + effective_timeout = ( + timeout_seconds if timeout_seconds is not None else stage.timeout_seconds + ) if self._dispatcher is None: # Dry-run 模式 return StageResult( @@ -602,7 +637,9 @@ class PipelineEngine: if status["status"] in ("completed", "failed", "cancelled"): return StageResult( stage_name=stage.name, - status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, + status=StageStatus.COMPLETED + if status["status"] == "completed" + else StageStatus.FAILED, output_data=status.get("output_data"), error_message=status.get("error_message"), started_at=started_at, @@ -639,7 +676,7 @@ class PipelineEngine: async def _execute_verifier( self, verifier_name: str, - worker_output: dict[str, Any], + worker_output: dict[str, object], stage: PipelineStage, started_at: str, ) -> ReviewFeedback: @@ -679,10 +716,7 @@ class PipelineEngine: try: feedback = ReviewFeedback( passed=output_data.get("passed", False), - issues=[ - ReviewIssue(**issue) - for issue in output_data.get("issues", []) - ], + issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])], summary=output_data.get("summary", "No summary provided"), score=output_data.get("score", 0.0), ) @@ -699,7 +733,7 @@ class PipelineEngine: self, feedback: ReviewFeedback, feedback_mode: str = "structured+natural", - ) -> dict[str, Any]: + ) -> dict[str, object]: """构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复 Args: @@ -720,7 +754,7 @@ class PipelineEngine: for issue in feedback.issues ] - feedback_context: dict[str, Any] = { + feedback_context: dict[str, object] = { "previous_attempt_failed": True, } @@ -756,7 +790,9 @@ class PipelineEngine: ) else: # 未知模式,fallback 到 structured+natural - logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural") + logger.warning( + f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural" + ) feedback_context["review_feedback"] = { "summary": feedback.summary, "issues": issues_list, diff --git a/src/agentkit/orchestrator/pipeline_loader.py b/src/agentkit/orchestrator/pipeline_loader.py index e22498a..d7821c7 100644 --- a/src/agentkit/orchestrator/pipeline_loader.py +++ b/src/agentkit/orchestrator/pipeline_loader.py @@ -2,7 +2,6 @@ import logging from pathlib import Path -from typing import Any import yaml @@ -23,7 +22,9 @@ class PipelineLoader: if not yaml_path.exists(): yaml_path = self._pipelines_dir / f"{pipeline_name}.yml" if not yaml_path.exists(): - raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}") + raise FileNotFoundError( + f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}" + ) content = yaml_path.read_text(encoding="utf-8") return self.load_from_yaml(content, pipeline_name) diff --git a/src/agentkit/orchestrator/pipeline_models.py b/src/agentkit/orchestrator/pipeline_models.py index 3fa1208..a06c472 100644 --- a/src/agentkit/orchestrator/pipeline_models.py +++ b/src/agentkit/orchestrator/pipeline_models.py @@ -35,9 +35,7 @@ class PipelineExecutionModel(Base): ) completed_at = Column(DateTime) - __table_args__ = ( - Index("ix_pipeline_status_created", "status", "created_at"), - ) + __table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),) class PipelineStepHistoryModel(Base): diff --git a/src/agentkit/orchestrator/pipeline_schema.py b/src/agentkit/orchestrator/pipeline_schema.py index 5f3cf0a..67182a3 100644 --- a/src/agentkit/orchestrator/pipeline_schema.py +++ b/src/agentkit/orchestrator/pipeline_schema.py @@ -1,7 +1,7 @@ """Pipeline 数据模型""" from enum import Enum -from typing import Any, Literal +from typing import Literal from pydantic import BaseModel, Field @@ -18,8 +18,11 @@ class StageStatus(str, Enum): class ReviewIssue(BaseModel): """单条审查问题""" + severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度") - category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别") + category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field( + description="问题类别" + ) description: str = Field(min_length=1, description="问题描述") location: str | None = Field(default=None, description="文件路径/行号") suggestion: str | None = Field(default=None, description="修复建议") @@ -27,6 +30,7 @@ class ReviewIssue(BaseModel): class ReviewFeedback(BaseModel): """Verifier 返回的结构化审查反馈""" + passed: bool = Field(description="是否通过审查") issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表") summary: str = Field(min_length=1, description="自然语言审查报告") @@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel): class AdversarialState(BaseModel): """对抗轮次状态追踪""" + current_round: int = Field(default=0, description="当前对抗轮次") max_rounds: int = Field(default=3, description="最大对抗轮次") feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史") @@ -46,7 +51,7 @@ class PipelineStage(BaseModel): agent: str action: str depends_on: list[str] = [] - inputs: dict[str, Any] = {} + inputs: dict[str, object] = {} outputs: list[str] = [] timeout_seconds: int = 300 retry_count: int = 0 @@ -54,12 +59,19 @@ class PipelineStage(BaseModel): condition: str | None = None retry_policy: StepRetryPolicy | None = None compensate: str | None = None - + # 对抗闭环相关字段 - verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式") + verifier: str | None = Field( + default=None, description="Verifier Agent 名称,配置后启用对抗模式" + ) max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次") - verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds") - feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式") + verifier_timeout_seconds: int = Field( + default=120, + description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds", + ) + feedback_mode: Literal["structured+natural", "structured", "natural"] = Field( + default="structured+natural", description="反馈模式" + ) escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标") model_config = {"arbitrary_types_allowed": True} @@ -70,13 +82,13 @@ class Pipeline(BaseModel): version: str description: str stages: list[PipelineStage] - variables: dict[str, Any] = {} + variables: dict[str, object] = {} class StageResult(BaseModel): stage_name: str status: StageStatus = StageStatus.PENDING - output_data: dict[str, Any] | None = None + output_data: dict[str, object] | None = None error_message: str | None = None started_at: str | None = None completed_at: str | None = None @@ -86,9 +98,9 @@ class PipelineResult(BaseModel): pipeline_name: str status: StageStatus = StageStatus.PENDING stage_results: dict[str, StageResult] = {} - variables: dict[str, Any] = {} + variables: dict[str, object] = {} error_message: str | None = None - metadata: dict[str, Any] = {} + metadata: dict[str, object] = {} class AdaptiveConfig(BaseModel): diff --git a/src/agentkit/orchestrator/pipeline_state.py b/src/agentkit/orchestrator/pipeline_state.py index 1acc9c8..cb2a6ce 100644 --- a/src/agentkit/orchestrator/pipeline_state.py +++ b/src/agentkit/orchestrator/pipeline_state.py @@ -13,7 +13,7 @@ import json import logging import uuid from datetime import datetime, timezone -from typing import Any, Callable, Coroutine +from typing import Callable, Coroutine from agentkit.orchestrator.pipeline_models import ( PipelineExecutionModel, @@ -183,7 +183,7 @@ class PipelineStateRedis: return self._redis async def _safe_redis_call( - self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object + self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object ) -> object | None: """Execute a Redis call, falling back to memory on failure. diff --git a/src/agentkit/orchestrator/reflection.py b/src/agentkit/orchestrator/reflection.py index 18aabc9..de97a68 100644 --- a/src/agentkit/orchestrator/reflection.py +++ b/src/agentkit/orchestrator/reflection.py @@ -4,22 +4,30 @@ 生成修正后的 Pipeline 重新执行。 """ +from __future__ import annotations + import json import logging -from typing import Any +from typing import TYPE_CHECKING from agentkit.orchestrator.pipeline_schema import ( Pipeline, PipelineResult, PipelineStage, ReflectionReport, - StageResult, StageStatus, ) logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Protocol + + class _LLMGatewayLike(Protocol): + async def chat(self, **kwargs: object) -> object: ... + + class PipelineReflector: """分析 Pipeline 执行失败原因,生成结构化反思报告。 @@ -27,7 +35,7 @@ class PipelineReflector: 输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。 """ - def __init__(self, llm_gateway: Any = None): + def __init__(self, llm_gateway: _LLMGatewayLike | None = None): self._llm_gateway = llm_gateway async def reflect( @@ -54,19 +62,25 @@ class PipelineReflector: if self._llm_gateway is not None: try: return await self._llm_reflect( - pipeline, failed_stage, error_message, - completed_outputs, reflection_number, + pipeline, + failed_stage, + error_message, + completed_outputs, + reflection_number, ) except Exception as e: logger.warning(f"LLM reflection failed, falling back to rule-based: {e}") # 规则兜底:基于错误信息分类 return self._rule_based_reflect( - failed_stage, error_message, reflection_number, + failed_stage, + error_message, + reflection_number, ) def _find_failure( - self, result: PipelineResult, + self, + result: PipelineResult, ) -> tuple[str, str]: """找到第一个失败的 stage 及其错误信息。""" for name, sr in result.stage_results.items(): @@ -75,8 +89,9 @@ class PipelineReflector: return "", "no failed stage found" def _collect_completed_outputs( - self, result: PipelineResult, - ) -> dict[str, Any]: + self, + result: PipelineResult, + ) -> dict[str, object]: """收集已完成步骤的输出。""" outputs = {} for name, sr in result.stage_results.items(): @@ -89,13 +104,16 @@ class PipelineReflector: pipeline: Pipeline, failed_stage: str, error_message: str, - completed_outputs: dict[str, Any], + completed_outputs: dict[str, object], reflection_number: int, ) -> ReflectionReport: """使用 LLM 分析失败原因。""" prompt = self._build_reflection_prompt( - pipeline, failed_stage, error_message, - completed_outputs, reflection_number, + pipeline, + failed_stage, + error_message, + completed_outputs, + reflection_number, ) response = await self._llm_gateway.chat( @@ -106,7 +124,9 @@ class PipelineReflector: # 解析 LLM 返回的 JSON content = response.content if hasattr(response, "content") else str(response) return self._parse_reflection_response( - content, failed_stage, reflection_number, + content, + failed_stage, + reflection_number, ) def _build_reflection_prompt( @@ -114,15 +134,14 @@ class PipelineReflector: pipeline: Pipeline, failed_stage: str, error_message: str, - completed_outputs: dict[str, Any], + completed_outputs: dict[str, object], reflection_number: int, ) -> str: """构建反思提示词。""" stage_descriptions = [] for s in pipeline.stages: stage_descriptions.append( - f" - {s.name}: agent={s.agent}, action={s.action}, " - f"depends_on={s.depends_on}" + f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}" ) completed_summary = json.dumps( @@ -174,7 +193,9 @@ JSON response:""" except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Failed to parse LLM reflection response: {e}") return self._rule_based_reflect( - failed_stage, content, reflection_number, + failed_stage, + content, + reflection_number, ) def _rule_based_reflect( @@ -218,7 +239,7 @@ class PipelineReplanner: 保留已完成步骤的结果,仅重新规划失败及后续步骤。 """ - def __init__(self, llm_gateway: Any = None): + def __init__(self, llm_gateway: _LLMGatewayLike | None = None): self._llm_gateway = llm_gateway async def replan( @@ -255,8 +276,7 @@ class PipelineReplanner: ) -> Pipeline: """使用 LLM 生成修正后的 Pipeline。""" completed_stages = [ - name for name, sr in result.stage_results.items() - if sr.status == StageStatus.COMPLETED + name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED ] prompt = f"""Based on the reflection report, generate a corrected pipeline. @@ -284,7 +304,9 @@ JSON pipeline:""" return self._parse_pipeline_response(content, pipeline) def _parse_pipeline_response( - self, content: str, original: Pipeline, + self, + content: str, + original: Pipeline, ) -> Pipeline: """解析 LLM 返回的 Pipeline JSON。""" try: @@ -294,9 +316,7 @@ JSON pipeline:""" text = "\n".join(lines[1:-1]) data = json.loads(text) - stages = [ - PipelineStage(**s) for s in data.get("stages", []) - ] + stages = [PipelineStage(**s) for s in data.get("stages", [])] return Pipeline( name=data.get("name", original.name), version=data.get("version", original.version), @@ -316,8 +336,7 @@ JSON pipeline:""" ) -> Pipeline: """基于规则的兜底重规划。""" completed_stages = { - name for name, sr in result.stage_results.items() - if sr.status == StageStatus.COMPLETED + name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED } # 构建修正后的 stages 列表 @@ -345,17 +364,21 @@ JSON pipeline:""" ) def _adjust_failed_stage( - self, stage: PipelineStage, report: ReflectionReport, + self, + stage: PipelineStage, + report: ReflectionReport, ) -> PipelineStage: """根据反思报告调整失败的步骤。""" - adjustments: dict[str, Any] = {} + adjustments: dict[str, object] = {} if report.failure_type == "timeout": adjustments["timeout_seconds"] = min( - stage.timeout_seconds * 2, 3600, + stage.timeout_seconds * 2, + 3600, ) if stage.retry_policy is None: from agentkit.orchestrator.retry import StepRetryPolicy + adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) elif report.failure_type == "resource_error": @@ -365,6 +388,7 @@ JSON pipeline:""" # 添加重试策略,可能输入在后续可用 if stage.retry_policy is None: from agentkit.orchestrator.retry import StepRetryPolicy + adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) return stage.model_copy(update=adjustments) diff --git a/src/agentkit/orchestrator/retry.py b/src/agentkit/orchestrator/retry.py index 4cb4ebd..372305d 100644 --- a/src/agentkit/orchestrator/retry.py +++ b/src/agentkit/orchestrator/retry.py @@ -4,7 +4,7 @@ import asyncio import logging import random from dataclasses import dataclass -from typing import Any, Awaitable, Callable +from typing import Awaitable, Callable logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class StepRetryPolicy: def calculate_delay(self, attempt: int) -> float: """Calculate delay for given attempt number (0-based)""" delay = min( - self.base_delay * (self.exponential_base ** attempt), + self.base_delay * (self.exponential_base**attempt), self.max_delay, ) if self.jitter: @@ -36,10 +36,10 @@ class StepRetryPolicy: async def execute_with_retry( - func: Callable[..., Awaitable[Any]], + func: Callable[..., Awaitable[object]], retry_policy: StepRetryPolicy | None = None, step_name: str = "", -) -> Any: +) -> object: """Execute a function with retry policy""" if retry_policy is None: return await func() diff --git a/src/agentkit/orchestrator/workflow_schema.py b/src/agentkit/orchestrator/workflow_schema.py index feb3c48..ea26280 100644 --- a/src/agentkit/orchestrator/workflow_schema.py +++ b/src/agentkit/orchestrator/workflow_schema.py @@ -3,18 +3,17 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any from pydantic import BaseModel, Field -from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage +from agentkit.orchestrator.pipeline_schema import PipelineStage class WorkflowStage(PipelineStage): """A workflow stage extending PipelineStage with type and config.""" type: str = "skill" # "skill" | "condition" | "approval" | "parallel" - config: dict[str, Any] = Field(default_factory=dict) + config: dict[str, object] = Field(default_factory=dict) class WorkflowDefinition(BaseModel): @@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel): name: str version: int = 1 stages: list[WorkflowStage] = Field(default_factory=list) - triggers: list[dict[str, Any]] = Field(default_factory=list) - variables_schema: dict[str, Any] = Field(default_factory=dict) - output_schema: dict[str, Any] = Field(default_factory=dict) + triggers: list[dict[str, object]] = Field(default_factory=list) + variables_schema: dict[str, object] = Field(default_factory=dict) + output_schema: dict[str, object] = Field(default_factory=dict) created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel): workflow_id: str = "" status: str = "pending" # pending|running|paused|completed|failed|cancelled current_stage: str | None = None - stage_results: dict[str, Any] = Field(default_factory=dict) + stage_results: dict[str, object] = Field(default_factory=dict) started_at: str | None = None completed_at: str | None = None error: str | None = None - variables: dict[str, Any] = Field(default_factory=dict) + variables: dict[str, object] = Field(default_factory=dict) class WorkflowSummary(BaseModel): @@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel): name: str stages: list[WorkflowStage] = Field(default_factory=list) - triggers: list[dict[str, Any]] = Field(default_factory=list) - variables_schema: dict[str, Any] = Field(default_factory=dict) - output_schema: dict[str, Any] = Field(default_factory=dict) + triggers: list[dict[str, object]] = Field(default_factory=list) + variables_schema: dict[str, object] = Field(default_factory=dict) + output_schema: dict[str, object] = Field(default_factory=dict) class ExecuteWorkflowRequest(BaseModel): """Request body for executing a workflow.""" - variables: dict[str, Any] = Field(default_factory=dict) + variables: dict[str, object] = Field(default_factory=dict) class ApproveRequest(BaseModel):