From 838a05772e700db2f285e34dc811ba6c512f2779 Mon Sep 17 00:00:00 2001 From: Fischer Date: Wed, 1 Jul 2026 03:03:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20follow-up=20tech=20debt=20cleanup?= =?UTF-8?q?=20(except=20Exception=20+=20Any=20=E6=B2=BB=E7=90=86)=20(#9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/agentkit/bitable/formula/engine.py | 17 +- .../bitable/ingestion/api_collector.py | 10 +- src/agentkit/bitable/ingestion/database.py | 15 +- src/agentkit/bitable/ingestion/excel.py | 9 +- src/agentkit/bitable/models.py | 16 +- src/agentkit/bitable/recalc_worker.py | 11 +- src/agentkit/bitable/repository.py | 37 ++--- src/agentkit/client/sync.py | 25 ++- src/agentkit/llm/cache.py | 54 +++++- src/agentkit/llm/cache_key.py | 3 +- src/agentkit/llm/config.py | 14 +- src/agentkit/llm/gateway.py | 43 ++++- src/agentkit/llm/migration.py | 11 +- src/agentkit/llm/protocol.py | 13 +- src/agentkit/llm/providers/anthropic.py | 135 ++++++++------- src/agentkit/llm/providers/doubao.py | 3 +- src/agentkit/llm/providers/gemini.py | 145 +++++++++------- .../llm/providers/litellm_provider.py | 31 ++-- src/agentkit/llm/providers/usage_store.py | 8 +- src/agentkit/llm/providers/wenxin.py | 3 +- src/agentkit/llm/providers/yuanbao.py | 3 +- src/agentkit/llm/remote_provider.py | 11 +- src/agentkit/llm/retry.py | 14 +- src/agentkit/memory/adapters/base.py | 9 +- src/agentkit/memory/adapters/confluence.py | 34 ++-- src/agentkit/memory/adapters/feishu.py | 42 ++--- src/agentkit/memory/adapters/generic_http.py | 34 ++-- src/agentkit/memory/base.py | 26 ++- src/agentkit/memory/chunking.py | 90 ++++++---- src/agentkit/memory/contextual_retrieval.py | 14 +- src/agentkit/memory/document_loader.py | 37 +++-- src/agentkit/memory/embedder.py | 12 +- src/agentkit/memory/episodic.py | 147 ++++++++++------- src/agentkit/memory/http_rag.py | 84 ++++++---- src/agentkit/memory/knowledge_base.py | 10 +- src/agentkit/memory/local_rag.py | 155 ++++++++++-------- src/agentkit/memory/multi_source_retriever.py | 7 +- src/agentkit/memory/profile.py | 23 ++- src/agentkit/memory/query_transformer.py | 24 ++- src/agentkit/memory/rag_loop.py | 29 ++-- src/agentkit/memory/relevance_scorer.py | 10 +- src/agentkit/memory/retriever.py | 71 ++++---- src/agentkit/memory/semantic.py | 103 ++++++++---- src/agentkit/memory/working.py | 29 ++-- src/agentkit/orchestrator/checkpoint.py | 46 ++++-- src/agentkit/orchestrator/compensation.py | 20 +-- src/agentkit/orchestrator/dynamic_pipeline.py | 15 +- src/agentkit/orchestrator/handoff.py | 25 ++- src/agentkit/orchestrator/pipeline_engine.py | 116 ++++++++----- src/agentkit/orchestrator/pipeline_loader.py | 5 +- src/agentkit/orchestrator/pipeline_models.py | 4 +- src/agentkit/orchestrator/pipeline_schema.py | 34 ++-- src/agentkit/orchestrator/pipeline_state.py | 4 +- src/agentkit/orchestrator/reflection.py | 82 +++++---- src/agentkit/orchestrator/retry.py | 8 +- src/agentkit/orchestrator/workflow_schema.py | 23 ++- src/agentkit/server/routes/agents.py | 2 +- src/agentkit/server/routes/auth.py | 8 +- src/agentkit/server/routes/bitable.py | 4 +- src/agentkit/server/routes/channels.py | 10 +- src/agentkit/server/routes/chat.py | 30 +++- src/agentkit/server/routes/config_sync.py | 6 +- src/agentkit/server/routes/documents.py | 5 +- src/agentkit/server/routes/evolution.py | 9 + .../server/routes/evolution_dashboard.py | 14 +- src/agentkit/server/routes/health.py | 12 +- src/agentkit/server/routes/kb_management.py | 11 +- src/agentkit/server/routes/mcp_publish.py | 2 +- src/agentkit/server/routes/memory.py | 7 + src/agentkit/server/routes/metrics.py | 7 +- src/agentkit/server/routes/portal.py | 58 +++++-- .../server/routes/skill_management.py | 8 +- src/agentkit/server/routes/skills.py | 20 +-- src/agentkit/server/routes/system.py | 6 +- src/agentkit/server/routes/tasks.py | 10 +- src/agentkit/server/routes/terminal.py | 20 ++- src/agentkit/server/routes/terminal_server.py | 18 +- src/agentkit/server/routes/workflows.py | 10 +- src/agentkit/server/routes/ws.py | 8 +- 79 files changed, 1388 insertions(+), 900 deletions(-) diff --git a/src/agentkit/bitable/formula/engine.py b/src/agentkit/bitable/formula/engine.py index 589c6f6..9e06e07 100644 --- a/src/agentkit/bitable/formula/engine.py +++ b/src/agentkit/bitable/formula/engine.py @@ -16,7 +16,6 @@ from __future__ import annotations import ast from collections import deque -from typing import Any from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY from agentkit.bitable.formula.parser import ( @@ -104,9 +103,9 @@ class FormulaEngine: def evaluate( self, field_id: str, - row_values: dict[str, Any], - column_values: dict[str, list[Any]] | None = None, - ) -> Any: + row_values: dict[str, object], + column_values: dict[str, list[object]] | None = None, + ) -> object: """Evaluate a formula field for a specific record. Args: @@ -130,7 +129,7 @@ class FormulaEngine: # Build the field_values dict for the evaluator # Aggregate refs get column values (lists), row refs get row values (scalars) - eval_values: dict[str, Any] = {} + eval_values: dict[str, object] = {} # Map real field IDs to safe names for safe_name, real_id in entry.field_mapping.items(): @@ -143,16 +142,16 @@ class FormulaEngine: def evaluate_all_for_record( self, - row_values: dict[str, Any], - column_values: dict[str, list[Any]] | None = None, - ) -> dict[str, Any]: + row_values: dict[str, object], + column_values: dict[str, list[object]] | None = None, + ) -> dict[str, object]: """Evaluate all registered formulas for a record. Returns a dict of field_id → computed value. Formulas are evaluated in topological order so that formula-to-formula dependencies are resolved correctly. """ - results: dict[str, Any] = {} + results: dict[str, object] = {} column_values = column_values or {} for field_id in self.topological_order(): diff --git a/src/agentkit/bitable/ingestion/api_collector.py b/src/agentkit/bitable/ingestion/api_collector.py index b14d92a..3dc8f59 100644 --- a/src/agentkit/bitable/ingestion/api_collector.py +++ b/src/agentkit/bitable/ingestion/api_collector.py @@ -16,13 +16,11 @@ Usage:: from __future__ import annotations -from typing import Any - def transform_records( - records: list[dict[str, Any]], + records: list[dict[str, object]], field_mapping: dict[str, str], -) -> list[dict[str, Any]]: +) -> list[dict[str, object]]: """Map source record keys to bitable field IDs via field_mapping. Keys not in ``field_mapping`` are dropped. Values are passed through @@ -40,9 +38,9 @@ def transform_records( if not field_mapping: return [] - transformed: list[dict[str, Any]] = [] + transformed: list[dict[str, object]] = [] for rec in records: - out: dict[str, Any] = {} + out: dict[str, object] = {} for src_key, field_id in field_mapping.items(): if src_key in rec: out[field_id] = rec[src_key] diff --git a/src/agentkit/bitable/ingestion/database.py b/src/agentkit/bitable/ingestion/database.py index 9f8d9fe..78f56db 100644 --- a/src/agentkit/bitable/ingestion/database.py +++ b/src/agentkit/bitable/ingestion/database.py @@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable): from __future__ import annotations import logging -from typing import Any from sqlalchemy import ( BigInteger, @@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = { READ_BATCH = 1000 -def infer_field_type(sqla_type: Any) -> str: +def infer_field_type(sqla_type: object) -> str: """Map a SQLAlchemy column type instance or class to a bitable field type. Handles both type instances (``Integer()``) and type classes (``Integer``). @@ -78,7 +77,7 @@ def import_table( table_name: str, *, max_rows: int = 50_000, -) -> dict[str, Any]: +) -> dict[str, object]: """Reflect a single table from an external DB. Returns ``{"table_name": str, "fields": [...], "records": [...], @@ -97,7 +96,7 @@ def import_table( engine.dispose() -def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]: +def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]: """Reflect one table and read its rows.""" insp = inspect(engine) @@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st table = Table(table_name, metadata, autoload_with=engine) # Build field definitions - fields: list[dict[str, Any]] = [] + fields: list[dict[str, object]] = [] pk_columns = list(table.primary_key.columns) pk_name = pk_columns[0].name if pk_columns else None @@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st pk_name = "id" # Read rows - records: list[dict[str, Any]] = [] + records: list[dict[str, object]] = [] with engine.connect() as conn: result = conn.execute(select(table)) for i, row in enumerate(result): if i >= max_rows: logger.warning("Table %r truncated at %d rows during import", table_name, max_rows) break - rec: dict[str, Any] = {} + rec: dict[str, object] = {} for col in table.columns: val = getattr(row, col.name, None) if val is not None: @@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st } -def _serialize(val: Any) -> Any: +def _serialize(val: object) -> object: """Serialize a DB value to JSON-safe form.""" from datetime import date, datetime from decimal import Decimal diff --git a/src/agentkit/bitable/ingestion/excel.py b/src/agentkit/bitable/ingestion/excel.py index 34365d4..d92af55 100644 --- a/src/agentkit/bitable/ingestion/excel.py +++ b/src/agentkit/bitable/ingestion/excel.py @@ -18,7 +18,6 @@ import logging import socket from dataclasses import dataclass, field from pathlib import Path -from typing import Any from urllib.parse import urlparse import httpx @@ -36,7 +35,7 @@ class ParsedSheet: name: str columns: list[str] = field(default_factory=list) field_types: list[str] = field(default_factory=list) # "text" | "number" | "date" - records: list[dict[str, Any]] = field(default_factory=list) + records: list[dict[str, object]] = field(default_factory=list) def parse_excel(file_path: str | Path) -> list[ParsedSheet]: @@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None: col_count = len(clean_headers) field_types = _infer_column_types(data_rows, col_count) - records: list[dict[str, Any]] = [] + records: list[dict[str, object]] = [] for row in data_rows: - rec: dict[str, Any] = {} + rec: dict[str, object] = {} for i, col_name in enumerate(clean_headers): val = row[i] if i < len(row) else None if val is not None: @@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]: return types -def _coerce_value(val: Any, field_type: str) -> Any: +def _coerce_value(val: object, field_type: str) -> object: """Coerce a cell value to the inferred field type. Truncate long strings.""" if field_type == "date": from datetime import datetime diff --git a/src/agentkit/bitable/models.py b/src/agentkit/bitable/models.py index 2f1ea76..db8d90e 100644 --- a/src/agentkit/bitable/models.py +++ b/src/agentkit/bitable/models.py @@ -9,10 +9,14 @@ from __future__ import annotations from datetime import datetime, timezone from enum import Enum -from typing import Any from pydantic import BaseModel, ConfigDict, Field as PydanticField +# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of +# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive +# named aliases (RecursionError). `object` is the most permissive type and +# Pydantic v2 serializes dict/list/primitive values fine at runtime. + def _utcnow() -> datetime: return datetime.now(timezone.utc) @@ -97,7 +101,7 @@ class Table(BaseModel): # --------------------------------------------------------------------------- # Status select field options — labels and colors match Feishu Bitable defaults. -_STATUS_OPTIONS: list[dict[str, Any]] = [ +_STATUS_OPTIONS: list[dict[str, object]] = [ {"label": "未开始", "value": "not_started", "color": "default"}, {"label": "进行中", "value": "in_progress", "color": "processing"}, {"label": "已完成", "value": "done", "color": "success"}, @@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [ #: Templates for the 5 default fields created on every new table (R2). #: agent-owned fields (创建人/创建时间) are auto-filled by the service layer #: on record creation; user-owned fields are user-editable. -DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [ +DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [ { "name": "标题", "field_type": FieldType.text, @@ -155,7 +159,7 @@ class Field(BaseModel): table_id: str name: str field_type: FieldType - config: dict[str, Any] = PydanticField(default_factory=dict) + config: dict[str, object] = PydanticField(default_factory=dict) owner: FieldOwner = FieldOwner.user created_at: datetime = PydanticField(default_factory=_utcnow) @@ -167,7 +171,7 @@ class Record(BaseModel): id: str table_id: str - values: dict[str, Any] = PydanticField(default_factory=dict) + values: dict[str, object] = PydanticField(default_factory=dict) created_at: datetime = PydanticField(default_factory=_utcnow) updated_at: datetime = PydanticField(default_factory=_utcnow) @@ -181,7 +185,7 @@ class View(BaseModel): table_id: str name: str view_type: ViewType = ViewType.grid - config: dict[str, Any] = PydanticField(default_factory=dict) + config: dict[str, object] = PydanticField(default_factory=dict) created_at: datetime = PydanticField(default_factory=_utcnow) diff --git a/src/agentkit/bitable/recalc_worker.py b/src/agentkit/bitable/recalc_worker.py index 0976b32..d6b8497 100644 --- a/src/agentkit/bitable/recalc_worker.py +++ b/src/agentkit/bitable/recalc_worker.py @@ -18,11 +18,10 @@ from __future__ import annotations import asyncio import logging -from typing import Any from agentkit.bitable.db import BitableDB from agentkit.bitable.formula.engine import FormulaEngine -from agentkit.bitable.models import FieldType, RecalcStatus +from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask from agentkit.bitable.repository import BitableRepository from agentkit.bitable.service import BitableService @@ -124,7 +123,7 @@ class RecalcWorker: logger.exception("RecalcWorker error in main loop") await asyncio.sleep(self._poll_interval) - async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]: + async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]: """Sort claimed tasks so dependencies are processed first (P1 #7). Groups tasks by table_id, builds (or reuses) the engine to get the @@ -146,7 +145,7 @@ class RecalcWorker: order = engine.topological_order() topo_index[tid] = {fid: i for i, fid in enumerate(order)} - def _key(t: Any) -> tuple[str, str, int]: + def _key(t: RecalcTask) -> tuple[str, str, int]: idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30) return (t.table_id, t.record_id, idx) @@ -175,7 +174,7 @@ class RecalcWorker: except Exception: logger.exception("RecalcWorker reaper error") - async def process_task(self, task: Any) -> None: + async def process_task(self, task: RecalcTask) -> None: """Process a single recalc task: evaluate formula → write result. The task is expected to already be in ``calculating`` status when @@ -216,7 +215,7 @@ class RecalcWorker: return deps = engine.get_dependencies(task.field_id) - column_values: dict[str, list[Any]] = {} + column_values: dict[str, list[object]] = {} for dep_field_id in deps: column_values[dep_field_id] = await self._repo.get_column_values( task.table_id, dep_field_id diff --git a/src/agentkit/bitable/repository.py b/src/agentkit/bitable/repository.py index 2560dd4..8278ff5 100644 --- a/src/agentkit/bitable/repository.py +++ b/src/agentkit/bitable/repository.py @@ -10,7 +10,6 @@ from __future__ import annotations import logging import re from datetime import datetime, timedelta, timezone -from typing import Any from sqlalchemy import delete, func, insert, select, text, update from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -102,7 +101,7 @@ class BitableRepository: result = await session.execute(stmt) return [BitableFile.model_validate(e) for e in result.scalars().all()] - async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None: + async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None: """Update a file's attributes.""" async with self._session_factory() as session: stmt = ( @@ -181,7 +180,7 @@ class BitableRepository: result = await session.execute(stmt) return [Table.model_validate(e) for e in result.scalars().all()] - async def update_table(self, table_id: str, **kwargs: Any) -> Table | None: + async def update_table(self, table_id: str, **kwargs: object) -> Table | None: """Update a table's attributes.""" async with self._session_factory() as session: stmt = ( @@ -236,7 +235,7 @@ class BitableRepository: table_id: str, name: str, field_type: FieldType, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, owner: FieldOwner = FieldOwner.user, ) -> Field: """Create a new field in a table.""" @@ -277,7 +276,7 @@ class BitableRepository: result = await session.execute(stmt) return [Field.model_validate(e) for e in result.scalars().all()] - async def update_field(self, field_id: str, **kwargs: Any) -> Field | None: + async def update_field(self, field_id: str, **kwargs: object) -> Field | None: """Update a field's attributes.""" async with self._session_factory() as session: stmt = ( @@ -300,7 +299,7 @@ class BitableRepository: # ── Records ───────────────────────────────────────────── - async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record: + async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record: """Create a new record.""" async with self._session_factory() as session: stmt = ( @@ -318,7 +317,7 @@ class BitableRepository: return Record.model_validate(entity) async def create_records_batch( - self, table_id: str, records_values: list[dict[str, Any]] + self, table_id: str, records_values: list[dict[str, object]] ) -> list[Record]: """Batch-insert multiple records (P2 #19: eliminates per-record INSERT). @@ -376,7 +375,7 @@ class BitableRepository: return [Record.model_validate(e) for e in entities], next_cursor - async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None: + async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None: """Update a record's values (full replace).""" async with self._session_factory() as session: stmt = ( @@ -413,7 +412,7 @@ class BitableRepository: table_id: str, name: str, view_type: ViewType = ViewType.grid, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, ) -> View: """Create a new view.""" async with self._session_factory() as session: @@ -451,7 +450,7 @@ class BitableRepository: result = await session.execute(stmt) return [View.model_validate(e) for e in result.scalars().all()] - async def update_view(self, view_id: str, **kwargs: Any) -> View | None: + async def update_view(self, view_id: str, **kwargs: object) -> View | None: """Update a view's attributes.""" async with self._session_factory() as session: stmt = ( @@ -543,7 +542,7 @@ class BitableRepository: ) -> None: """Update a recalc task's status.""" async with self._session_factory() as session: - kwargs: dict[str, Any] = {"status": status.value} + kwargs: dict[str, object] = {"status": status.value} if error_message is not None: kwargs["error_message"] = error_message if status in (RecalcStatus.done, RecalcStatus.error): @@ -630,7 +629,7 @@ class BitableRepository: return result_map async def upsert_record_agent_fields( - self, record_id: str, agent_field_values: dict[str, Any] + self, record_id: str, agent_field_values: dict[str, object] ) -> None: """Update agent-owned fields using jsonb_set (KTD8). @@ -646,7 +645,7 @@ class BitableRepository: # Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect # misparses the `::` as part of the param name. inner = "values" - params: dict[str, Any] = {"record_id": record_id} + params: dict[str, object] = {"record_id": record_id} for i, (field_id, value) in enumerate(agent_field_values.items()): param_key = f"v{i}" inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)" @@ -660,8 +659,8 @@ class BitableRepository: async def list_records_filtered( self, table_id: str, - filters: list[dict[str, Any]] | None = None, - sorts: list[dict[str, Any]] | None = None, + filters: list[dict[str, object]] | None = None, + sorts: list[dict[str, object]] | None = None, cursor: str | None = None, limit: int = 50, ) -> tuple[list[Record], str | None]: @@ -682,7 +681,7 @@ class BitableRepository: # Build raw SQL with JSONB filter/sort translation. # ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer). where_clauses = ["table_id = :table_id"] - params: dict[str, Any] = {"table_id": table_id} + params: dict[str, object] = {"table_id": table_id} if filters: for i, f in enumerate(filters): @@ -783,7 +782,7 @@ class BitableRepository: last_mapping = last_row._mapping # Build composite cursor from sort values + id. # Sort values are extracted as text to match `values->>'fid'` expressions. - sv: list[Any] = [] + sv: list[object] = [] last_values = last_mapping.get("values") if isinstance(last_values, str): # asyncpg may return JSONB as str in raw text() queries. @@ -852,7 +851,7 @@ class BitableRepository: # ── Recalc support (U3) ──────────────────────────────── - async def get_column_values(self, table_id: str, field_id: str) -> list[Any]: + async def get_column_values(self, table_id: str, field_id: str) -> list[object]: """Get all values for a field across all records in a table (for aggregates). Returns a list of values (preserving order by record id). Missing values @@ -866,7 +865,7 @@ class BitableRepository: result = await session.execute(sql, {"field_id": field_id, "table_id": table_id}) return [row[0] for row in result.fetchall()] - async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None: + async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None: """Set a single formula field value in a record's JSONB (jsonb_set).""" import json diff --git a/src/agentkit/client/sync.py b/src/agentkit/client/sync.py index 7edb326..8f5710a 100644 --- a/src/agentkit/client/sync.py +++ b/src/agentkit/client/sync.py @@ -34,12 +34,19 @@ import os import sqlite3 from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable +from typing import Callable, TypeAlias import httpx logger = logging.getLogger(__name__) +# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。 +# 服务器返回的 skills/workflows 列表元素是 dict(model_dump/to_dict 输出), +# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。 +SkillConfigDict: TypeAlias = dict[str, object] +WorkflowConfigDict: TypeAlias = dict[str, object] +SyncedConfigPayload: TypeAlias = dict[str, object] + # ── Defaults ────────────────────────────────────────────────────────── @@ -100,8 +107,8 @@ class ConfigSync: # In-memory cache (mirrors the SQLite cache for fast access) self._version: str | None = None - self._skills: list[dict[str, Any]] = [] - self._workflows: list[dict[str, Any]] = [] + self._skills: list[SkillConfigDict] = [] + self._workflows: list[WorkflowConfigDict] = [] self._last_synced_at: str | None = None # ── Lifecycle ───────────────────────────────────────────────── @@ -232,15 +239,15 @@ class ConfigSync: """Return the current cached config version hash.""" return self._version - def get_skills(self) -> list[dict[str, Any]]: + def get_skills(self) -> list[SkillConfigDict]: """Return the cached skill configs.""" return list(self._skills) - def get_workflows(self) -> list[dict[str, Any]]: + def get_workflows(self) -> list[WorkflowConfigDict]: """Return the cached workflow configs.""" return list(self._workflows) - def get_all(self) -> dict[str, Any]: + def get_all(self) -> SyncedConfigPayload: """Return all cached configs as a single dict.""" return { "version": self._version, @@ -249,14 +256,14 @@ class ConfigSync: "synced_at": self._last_synced_at, } - def get_skill(self, name: str) -> dict[str, Any] | None: + def get_skill(self, name: str) -> SkillConfigDict | None: """Return a single skill config by name, or ``None``.""" for skill in self._skills: if skill.get("name") == name: return skill return None - def get_workflow(self, workflow_id: str) -> dict[str, Any] | None: + def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None: """Return a single workflow config by ID, or ``None``.""" for wf in self._workflows: if wf.get("workflow_id") == workflow_id: @@ -281,7 +288,7 @@ class ConfigSync: conn.executescript(_CACHE_SCHEMA) conn.commit() - def _save_to_cache(self, data: dict[str, Any]) -> None: + def _save_to_cache(self, data: SyncedConfigPayload) -> None: """Save the synced configs to the local SQLite cache.""" now = datetime.now(timezone.utc).isoformat() with sqlite3.connect(str(self.cache_db_path)) as conn: diff --git a/src/agentkit/llm/cache.py b/src/agentkit/llm/cache.py index b0f334c..0919920 100644 --- a/src/agentkit/llm/cache.py +++ b/src/agentkit/llm/cache.py @@ -14,7 +14,7 @@ import logging import time from collections import OrderedDict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, runtime_checkable from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall from agentkit.utils.vector_math import compute_cosine_similarity @@ -25,6 +25,52 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的第三方对象 +# --------------------------------------------------------------------------- + + +if TYPE_CHECKING: + + class _RedisLike(Protocol): + """Redis 客户端最小契约(仅覆盖本模块用到的方法)。""" + + async def get(self, key: str) -> bytes | str | None: ... + + async def mget(self, keys: list[str]) -> list[bytes | str | None]: ... + + async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ... + + async def smembers(self, key: str) -> set[bytes | str]: ... + + async def sadd(self, name: str, *values: str) -> int: ... + + async def srem(self, name: str, *values: str) -> int: ... + + async def scard(self, name: str) -> int: ... + + async def delete(self, *names: str) -> int: ... + + def pipeline(self) -> "_RedisPipelineLike": ... + + class _RedisPipelineLike(Protocol): + """Redis pipeline 最小契约。""" + + def get(self, key: str) -> "_RedisPipelineLike": ... + + def set( + self, key: str, value: bytes | str, ex: int | None = None + ) -> "_RedisPipelineLike": ... + + def delete(self, *names: str) -> "_RedisPipelineLike": ... + + def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ... + + def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ... + + async def execute(self) -> list[object]: ... + + # --------------------------------------------------------------------------- # Data Classes # --------------------------------------------------------------------------- @@ -328,7 +374,7 @@ class RedisLLMCache: self._semantic_ttl = semantic_ttl self._similarity_threshold = similarity_threshold self._max_entries_to_scan = max_entries_to_scan - self._redis: Any = None + self._redis: _RedisLike | None = None self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation self._degraded = False # True if Redis is unreachable @@ -691,7 +737,7 @@ class LitellmCacheManager: def __init__(self, config: LitellmCacheConfig): self._config = config - self._cache_instance: Any = None # litellm.caching.Cache 实例 + self._cache_instance: object | None = None # litellm.caching.Cache 实例 self._hits = 0 self._misses = 0 @@ -709,7 +755,7 @@ class LitellmCacheManager: litellm.cache = None self._cache_instance = None - def _create_cache_instance(self) -> Any: + def _create_cache_instance(self) -> object: """根据 backend 配置创建 LiteLLM Cache 实例。 auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。 diff --git a/src/agentkit/llm/cache_key.py b/src/agentkit/llm/cache_key.py index 3a5211b..50dc348 100644 --- a/src/agentkit/llm/cache_key.py +++ b/src/agentkit/llm/cache_key.py @@ -2,14 +2,13 @@ import hashlib import json -from typing import Any def generate_cache_key( model: str, messages: list[dict[str, str]], temperature: float, - tools: list[dict[str, Any]] | None = None, + tools: list[dict[str, object]] | None = None, tool_choice: str = "auto", max_tokens: int = 2000, user_id: str | None = None, diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index 1ee2aa9..3df6e61 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -3,12 +3,12 @@ import json import logging from dataclasses import dataclass, field -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig if TYPE_CHECKING: - from agentkit.channels.secrets import SecretsStore + from agentkit.channels.secrets import SecretEntry, SecretsStore logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class ProviderConfig: api_key: str base_url: str - models: dict[str, dict[str, Any]] = field(default_factory=dict) + models: dict[str, dict[str, object]] = field(default_factory=dict) type: str = "openai" # "openai" | "anthropic" | "gemini" max_tokens: int = 4096 # Anthropic: default max_tokens timeout: float = 120.0 # Anthropic: request timeout @@ -168,18 +168,18 @@ class ProviderConfig: return f"llm:provider:{self.type}:api_key" @staticmethod - def _encode_secret_entry(entry: Any, key: str) -> str: + def _encode_secret_entry(entry: object, key: str) -> str: """把 SecretEntry 编码为 JSON 字符串(含 key 字段)。""" # entry 是 SecretEntry pydantic 模型,有 model_dump() if hasattr(entry, "model_dump"): - data = entry.model_dump() + data = entry.model_dump() # type: ignore[attr-defined] else: - data = dict(entry) + data = dict(entry) # type: ignore[call-overload] data["key"] = key return json.dumps(data) @staticmethod - def _decode_secret_entry(encoded: str) -> Any: + def _decode_secret_entry(encoded: str) -> "SecretEntry": """从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。""" from agentkit.channels.secrets import SecretEntry diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 0b7be8f..387b578 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -5,7 +5,7 @@ import logging import time from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Protocol from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.config import LLMConfig @@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE from agentkit.telemetry.metrics import llm_token_histogram +if TYPE_CHECKING: + from agentkit.llm.cache import LitellmCacheManager + logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的对象 +# --------------------------------------------------------------------------- + + +if TYPE_CHECKING: + + class _QuotaServiceLike(Protocol): + """Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。""" + + async def is_model_allowed( + self, db: Path, department_id: str, model: str + ) -> tuple[bool, str]: ... + + async def check_quota( + self, + db: Path, + department_id: str, + quota_type: str, + period: str, + current: float, + ) -> tuple[bool, str]: ... + + async def get_quota( + self, db: Path, department_id: str, quota_type: str, period: str + ) -> dict[str, object] | None: ... + + class QuotaExceededError(Exception): """Raised when a department's LLM quota is exceeded. @@ -29,8 +60,8 @@ class QuotaExceededError(Exception): department_id: str, quota_type: str, period: str, - limit: Any, - current: Any, + limit: object, + current: object, ) -> None: self.department_id = department_id self.quota_type = quota_type @@ -46,13 +77,13 @@ class QuotaExceededError(Exception): class LLMGateway: """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache""" - def __init__(self, config: LLMConfig | None = None, usage_store: Any = None): + def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None): self._providers: dict[str, LLMProvider] = {} self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._config = config or LLMConfig() # Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用) - self._cache_manager: Any = None # LitellmCacheManager | None + self._cache_manager: "LitellmCacheManager | None" = None if self._config.cache and self._config.cache.enabled: from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager @@ -601,7 +632,7 @@ class LLMGateway: async def _check_quota_value( self, - quota_service: Any, + quota_service: _QuotaServiceLike, db: Path, dept_id: str, period: str, diff --git a/src/agentkit/llm/migration.py b/src/agentkit/llm/migration.py index cae041a..6a9eb1b 100644 --- a/src/agentkit/llm/migration.py +++ b/src/agentkit/llm/migration.py @@ -27,12 +27,11 @@ from __future__ import annotations import logging from pathlib import Path -from typing import Any logger = logging.getLogger(__name__) -def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]: +def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]: """把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。 流程: @@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, store = SecretsStore() # master key 从 env 加载 - async def _run() -> dict[str, dict[str, Any]]: - report: dict[str, dict[str, Any]] = {} + async def _run() -> dict[str, dict[str, object]]: + report: dict[str, dict[str, object]] = {} for name, pconf in llm_config.providers.items(): if pconf.api_key_source == "secrets_store" and not pconf.api_key: report[name] = {"status": "skipped", "source": pconf.api_key_source} @@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, report = asyncio.run(_run()) # 写回 YAML:更新 llm.providers 段,保留其它段 - providers_out: dict[str, dict[str, Any]] = {} + providers_out: dict[str, dict[str, object]] = {} for name, pconf in llm_config.providers.items(): - entry: dict[str, Any] = { + entry: dict[str, object] = { "type": pconf.type, "base_url": pconf.base_url, "models": pconf.models, diff --git a/src/agentkit/llm/protocol.py b/src/agentkit/llm/protocol.py index c6f5d54..21369c3 100644 --- a/src/agentkit/llm/protocol.py +++ b/src/agentkit/llm/protocol.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any @dataclass @@ -23,7 +22,7 @@ class ToolCall: id: str name: str - arguments: dict[str, Any] + arguments: dict[str, object] @dataclass @@ -32,7 +31,7 @@ class LLMRequest: messages: list[dict[str, str]] model: str - tools: list[dict[str, Any]] | None = None + tools: list[dict[str, object]] | None = None tool_choice: str = "auto" temperature: float = 0.7 max_tokens: int = 2000 @@ -42,13 +41,13 @@ class LLMRequest: self, messages: list[dict[str, str]], model: str, - tools: list[dict[str, Any]] | None = None, + tools: list[dict[str, object]] | None = None, tool_choice: str = "auto", temperature: float = 0.7, max_tokens: int = 2000, timeout: float | None = None, - cache: dict[str, Any] | None = None, - **kwargs: Any, + cache: dict[str, object] | None = None, + **kwargs: object, ): self.messages = messages self.model = model @@ -59,7 +58,7 @@ class LLMRequest: self.timeout = timeout self._extra = kwargs # U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion - self._cache: dict[str, Any] | None = cache + self._cache: dict[str, object] | None = cache @dataclass diff --git a/src/agentkit/llm/providers/anthropic.py b/src/agentkit/llm/providers/anthropic.py index 2829ac9..3a1c197 100644 --- a/src/agentkit/llm/providers/anthropic.py +++ b/src/agentkit/llm/providers/anthropic.py @@ -3,7 +3,6 @@ import json import logging import time -from typing import Any import httpx @@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider): "content-type": "application/json", } - def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]: + def _convert_messages( + self, messages: list[dict[str, str]] + ) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]: """将 OpenAI 风格消息转换为 Anthropic 格式 Returns: @@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider): - list[dict]: Anthropic content blocks(支持 cache_control,U2/G2) - None: 无 system 消息 """ - system_prompt: str | list[dict[str, Any]] | None = None - anthropic_messages: list[dict[str, Any]] = [] + system_prompt: str | list[dict[str, object]] | None = None + anthropic_messages: list[dict[str, object]] = [] for msg in messages: role = msg.get("role", "") @@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider): # 检查是否有 tool_calls (OpenAI 格式) tool_calls = msg.get("tool_calls") if tool_calls: - blocks: list[dict[str, Any]] = [] + blocks: list[dict[str, object]] = [] # 如果有文本内容,先添加文本块 if content: blocks.append({"type": "text", "text": content}) @@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider): arguments = json.loads(arguments) except json.JSONDecodeError: arguments = {"raw": arguments} - blocks.append({ - "type": "tool_use", - "id": tc.get("id", ""), - "name": func.get("name", ""), - "input": arguments, - }) + blocks.append( + { + "type": "tool_use", + "id": tc.get("id", ""), + "name": func.get("name", ""), + "input": arguments, + } + ) anthropic_messages.append({"role": "assistant", "content": blocks}) else: - anthropic_messages.append({ - "role": "assistant", - "content": [{"type": "text", "text": content}], - }) + anthropic_messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": content}], + } + ) continue if role == "user": # 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果) # OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."} if msg.get("tool_call_id"): - tool_result_blocks: list[dict[str, Any]] = [] + tool_result_blocks: list[dict[str, object]] = [] tool_content = msg.get("content", "") # tool_result 的 content 可以是字符串或内容块列表 if isinstance(tool_content, str): @@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider): else: tool_result_blocks.append({"type": "text", "text": str(tool_content)}) - anthropic_messages.append({ - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": msg.get("tool_call_id", ""), - "content": tool_result_blocks, - }], - }) + anthropic_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": tool_result_blocks, + } + ], + } + ) else: - anthropic_messages.append({ - "role": "user", - "content": [{"type": "text", "text": content}], - }) + anthropic_messages.append( + { + "role": "user", + "content": [{"type": "text", "text": content}], + } + ) continue if role == "tool": # OpenAI 格式中独立的 tool 消息 tool_content = msg.get("content", "") if isinstance(tool_content, str): - result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}] + result_content: list[dict[str, object]] | str = [ + {"type": "text", "text": tool_content} + ] elif isinstance(tool_content, list): result_content = tool_content else: result_content = [{"type": "text", "text": str(tool_content)}] - anthropic_messages.append({ - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": msg.get("tool_call_id", ""), - "content": result_content, - }], - }) + anthropic_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": result_content, + } + ], + } + ) return system_prompt, anthropic_messages - def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]: """将 OpenAI function 格式转换为 Anthropic tool 格式""" anthropic_tools = [] for tool in tools: if tool.get("type") == "function": func = tool.get("function", {}) - anthropic_tools.append({ - "name": func.get("name", ""), - "description": func.get("description", ""), - "input_schema": func.get("parameters", {"type": "object", "properties": {}}), - }) + anthropic_tools.append( + { + "name": func.get("name", ""), + "description": func.get("description", ""), + "input_schema": func.get( + "parameters", {"type": "object", "properties": {}} + ), + } + ) return anthropic_tools - def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: + def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None: """将 OpenAI tool_choice 格式转换为 Anthropic 格式""" if tool_choice == "auto": return {"type": "auto"} @@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider): return {"type": "tool", "name": tool_choice} return None - def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: + def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse: """将 Anthropic 响应转换为 LLMResponse""" content_blocks = data.get("content", []) text_parts: list[str] = [] @@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider): if block_type == "text": text_parts.append(block.get("text", "")) elif block_type == "tool_use": - tool_calls.append(ToolCall( - id=block.get("id", ""), - name=block.get("name", ""), - arguments=block.get("input", {}), - )) + tool_calls.append( + ToolCall( + id=block.get("id", ""), + name=block.get("name", ""), + arguments=block.get("input", {}), + ) + ) usage_data = data.get("usage", {}) usage = TokenUsage( @@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider): system_prompt, anthropic_messages = self._convert_messages(request.messages) - payload: dict[str, Any] = { + payload: dict[str, object] = { "model": request.model, "max_tokens": request.max_tokens or self._max_tokens, "messages": anthropic_messages, @@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider): system_prompt, anthropic_messages = self._convert_messages(request.messages) - payload: dict[str, Any] = { + payload: dict[str, object] = { "model": request.model, "max_tokens": request.max_tokens or self._max_tokens, "messages": anthropic_messages, @@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider): async def _iterate_stream(self, response, request: LLMRequest): """Iterate over an already-open SSE stream and yield StreamChunks.""" # Accumulated tool calls: tool_use_id -> {id, name, input_json_str} - accumulated_tool_calls: dict[str, dict[str, Any]] = {} + accumulated_tool_calls: dict[str, dict[str, object]] = {} current_tool_id: str | None = None current_tool_name: str | None = None current_tool_input_json: str = "" @@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider): # Finalize current tool call if any if current_tool_id is not None: try: - arguments = json.loads(current_tool_input_json) if current_tool_input_json else {} + arguments = ( + json.loads(current_tool_input_json) if current_tool_input_json else {} + ) except json.JSONDecodeError: arguments = {"raw": current_tool_input_json} @@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider): error_msg = error_info.get("message", "Stream error") raise LLMProviderError("anthropic", error_msg) - def get_model_info(self) -> dict[str, Any]: + def get_model_info(self) -> dict[str, object]: """返回 Provider 和模型信息""" return { "provider": "anthropic", diff --git a/src/agentkit/llm/providers/doubao.py b/src/agentkit/llm/providers/doubao.py index ebd7f9a..9a15994 100644 --- a/src/agentkit/llm/providers/doubao.py +++ b/src/agentkit/llm/providers/doubao.py @@ -8,7 +8,6 @@ API:火山引擎 OpenAI 兼容接口 from __future__ import annotations import logging -from typing import Any from agentkit.llm.providers.openai import OpenAICompatibleProvider @@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider): api_key: str, base_url: str = DOUBAO_DEFAULT_BASE_URL, default_model: str = "doubao-pro-32k", - **kwargs: Any, + **kwargs: object, ): super().__init__( api_key=api_key, diff --git a/src/agentkit/llm/providers/gemini.py b/src/agentkit/llm/providers/gemini.py index 0b57efe..f267c31 100644 --- a/src/agentkit/llm/providers/gemini.py +++ b/src/agentkit/llm/providers/gemini.py @@ -3,7 +3,6 @@ import json import logging import time -from typing import Any import httpx @@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider): def _convert_messages( self, messages: list[dict[str, str]] - ) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]: + ) -> tuple[dict[str, object] | None, list[dict[str, object]]]: """将 OpenAI 风格消息转换为 Gemini 格式 Returns: (system_instruction, contents) """ - system_instruction: dict[str, Any] | None = None - contents: list[dict[str, Any]] = [] + system_instruction: dict[str, object] | None = None + contents: list[dict[str, object]] = [] for msg in messages: role = msg.get("role", "") @@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider): tool_name = parsed.get("name", "") except (json.JSONDecodeError, AttributeError): pass - contents.append({ - "role": "user", - "parts": [{ - "functionResponse": { - "name": tool_name, - "response": { - "content": content, - }, - }, - }], - }) + contents.append( + { + "role": "user", + "parts": [ + { + "functionResponse": { + "name": tool_name, + "response": { + "content": content, + }, + }, + } + ], + } + ) else: - contents.append({ - "role": "user", - "parts": [{"text": content}], - }) + contents.append( + { + "role": "user", + "parts": [{"text": content}], + } + ) continue if role == "assistant": tool_calls = msg.get("tool_calls") if tool_calls: - parts: list[dict[str, Any]] = [] + parts: list[dict[str, object]] = [] if content: parts.append({"text": content}) for tc in tool_calls: @@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider): arguments = json.loads(arguments) except json.JSONDecodeError: arguments = {"raw": arguments} - parts.append({ - "functionCall": { - "name": func.get("name", ""), - "args": arguments, - }, - }) + parts.append( + { + "functionCall": { + "name": func.get("name", ""), + "args": arguments, + }, + } + ) contents.append({"role": "model", "parts": parts}) else: - contents.append({ - "role": "model", - "parts": [{"text": content}], - }) + contents.append( + { + "role": "model", + "parts": [{"text": content}], + } + ) continue if role == "tool": # OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."} tool_name = msg.get("name", "") tool_content = msg.get("content", "") - contents.append({ - "role": "user", - "parts": [{ - "functionResponse": { - "name": tool_name, - "response": { - "content": tool_content, - }, - }, - }], - }) + contents.append( + { + "role": "user", + "parts": [ + { + "functionResponse": { + "name": tool_name, + "response": { + "content": tool_content, + }, + }, + } + ], + } + ) return system_instruction, contents - def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]: """将 OpenAI function 格式转换为 Gemini functionDeclarations""" declarations = [] for tool in tools: if tool.get("type") == "function": func = tool.get("function", {}) - declarations.append({ - "name": func.get("name", ""), - "description": func.get("description", ""), - "parameters": func.get("parameters", {"type": "object", "properties": {}}), - }) + declarations.append( + { + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": func.get("parameters", {"type": "object", "properties": {}}), + } + ) if not declarations: return [] return [{"functionDeclarations": declarations}] - def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: + def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None: """将 OpenAI tool_choice 格式转换为 Gemini toolConfig""" if tool_choice == "auto": return {"functionCallingConfig": {"mode": "AUTO"}} @@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider): return {"functionCallingConfig": {"mode": "NONE"}} return None - def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: + def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse: """将 Gemini 响应转换为 LLMResponse""" candidates = data.get("candidates", []) text_parts: list[str] = [] @@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider): text_parts.append(part["text"]) elif "functionCall" in part: fc = part["functionCall"] - tool_calls.append(ToolCall( - id=f"call_{tool_call_index}", - name=fc.get("name", ""), - arguments=fc.get("args", {}), - )) + tool_calls.append( + ToolCall( + id=f"call_{tool_call_index}", + name=fc.get("name", ""), + arguments=fc.get("args", {}), + ) + ) tool_call_index += 1 usage_metadata = data.get("usageMetadata", {}) @@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider): system_instruction, contents = self._convert_messages(request.messages) - payload: dict[str, Any] = { + payload: dict[str, object] = { "contents": contents, "generationConfig": { "temperature": request.temperature, @@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider): system_instruction, contents = self._convert_messages(request.messages) - payload: dict[str, Any] = { + payload: dict[str, object] = { "contents": contents, "generationConfig": { "temperature": request.temperature, @@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider): async def _iterate_stream(self, response, request: LLMRequest): """Iterate over an already-open SSE stream and yield StreamChunks.""" - accumulated_tool_calls: list[dict[str, Any]] = [] + accumulated_tool_calls: list[dict[str, object]] = [] model = request.model or self._model async for line in response.aiter_lines(): @@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider): ) elif "functionCall" in part: fc = part["functionCall"] - accumulated_tool_calls.append({ - "id": f"call_{len(accumulated_tool_calls)}", - "name": fc.get("name", ""), - "arguments": fc.get("args", {}), - }) + accumulated_tool_calls.append( + { + "id": f"call_{len(accumulated_tool_calls)}", + "name": fc.get("name", ""), + "arguments": fc.get("args", {}), + } + ) # Check for finish reason finish_reason = candidates[0].get("finishReason", "") @@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider): ) accumulated_tool_calls = [] - def get_model_info(self) -> dict[str, Any]: + def get_model_info(self) -> dict[str, object]: """返回 Provider 和模型信息""" return { "provider": "gemini", diff --git a/src/agentkit/llm/providers/litellm_provider.py b/src/agentkit/llm/providers/litellm_provider.py index 45d144a..e25784f 100644 --- a/src/agentkit/llm/providers/litellm_provider.py +++ b/src/agentkit/llm/providers/litellm_provider.py @@ -26,8 +26,7 @@ import inspect import json import logging import time -from collections.abc import AsyncGenerator -from typing import Any +from collections.abc import AsyncGenerator, Iterable from agentkit.core.exceptions import LLMProviderError from agentkit.llm.protocol import ( @@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider): api_key: str, base_url: str | None = None, provider_type: str = "openai", - **default_kwargs: Any, + **default_kwargs: object, ) -> None: self._model_prefix = model_prefix self._api_key = api_key self._base_url = base_url or None # 空字符串视作未设置 self._provider_type = provider_type - self._default_kwargs: dict[str, Any] = dict(default_kwargs) + self._default_kwargs: dict[str, object] = dict(default_kwargs) async def chat(self, request: LLMRequest) -> LLMResponse: """非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。""" @@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider): kwargs = self._build_kwargs(request, stream=True) - accumulated_tool_calls: dict[int, dict[str, Any]] = {} + accumulated_tool_calls: dict[int, dict[str, object]] = {} final_usage: TokenUsage | None = None final_model: str = request.model @@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider): # 内部辅助 # ------------------------------------------------------------------ - def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]: + def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]: """从 LLMRequest 构造 litellm.acompletion kwargs。""" - kwargs: dict[str, Any] = { + kwargs: dict[str, object] = { "model": f"{self._model_prefix}{request.model}", "messages": request.messages, "temperature": request.temperature, @@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider): def _parse_response( self, - response: Any, + response: object, request_model: str, latency_ms: float, ) -> LLMResponse: @@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider): def _parse_stream_chunk( self, - chunk: Any, + chunk: object, request_model: str, - accumulated_tool_calls: dict[int, dict[str, Any]], + accumulated_tool_calls: dict[int, dict[str, object]], ) -> StreamChunk: """解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。""" choices = getattr(chunk, "choices", None) or [] @@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider): def _finalize_tool_calls( self, - accumulated: dict[int, dict[str, Any]], + accumulated: dict[int, dict[str, object]], ) -> list[ToolCall]: """把累计的流式 tool_calls 字典转成 ToolCall 列表。""" tool_calls: list[ToolCall] = [] @@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider): # ---------------------------------------------------------------------- -def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: +def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]: """解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。""" result: list[ToolCall] = [] for tc in raw_tool_calls: @@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: return result -def _parse_usage(usage_obj: Any) -> TokenUsage: +def _parse_usage(usage_obj: object) -> TokenUsage: """解析 usage 对象(OpenAI CompletionUsage 或 dict)。""" prompt = getattr(usage_obj, "prompt_tokens", None) completion = getattr(usage_obj, "completion_tokens", None) @@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage: def _accumulate_stream_tool_calls( - raw_tool_calls: Any, - accumulated: dict[int, dict[str, Any]], + raw_tool_calls: Iterable[object], + accumulated: dict[int, dict[str, object]], ) -> None: """累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。 @@ -364,7 +363,7 @@ def create_litellm_provider( provider_type: str, api_key: str, base_url: str | None = None, - **kwargs: Any, + **kwargs: object, ) -> LitellmProvider: """根据 provider_type 创建 LitellmProvider 实例。 diff --git a/src/agentkit/llm/providers/usage_store.py b/src/agentkit/llm/providers/usage_store.py index f7cf77d..3bd634d 100644 --- a/src/agentkit/llm/providers/usage_store.py +++ b/src/agentkit/llm/providers/usage_store.py @@ -18,7 +18,7 @@ import json import logging from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import Any, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable from agentkit.llm.protocol import TokenUsage @@ -294,8 +294,8 @@ class RedisUsageStore: def __init__(self, redis_url: str = "redis://localhost:6379"): self._redis_url = redis_url - self._redis: Any = None - self._sync_redis: Any = None + self._redis: object | None = None + self._sync_redis: object | None = None self._fallback: InMemoryUsageStore | None = None self._degraded = False self._health_check_task: asyncio.Task[None] | None = None @@ -687,7 +687,7 @@ class RedisUsageStore: @staticmethod def _read_list( - r: Any, + r: object, list_key: str, start: datetime, end: datetime, diff --git a/src/agentkit/llm/providers/wenxin.py b/src/agentkit/llm/providers/wenxin.py index ae6ee8e..bdb173a 100644 --- a/src/agentkit/llm/providers/wenxin.py +++ b/src/agentkit/llm/providers/wenxin.py @@ -9,7 +9,6 @@ from __future__ import annotations import logging import time -from typing import Any from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.protocol import LLMRequest, LLMResponse @@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider): secret_key: str | None = None, base_url: str = WENXIN_DEFAULT_BASE_URL, default_model: str = "ernie-4.5-turbo-128k", - **kwargs: Any, + **kwargs: object, ): # If AK/SK provided, use token-based auth self._access_key = access_key diff --git a/src/agentkit/llm/providers/yuanbao.py b/src/agentkit/llm/providers/yuanbao.py index a055c36..2b61241 100644 --- a/src/agentkit/llm/providers/yuanbao.py +++ b/src/agentkit/llm/providers/yuanbao.py @@ -8,7 +8,6 @@ API:腾讯云 OpenAI 兼容接口 from __future__ import annotations import logging -from typing import Any from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.protocol import LLMRequest, LLMResponse @@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider): base_url: str = YUANBAO_DEFAULT_BASE_URL, default_model: str = "hunyuan-turbos-latest", enable_enhancement: bool = False, - **kwargs: Any, + **kwargs: object, ): self._enable_enhancement = enable_enhancement super().__init__( diff --git a/src/agentkit/llm/remote_provider.py b/src/agentkit/llm/remote_provider.py index 1ec54df..706336a 100644 --- a/src/agentkit/llm/remote_provider.py +++ b/src/agentkit/llm/remote_provider.py @@ -3,7 +3,6 @@ import json import logging from collections.abc import AsyncIterator, Awaitable, Callable -from typing import Any import httpx @@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider): "Content-Type": "application/json", } - def _build_payload(self, request: LLMRequest) -> dict[str, Any]: + def _build_payload(self, request: LLMRequest) -> dict[str, object]: """Convert LLMRequest to server API payload.""" return { "messages": request.messages, @@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider): return str(body["error"]) return str(body) - def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse: + def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse: """Parse server response JSON into an LLMResponse.""" usage_data = data.get("usage") or {} usage = TokenUsage( @@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider): latency_ms=data.get("latency_ms", 0.0), ) - def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk: + def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk: """Parse a single SSE data payload into a StreamChunk.""" usage: TokenUsage | None = None usage_data = data.get("usage") @@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider): if response.status_code == 502: await response.aread() detail = self._extract_error_detail(response) - raise LLMProviderError( - "remote", f"Server LLM gateway error: {detail}" - ) + raise LLMProviderError("remote", f"Server LLM gateway error: {detail}") if response.status_code != 200: await response.aread() raise LLMProviderError( diff --git a/src/agentkit/llm/retry.py b/src/agentkit/llm/retry.py index cc2990f..9659856 100644 --- a/src/agentkit/llm/retry.py +++ b/src/agentkit/llm/retry.py @@ -5,7 +5,7 @@ import logging import time from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable +from typing import Callable from agentkit.core.exceptions import LLMProviderError @@ -20,9 +20,7 @@ class RetryConfig: base_delay: float = 1.0 max_delay: float = 30.0 exponential_base: float = 2.0 - retryable_status_codes: set[int] = field( - default_factory=lambda: {429, 500, 502, 503, 529} - ) + retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529}) class CircuitState(Enum): @@ -69,7 +67,7 @@ class RetryPolicy: def __init__(self, config: RetryConfig | None = None): self._config = config or RetryConfig() - async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: + async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object: """Execute fn with retry on retryable errors.""" last_error: Exception | None = None @@ -84,7 +82,7 @@ class RetryPolicy: raise delay = min( - self._config.base_delay * (self._config.exponential_base ** attempt), + self._config.base_delay * (self._config.exponential_base**attempt), self._config.max_delay, ) logger.warning( @@ -142,7 +140,7 @@ class CircuitBreaker: f"after {self._failure_count} failures" ) - async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: + async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object: """Execute fn through the circuit breaker.""" current_state = self.state @@ -158,6 +156,6 @@ class CircuitBreaker: result = await fn(*args, **kwargs) self._on_success() return result - except Exception as e: + except Exception: self._on_failure() raise diff --git a/src/agentkit/memory/adapters/base.py b/src/agentkit/memory/adapters/base.py index eddedec..2c2e230 100644 --- a/src/agentkit/memory/adapters/base.py +++ b/src/agentkit/memory/adapters/base.py @@ -10,7 +10,6 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any import httpx @@ -95,8 +94,7 @@ class KBAdapter(ABC): async def delete_by_id(self, id: str) -> bool: """按文档 ID 删除(子类可覆盖)""" logger.warning( - f"{self.__class__.__name__} does not support delete_by_id; " - f"id '{id}' skipped" + f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped" ) return False @@ -127,8 +125,7 @@ class KBAdapter(ABC): async def get_document(self, doc_id: str) -> Document | None: """按 ID 获取单个文档(子类可覆盖)""" logger.warning( - f"{self.__class__.__name__} does not support get_document; " - f"doc_id '{doc_id}' not found" + f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found" ) return None @@ -156,5 +153,5 @@ class KBAdapter(ABC): async def __aenter__(self) -> KBAdapter: return self - async def __aexit__(self, *args: Any) -> None: + async def __aexit__(self, *args: object) -> None: await self.close() diff --git a/src/agentkit/memory/adapters/confluence.py b/src/agentkit/memory/adapters/confluence.py index 1478fb0..1aeb763 100644 --- a/src/agentkit/memory/adapters/confluence.py +++ b/src/agentkit/memory/adapters/confluence.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from typing import Any import httpx @@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter): ) self._base_url = base_url.rstrip("/") if not is_safe_url(self._base_url): - raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") + raise ValueError( + f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed." + ) self._username = username self._api_token = api_token self._space_keys = space_keys or [] @@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter): """创建 Confluence API HTTP 客户端""" import base64 - credentials = base64.b64encode( - f"{self._username}:{self._api_token}".encode() - ).decode() + credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode() return httpx.AsyncClient( base_url=self._base_url, headers={ @@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter): space_filter = " OR ".join( f'space = "{_escape_cql(key)}"' for key in self._space_keys ) - cql = f'{cql} AND ({space_filter})' + cql = f"{cql} AND ({space_filter})" resp = await client.get( "/rest/api/content/search", @@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter): body = page.get("body", {}).get("storage", {}).get("value", "") # Strip HTML tags for plain text content import re + content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "") results.append( @@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter): except httpx.HTTPStatusError as e: logger.error( - f"Confluence search HTTP error: {e.response.status_code} — " - f"{e.response.text[:200]}" + f"Confluence search HTTP error: {e.response.status_code} — {e.response.text[:200]}" ) return [] except Exception as e: @@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter): body = page.get("body", {}).get("storage", {}).get("value", "") import re + content = re.sub(r"<[^>]+>", "", body) if body else "" return Document( @@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter): source_type="confluence", ) ) - return sources if sources else [ - SourceInfo( - source_id=self._source_id, - source_name=self._source_name, - source_type=self._source_type, - ) - ] + return ( + sources + if sources + else [ + SourceInfo( + source_id=self._source_id, + source_name=self._source_name, + source_type=self._source_type, + ) + ] + ) except Exception as e: logger.error(f"Confluence list_sources error: {e}") return [ diff --git a/src/agentkit/memory/adapters/feishu.py b/src/agentkit/memory/adapters/feishu.py index e28dec6..69ef235 100644 --- a/src/agentkit/memory/adapters/feishu.py +++ b/src/agentkit/memory/adapters/feishu.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging import time -from typing import Any +from typing import TypeAlias import httpx @@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url logger = logging.getLogger(__name__) +# 飞书搜索请求 payload:search_key/page_size/wiki_space_ids — 值为 str|int|list[str]。 +FeishuSearchPayload: TypeAlias = dict[str, object] + class FeishuKBAdapter(KBAdapter): """飞书知识库适配器 @@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter): self._app_secret = app_secret self._base_url = base_url.rstrip("/") if not is_safe_url(self._base_url): - raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") + raise ValueError( + f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed." + ) self._space_ids = space_ids or [] self._access_token: str | None = None self._token_expiry: float = 0.0 @@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter): self._client = None return self._access_token else: - logger.error( - f"Feishu auth failed: code={data.get('code')}, " - f"msg={data.get('msg')}" - ) + logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}") return None except Exception as e: logger.error(f"Feishu auth error: {e}") @@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter): client = self._get_client() try: - payload: dict[str, Any] = { + payload: FeishuSearchPayload = { "search_key": query, "page_size": top_k, } @@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter): if data.get("code") != 0: logger.error( - f"Feishu search failed: code={data.get('code')}, " - f"msg={data.get('msg')}" + f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}" ) return [] @@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter): except httpx.HTTPStatusError as e: logger.error( - f"Feishu search HTTP error: {e.response.status_code} — " - f"{e.response.text[:200]}" + f"Feishu search HTTP error: {e.response.status_code} — {e.response.text[:200]}" ) return [] except Exception as e: @@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter): client = self._get_client() try: resp = await client.get( - f"/wiki/v2/spaces/get_node", + "/wiki/v2/spaces/get_node", params={"token": doc_id}, ) resp.raise_for_status() @@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter): source_type="feishu", ) ) - return sources if sources else [ - SourceInfo( - source_id=self._source_id, - source_name=self._source_name, - source_type=self._source_type, - ) - ] + return ( + sources + if sources + else [ + SourceInfo( + source_id=self._source_id, + source_name=self._source_name, + source_type=self._source_type, + ) + ] + ) except Exception as e: logger.error(f"Feishu list_sources error: {e}") return [ diff --git a/src/agentkit/memory/adapters/generic_http.py b/src/agentkit/memory/adapters/generic_http.py index 38e8b9c..41dd26b 100644 --- a/src/agentkit/memory/adapters/generic_http.py +++ b/src/agentkit/memory/adapters/generic_http.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from typing import Any import httpx @@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter): ) self._endpoint_url = endpoint_url.rstrip("/") if not is_safe_url(self._endpoint_url): - raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.") + raise ValueError( + f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed." + ) self._auth_config = auth_config or {} self._extra_headers = headers or {} @@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter): headers["Authorization"] = f"Bearer {token}" elif auth_type == "basic": import base64 + username = self._auth_config.get("username", "") password = self._auth_config.get("password", "") if username and password: - credentials = base64.b64encode( - f"{username}:{password}".encode() - ).decode() + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() headers["Authorization"] = f"Basic {credentials}" elif auth_type == "api_key": key_name = self._auth_config.get("header_name", "X-API-Key") @@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter): except httpx.HTTPStatusError as e: logger.error( - f"GenericHTTP search HTTP error: {e.response.status_code} — " - f"{e.response.text[:200]}" + f"GenericHTTP search HTTP error: {e.response.status_code} — {e.response.text[:200]}" ) return [] except Exception as e: @@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter): except httpx.HTTPStatusError as e: logger.error( - f"GenericHTTP ingest HTTP error: {e.response.status_code} — " - f"{e.response.text[:200]}" + f"GenericHTTP ingest HTTP error: {e.response.status_code} — {e.response.text[:200]}" ) return [] except Exception as e: @@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter): document_count=item.get("document_count", 0), ) ) - return sources if sources else [ - SourceInfo( - source_id=self._source_id, - source_name=self._source_name, - source_type=self._source_type, - ) - ] + return ( + sources + if sources + else [ + SourceInfo( + source_id=self._source_id, + source_name=self._source_name, + source_type=self._source_type, + ) + ] + ) except Exception as e: logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}") diff --git a/src/agentkit/memory/base.py b/src/agentkit/memory/base.py index 930a933..f3b9ae1 100644 --- a/src/agentkit/memory/base.py +++ b/src/agentkit/memory/base.py @@ -3,19 +3,29 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Any +from typing import TypeAlias + +# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。 +# MetadataValue 覆盖 metadata dict 中实际出现的原始类型; +# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。 +MetadataValue: TypeAlias = str | int | float | bool | None +MetadataDict: TypeAlias = dict[str, MetadataValue] +MemoryValue: TypeAlias = ( + str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue] +) @dataclass class MemoryItem: """记忆条目""" + key: str - value: Any - metadata: dict[str, Any] = field(default_factory=dict) + value: object + metadata: MetadataDict = field(default_factory=dict) score: float = 1.0 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, object]: return { "key": self.key, "value": self.value, @@ -35,7 +45,7 @@ class Memory(ABC): """ @abstractmethod - async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None: """存储记忆""" ... @@ -45,7 +55,9 @@ class Memory(ABC): ... @abstractmethod - async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + async def search( + self, query: str, top_k: int = 5, filters: MetadataDict | None = None + ) -> list[MemoryItem]: """语义检索""" ... @@ -54,7 +66,7 @@ class Memory(ABC): """删除记忆""" ... - async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None: + async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None: """批量存储""" for key, value, metadata in items: await self.store(key, value, metadata) diff --git a/src/agentkit/memory/chunking.py b/src/agentkit/memory/chunking.py index ad0dbe4..67d0bd5 100644 --- a/src/agentkit/memory/chunking.py +++ b/src/agentkit/memory/chunking.py @@ -11,10 +11,18 @@ import logging import re import uuid from dataclasses import dataclass, field -from typing import Any +from typing import TypeAlias + +from agentkit.memory.base import MetadataDict logger = logging.getLogger(__name__) +# 分块元数据:source_doc/position/char_count/chunking_strategy/heading/heading_level +# — 全部为原始标量(str/int)。 +ChunkMetadata: TypeAlias = MetadataDict +# _split_by_headings 返回的节段结构。 +SectionInfo: TypeAlias = dict[str, str | int] + @dataclass class Chunk: @@ -22,7 +30,7 @@ class Chunk: chunk_id: str content: str - metadata: dict[str, Any] = field(default_factory=dict) + metadata: ChunkMetadata = field(default_factory=dict) def __post_init__(self) -> None: if "source_doc" not in self.metadata: @@ -30,7 +38,7 @@ class Chunk: if "position" not in self.metadata: self.metadata["position"] = 0 - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, object]: return { "chunk_id": self.chunk_id, "content": self.content, @@ -57,7 +65,9 @@ class TextChunker: separator: 优先分割符 """ if chunk_overlap >= chunk_size: - raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})") + raise ValueError( + f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})" + ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._separator = separator @@ -66,7 +76,7 @@ class TextChunker: self, text: str, source_doc_id: str = "", - metadata: dict[str, Any] | None = None, + metadata: ChunkMetadata | None = None, ) -> list[Chunk]: """将文本分块 @@ -96,11 +106,13 @@ class TextChunker: chunk_meta = dict(base_meta) chunk_meta["position"] = i chunk_meta["char_count"] = len(chunk_text) - chunks.append(Chunk( - chunk_id=str(uuid.uuid4()), - content=chunk_text, - metadata=chunk_meta, - )) + chunks.append( + Chunk( + chunk_id=str(uuid.uuid4()), + content=chunk_text, + metadata=chunk_meta, + ) + ) return chunks @@ -142,7 +154,9 @@ class TextChunker: overlap_text[overlap_start:], segments ) current = overlap_segments - current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1) + current_len = sum(len(s) for s in current) + len(self._separator) * max( + 0, len(current) - 1 + ) current.append(segment) current_len += seg_len + len(self._separator) @@ -214,7 +228,7 @@ class StructuralChunker: self, text: str, source_doc_id: str = "", - metadata: dict[str, Any] | None = None, + metadata: ChunkMetadata | None = None, ) -> list[Chunk]: """将文本按结构分块 @@ -266,23 +280,25 @@ class StructuralChunker: chunk_meta["heading"] = heading chunk_meta["heading_level"] = level chunk_meta["char_count"] = len(content) - chunks.append(Chunk( - chunk_id=str(uuid.uuid4()), - content=content, - metadata=chunk_meta, - )) + chunks.append( + Chunk( + chunk_id=str(uuid.uuid4()), + content=content, + metadata=chunk_meta, + ) + ) position += 1 return chunks - def _split_by_headings(self, text: str) -> list[dict[str, Any]]: + def _split_by_headings(self, text: str) -> list[SectionInfo]: """按标题分割 Markdown 文本 Returns: 列表,每项包含 heading, content, level """ lines = text.split("\n") - sections: list[dict[str, Any]] = [] + sections: list[SectionInfo] = [] current_heading = "" current_level = 0 current_lines: list[str] = [] @@ -296,11 +312,13 @@ class StructuralChunker: if current_lines: content = "\n".join(current_lines).strip() if content: - sections.append({ - "heading": current_heading, - "content": content, - "level": current_level, - }) + sections.append( + { + "heading": current_heading, + "content": content, + "level": current_level, + } + ) # 开始新节 current_heading = match.group(2).strip() @@ -313,18 +331,22 @@ class StructuralChunker: if current_lines: content = "\n".join(current_lines).strip() if content: - sections.append({ - "heading": current_heading, - "content": content, - "level": current_level, - }) + sections.append( + { + "heading": current_heading, + "content": content, + "level": current_level, + } + ) # 如果没有标题结构,整体作为一个块 if not sections: - sections.append({ - "heading": "", - "content": text.strip(), - "level": 0, - }) + sections.append( + { + "heading": "", + "content": text.strip(), + "level": 0, + } + ) return sections diff --git a/src/agentkit/memory/contextual_retrieval.py b/src/agentkit/memory/contextual_retrieval.py index 93eb47f..aad0388 100644 --- a/src/agentkit/memory/contextual_retrieval.py +++ b/src/agentkit/memory/contextual_retrieval.py @@ -9,10 +9,14 @@ from __future__ import annotations import hashlib import logging from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING +from agentkit.memory.base import MetadataDict from agentkit.memory.embedder import EmbeddingCache +if TYPE_CHECKING: + from agentkit.llm.gateway import LLMGateway + logger = logging.getLogger(__name__) @@ -24,7 +28,7 @@ class ContextualChunk: context_prefix: str enhanced_content: str chunk_index: int - metadata: dict[str, Any] + metadata: MetadataDict @property def content(self) -> str: @@ -65,7 +69,7 @@ class ContextualChunker: def __init__( self, - llm_gateway: Any = None, + llm_gateway: LLMGateway | None = None, cache: EmbeddingCache | None = None, batch_size: int = 8, max_context_length: int = 200, @@ -90,7 +94,7 @@ class ContextualChunker: self, document: str, chunks: list[str], - metadata: dict[str, Any] | None = None, + metadata: MetadataDict | None = None, ) -> list[ContextualChunk]: """为文档块添加上下文前缀 @@ -134,7 +138,7 @@ class ContextualChunker: document: str, chunks: list[str], start_index: int, - metadata: dict[str, Any] | None, + metadata: MetadataDict | None, ) -> list[ContextualChunk]: """处理一批文档块""" results: list[ContextualChunk] = [] diff --git a/src/agentkit/memory/document_loader.py b/src/agentkit/memory/document_loader.py index 522bd53..d991221 100644 --- a/src/agentkit/memory/document_loader.py +++ b/src/agentkit/memory/document_loader.py @@ -12,7 +12,9 @@ import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import TypeAlias + +from agentkit.memory.base import MetadataDict logger = logging.getLogger(__name__) @@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB MAX_ROWS_PER_SHEET = 10_000 MAX_CELL_CHARS = 10_000 +# 文档元数据:source/format/parser/page_count/table_count/sheet_count/row_count/ +# heading_count/created_at/title/truncated — 全部为原始标量。 +DocumentMetadata: TypeAlias = MetadataDict +ParseResult: TypeAlias = tuple[str, DocumentMetadata] + @dataclass class Document: @@ -31,7 +38,7 @@ class Document: doc_id: str title: str content: str - metadata: dict[str, Any] = field(default_factory=dict) + metadata: DocumentMetadata = field(default_factory=dict) def __post_init__(self) -> None: if "source" not in self.metadata: @@ -43,7 +50,7 @@ class Document: if "created_at" not in self.metadata: self.metadata["created_at"] = datetime.now(timezone.utc).isoformat() - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, object]: return { "doc_id": self.doc_id, "title": self.title, @@ -136,12 +143,14 @@ class DocumentLoader: parser = parsers.get(doc_format) if parser is None: - logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text") + logger.warning( + f"Unsupported format '{doc_format}' for {filename}, falling back to text" + ) parser = self._parse_text text, extra_meta = parser(content, filename) - metadata: dict[str, Any] = { + metadata: DocumentMetadata = { "source": filename, "format": doc_format, "created_at": datetime.now(timezone.utc).isoformat(), @@ -159,7 +168,7 @@ class DocumentLoader: metadata=metadata, ) - def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_pdf(self, content: bytes, filename: str) -> ParseResult: """解析 PDF 文件 优先使用 PyMuPDF (fitz),回退到 pdfplumber,最终回退到纯文本。 @@ -215,7 +224,7 @@ class DocumentLoader: logger.warning(f"No PDF parser available for {filename}, falling back to text extraction") return self._parse_text(content, filename) - def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_docx(self, content: bytes, filename: str) -> ParseResult: """解析 Word 文件 使用 python-docx,回退到纯文本。 @@ -259,7 +268,7 @@ class DocumentLoader: logger.warning(f"python-docx parsing failed for {filename}: {e}") return self._parse_text(content, filename) - def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult: """解析 Excel 文件 使用 openpyxl,回退到纯文本。每个 sheet 转为 Markdown 表格, @@ -313,7 +322,7 @@ class DocumentLoader: finally: wb.close() text = "\n".join(sections).strip() - meta: dict[str, Any] = { + meta: DocumentMetadata = { "parser": "openpyxl", "sheet_count": sheet_count, "row_count": total_rows, @@ -328,7 +337,7 @@ class DocumentLoader: logger.warning(f"openpyxl parsing failed for {filename}: {e}") return self._parse_text(content, filename) - def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_markdown(self, content: bytes, filename: str) -> ParseResult: """解析 Markdown 文件 使用 mistune(如果可用),否则直接读取文本。 @@ -347,7 +356,7 @@ class DocumentLoader: title = line_stripped.lstrip("#").strip() break - meta: dict[str, Any] = { + meta: DocumentMetadata = { "parser": "markdown", } if title: @@ -362,7 +371,7 @@ class DocumentLoader: return text, meta - def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_html(self, content: bytes, filename: str) -> ParseResult: """解析 HTML 文件 使用 BeautifulSoup 提取文本,回退到纯文本。 @@ -388,7 +397,7 @@ class DocumentLoader: if soup.title and soup.title.string: title = soup.title.string.strip() - meta: dict[str, Any] = { + meta: DocumentMetadata = { "parser": "beautifulsoup", } if title: @@ -402,7 +411,7 @@ class DocumentLoader: logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}") return self._parse_text(content, filename) - def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: + def _parse_text(self, content: bytes, filename: str) -> ParseResult: """解析纯文本文件""" try: text = content.decode("utf-8") diff --git a/src/agentkit/memory/embedder.py b/src/agentkit/memory/embedder.py index 203ee69..6c47c39 100644 --- a/src/agentkit/memory/embedder.py +++ b/src/agentkit/memory/embedder.py @@ -1,12 +1,17 @@ """Embedder 接口与实现 - 文本向量化""" +from __future__ import annotations + import hashlib import logging import os import time from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import httpx logger = logging.getLogger(__name__) @@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder): self._model = model self._base_url = base_url self._dimension = 1536 # text-embedding-3-small 默认维度 - self._client: Any = None + self._client: httpx.AsyncClient | None = None self._cache = cache - def _get_client(self): + def _get_client(self) -> httpx.AsyncClient: """Lazily create and reuse a single httpx.AsyncClient.""" if self._client is None: import httpx + self._client = httpx.AsyncClient(timeout=30.0) return self._client diff --git a/src/agentkit/memory/episodic.py b/src/agentkit/memory/episodic.py index f6817e4..ba8a4c5 100644 --- a/src/agentkit/memory/episodic.py +++ b/src/agentkit/memory/episodic.py @@ -1,17 +1,22 @@ """Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆""" +from __future__ import annotations + import json import logging import math from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING from sqlalchemy import text -from agentkit.memory.base import Memory, MemoryItem +from agentkit.memory.base import Memory, MemoryItem, MetadataDict from agentkit.memory.embedder import Embedder from agentkit.utils.vector_math import compute_cosine_similarity +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + logger = logging.getLogger(__name__) @@ -28,8 +33,8 @@ class EpisodicMemory(Memory): def __init__( self, - session_factory: Any, - episodic_model: Any, + session_factory: object, + episodic_model: object, embedder: Embedder | None = None, decay_rate: float = 0.01, alpha: float = 0.7, @@ -57,7 +62,7 @@ class EpisodicMemory(Memory): self._pgvector_enabled = pgvector_enabled self._table_name = table_name - async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None: """存储任务经验""" async with self._session_factory() as db: try: @@ -68,7 +73,11 @@ class EpisodicMemory(Memory): embedding = None if self._embedder: if isinstance(value, dict): - text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500] + text = ( + value.get("output_summary", "") + or value.get("input_summary", "") + or json.dumps(value, ensure_ascii=False)[:500] + ) else: text = str(value) embedding = await self._embedder.embed(text) @@ -106,13 +115,11 @@ class EpisodicMemory(Memory): logger.error(f"Failed to retrieve episodic memory: {e}") return None - async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: + async def _retrieve_pgvector( + self, db: AsyncSession, query_embedding: list[float] + ) -> MemoryItem | None: """使用 pgvector ``<=>`` 算符检索最相似条目""" - sql = text( - f"SELECT * FROM {self._table_name} " - f"ORDER BY embedding <=> :query_vec " - f"LIMIT :lim" - ) + sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim") result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1}) row = result.mappings().first() @@ -147,7 +154,9 @@ class EpisodicMemory(Memory): created_at=row.get("created_at") or datetime.now(timezone.utc), ) - async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: + async def _retrieve_client_side( + self, db: AsyncSession, query_embedding: list[float] + ) -> MemoryItem | None: """客户端 O(N) cosine similarity 检索(回退路径)""" Model = self._episodic_model from sqlalchemy import select @@ -193,7 +202,13 @@ class EpisodicMemory(Memory): created_at=best_item.created_at or datetime.now(timezone.utc), ) - async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]: + async def search( + self, + query: str, + top_k: int = 5, + filters: MetadataDict | None = None, + search_multiplier: int = 5, + ) -> list[MemoryItem]: """语义检索相似历史案例 Args: @@ -214,10 +229,10 @@ class EpisodicMemory(Memory): async def _search_pgvector( self, - db: Any, + db: AsyncSession, query: str, top_k: int, - filters: dict[str, Any] | None, + filters: MetadataDict | None, search_multiplier: int, ) -> list[MemoryItem]: """使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排""" @@ -225,7 +240,7 @@ class EpisodicMemory(Memory): fetch_limit = top_k * search_multiplier where_clauses = [] - params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit} + params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit} filters = filters or {} if filters.get("agent_name"): @@ -256,7 +271,11 @@ class EpisodicMemory(Memory): items = [] for row in rows: row_embedding = row.get("embedding") - age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0 + age_hours = ( + (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 + if row.get("created_at") + else 0 + ) decay = math.exp(-self._decay_rate * age_hours) time_decay_score = (row.get("quality_score") or 0.5) * decay @@ -266,33 +285,37 @@ class EpisodicMemory(Memory): else: score = time_decay_score - items.append(MemoryItem( - key=str(row.get("id", "")), - value={ - "input_summary": row.get("input_summary", ""), - "output_summary": row.get("output_summary", ""), - "outcome": row.get("outcome", "success"), - "quality_score": row.get("quality_score", 0.5), - "reflection": row.get("reflection", ""), - }, - metadata={ - "agent_name": row.get("agent_name", ""), - "task_type": row.get("task_type", ""), - "created_at": row["created_at"].isoformat() if row.get("created_at") else None, - }, - score=score, - created_at=row.get("created_at") or datetime.now(timezone.utc), - )) + items.append( + MemoryItem( + key=str(row.get("id", "")), + value={ + "input_summary": row.get("input_summary", ""), + "output_summary": row.get("output_summary", ""), + "outcome": row.get("outcome", "success"), + "quality_score": row.get("quality_score", 0.5), + "reflection": row.get("reflection", ""), + }, + metadata={ + "agent_name": row.get("agent_name", ""), + "task_type": row.get("task_type", ""), + "created_at": row["created_at"].isoformat() + if row.get("created_at") + else None, + }, + score=score, + created_at=row.get("created_at") or datetime.now(timezone.utc), + ) + ) items.sort(key=lambda x: x.score, reverse=True) return items[:top_k] async def _search_client_side( self, - db: Any, + db: AsyncSession, query: str, top_k: int, - filters: dict[str, Any] | None, + filters: MetadataDict | None, search_multiplier: int, ) -> list[MemoryItem]: """客户端 O(N) cosine similarity 检索(回退路径)""" @@ -300,6 +323,7 @@ class EpisodicMemory(Memory): filters = filters or {} from sqlalchemy import select + stmt = select(Model) if filters.get("agent_name"): @@ -322,7 +346,11 @@ class EpisodicMemory(Memory): # 计算得分并构建 MemoryItem items = [] for entry in entries: - age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0 + age_hours = ( + (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 + if entry.created_at + else 0 + ) decay = math.exp(-self._decay_rate * age_hours) time_decay_score = (entry.quality_score or 0.5) * decay @@ -333,30 +361,34 @@ class EpisodicMemory(Memory): else: score = time_decay_score - items.append(MemoryItem( - key=str(entry.id), - value={ - "input_summary": entry.input_summary, - "output_summary": entry.output_summary, - "outcome": entry.outcome, - "quality_score": entry.quality_score, - "reflection": entry.reflection, - }, - metadata={ - "agent_name": entry.agent_name, - "task_type": entry.task_type, - "created_at": entry.created_at.isoformat() if entry.created_at else None, - }, - score=score, - created_at=entry.created_at or datetime.now(timezone.utc), - )) + items.append( + MemoryItem( + key=str(entry.id), + value={ + "input_summary": entry.input_summary, + "output_summary": entry.output_summary, + "outcome": entry.outcome, + "quality_score": entry.quality_score, + "reflection": entry.reflection, + }, + metadata={ + "agent_name": entry.agent_name, + "task_type": entry.task_type, + "created_at": entry.created_at.isoformat() if entry.created_at else None, + }, + score=score, + created_at=entry.created_at or datetime.now(timezone.utc), + ) + ) items.sort(key=lambda x: x.score, reverse=True) if len(items) < top_k: logger.warning( "EpisodicMemory.search returned %d results after scoring (top_k=%d). " "Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.", - len(items), top_k, search_multiplier, + len(items), + top_k, + search_multiplier, ) return items[:top_k] @@ -364,8 +396,9 @@ class EpisodicMemory(Memory): """删除指定经验""" async with self._session_factory() as db: try: - from sqlalchemy import select, delete as sql_delete + from sqlalchemy import delete as sql_delete import uuid + Model = self._episodic_model stmt = sql_delete(Model).where(Model.id == uuid.UUID(key)) diff --git a/src/agentkit/memory/http_rag.py b/src/agentkit/memory/http_rag.py index 2e4d94f..9f170c8 100644 --- a/src/agentkit/memory/http_rag.py +++ b/src/agentkit/memory/http_rag.py @@ -3,13 +3,26 @@ 配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。 """ +from __future__ import annotations + import logging -from typing import Any +from typing import TYPE_CHECKING, TypeAlias import httpx +from agentkit.memory.base import MetadataDict + +if TYPE_CHECKING: + from agentkit.llm.gateway import LLMGateway + logger = logging.getLogger(__name__) +# 标准化检索结果:id/content/score/source/document_id/document_title/ +# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。 +RAGSearchResult: TypeAlias = dict[str, object] +# ingest() 写入的文档负载:title/content/source_type/metadata。 +RAGIngestPayload: TypeAlias = dict[str, object] + class HttpRAGService: """HTTP 客户端,调用业务系统的知识库检索 API @@ -39,7 +52,7 @@ class HttpRAGService: knowledge_base_ids: list[str] | None = None, timeout: int = 30, contextual_chunking: bool = False, - llm_gateway: Any = None, + llm_gateway: LLMGateway | None = None, ): """ Args: @@ -74,7 +87,7 @@ class HttpRAGService: query: str, knowledge_base_ids: list[str] | None = None, top_k: int = 5, - ) -> list[dict[str, Any]]: + ) -> list[RAGSearchResult]: """语义检索知识库 Args: @@ -113,19 +126,23 @@ class HttpRAGService: normalized = [] for r in results: if isinstance(r, dict): - normalized.append({ - "id": r.get("chunk_id", r.get("id", "")), - "content": r.get("content", ""), - "score": float(r.get("score", 0.0)), - "source": r.get("source", "rag"), - "document_id": r.get("document_id", ""), - "document_title": r.get("document_title", ""), - "metadata": r.get("metadata", {}), - }) + normalized.append( + { + "id": r.get("chunk_id", r.get("id", "")), + "content": r.get("content", ""), + "score": float(r.get("score", 0.0)), + "source": r.get("source", "rag"), + "document_id": r.get("document_id", ""), + "document_title": r.get("document_title", ""), + "metadata": r.get("metadata", {}), + } + ) return normalized except httpx.HTTPStatusError as e: - logger.error(f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}") + logger.error( + f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}" + ) return [] except httpx.RequestError as e: logger.error(f"RAG search request error: {e}") @@ -141,7 +158,7 @@ class HttpRAGService: top_k: int = 5, use_rerank: bool = True, use_compression: bool = False, - ) -> list[dict[str, Any]]: + ) -> list[RAGSearchResult]: """增强语义检索知识库(支持 rerank 和 compression) 对每个知识库分别调用 /bases/{kb_id}/retrieve 接口, @@ -169,7 +186,7 @@ class HttpRAGService: } client = self._get_client() - all_results: list[dict[str, Any]] = [] + all_results: list[RAGSearchResult] = [] for kb_id in kb_ids: try: @@ -189,28 +206,27 @@ class HttpRAGService: # 标准化 for r in results: if isinstance(r, dict): - all_results.append({ - "id": r.get("chunk_id", r.get("id", "")), - "content": r.get("content", ""), - "score": float(r.get("score", 0.0)), - "source": r.get("source", "rag"), - "document_id": r.get("document_id", ""), - "document_title": r.get("document_title", ""), - "knowledge_base_id": kb_id, - "metadata": r.get("metadata", {}), - }) + all_results.append( + { + "id": r.get("chunk_id", r.get("id", "")), + "content": r.get("content", ""), + "score": float(r.get("score", 0.0)), + "source": r.get("source", "rag"), + "document_id": r.get("document_id", ""), + "document_title": r.get("document_title", ""), + "knowledge_base_id": kb_id, + "metadata": r.get("metadata", {}), + } + ) except httpx.HTTPStatusError as e: if e.response.status_code == 404: # This KB doesn't support enhanced search — fall back to # standard search for THIS KB only, not all KBs. logger.info( - f"Enhanced search not available for KB {kb_id}, " - f"using standard search" - ) - std_result = await self.search( - query, knowledge_base_ids=[kb_id], top_k=top_k + f"Enhanced search not available for KB {kb_id}, using standard search" ) + std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k) all_results.extend(std_result) else: logger.error( @@ -232,9 +248,9 @@ class HttpRAGService: async def ingest( self, key: str, - value: Any, - metadata: dict[str, Any] | None = None, - ) -> dict[str, Any] | None: + value: object, + metadata: MetadataDict | None = None, + ) -> dict[str, object] | None: """写入文档到知识库(可选操作) When contextual_chunking is enabled and llm_gateway is configured, @@ -308,5 +324,5 @@ class HttpRAGService: async def __aenter__(self) -> "HttpRAGService": return self - async def __aexit__(self, *args: Any) -> None: + async def __aexit__(self, *args: object) -> None: await self.close() diff --git a/src/agentkit/memory/knowledge_base.py b/src/agentkit/memory/knowledge_base.py index 84ef46d..2d23b94 100644 --- a/src/agentkit/memory/knowledge_base.py +++ b/src/agentkit/memory/knowledge_base.py @@ -11,7 +11,11 @@ from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Any, Protocol, runtime_checkable +from typing import Protocol, TypeAlias, runtime_checkable + +from agentkit.memory.base import MetadataDict + +KBMetadata: TypeAlias = MetadataDict @dataclass @@ -22,7 +26,7 @@ class Document: content: str title: str = "" source_id: str = "" - metadata: dict[str, Any] = field(default_factory=dict) + metadata: KBMetadata = field(default_factory=dict) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) @@ -34,7 +38,7 @@ class QueryResult: source_id: str source_name: str score: float - metadata: dict[str, Any] = field(default_factory=dict) + metadata: KBMetadata = field(default_factory=dict) doc_id: str = "" title: str = "" diff --git a/src/agentkit/memory/local_rag.py b/src/agentkit/memory/local_rag.py index d9a02c1..8820131 100644 --- a/src/agentkit/memory/local_rag.py +++ b/src/agentkit/memory/local_rag.py @@ -11,25 +11,32 @@ from __future__ import annotations import json import logging import re -import uuid from datetime import datetime, timezone -from typing import Any - -_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') +from typing import TYPE_CHECKING, TypeAlias from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.embedder import Embedder from agentkit.memory.knowledge_base import ( Document, - KnowledgeBase, QueryResult, SourceInfo, ) from agentkit.utils.vector_math import compute_cosine_similarity +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + logger = logging.getLogger(__name__) +# InMemoryLocalRAGService 内部存储的文档元信息结构。 +# 字段:title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。 +InMemoryDocInfo: TypeAlias = dict[str, object] +# 内部 chunk 存储结构:content/embedding/metadata/source_doc_id。 +InMemoryChunkInfo: TypeAlias = dict[str, object] + def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document: """将 document_loader.Document 转换为 knowledge_base.Document""" @@ -53,7 +60,7 @@ class LocalRAGService: def __init__( self, - session_factory: Any, + session_factory: object, embedder: Embedder, chunk_size: int = 1000, chunk_overlap: int = 200, @@ -75,10 +82,14 @@ class LocalRAGService: self._chunk_overlap = chunk_overlap self._table_name = table_name if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): - raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") + raise ValueError( + f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*" + ) self._pgvector_enabled = pgvector_enabled self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + self._structural_chunker = StructuralChunker( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) async def ingest(self, documents: list[Document]) -> list[str]: """摄取文档列表 @@ -136,9 +147,7 @@ class LocalRAGService: try: from sqlalchemy import text as sql_text - sql = sql_text( - f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id" - ) + sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id") await db.execute(sql, {"doc_id": id}) await db.commit() return True @@ -171,20 +180,15 @@ class LocalRAGService: sources = [] for row in rows: - meta = {} - if row.get("doc_metadata"): - try: - meta = json.loads(row["doc_metadata"]) - except (json.JSONDecodeError, TypeError): - pass - - sources.append(SourceInfo( - source_id=row["source_doc_id"], - source_name=row.get("source_title", ""), - source_type=row.get("doc_format", "local"), - document_count=row.get("chunk_count", 0), - last_updated=row["created_at"] if row.get("created_at") else None, - )) + sources.append( + SourceInfo( + source_id=row["source_doc_id"], + source_name=row.get("source_title", ""), + source_type=row.get("doc_format", "local"), + document_count=row.get("chunk_count", 0), + last_updated=row["created_at"] if row.get("created_at") else None, + ) + ) return sources except Exception as e: logger.error(f"Failed to list sources: {e}") @@ -271,7 +275,7 @@ class LocalRAGService: async def _query_pgvector( self, - db: Any, + db: AsyncSession, query_embedding: list[float], top_k: int, ) -> list[QueryResult]: @@ -286,10 +290,13 @@ class LocalRAGService: f"LIMIT :lim" ) - result = await db.execute(sql, { - "query_vec": str(query_embedding), - "lim": top_k, - }) + result = await db.execute( + sql, + { + "query_vec": str(query_embedding), + "lim": top_k, + }, + ) rows = result.mappings().all() results = [] @@ -306,21 +313,23 @@ class LocalRAGService: except (json.JSONDecodeError, TypeError): pass - results.append(QueryResult( - content=row["content"], - source_id=row["source_doc_id"], - source_name=row.get("source_title", ""), - score=cosine, - metadata=chunk_meta, - doc_id=row["source_doc_id"], - title=row.get("source_title", ""), - )) + results.append( + QueryResult( + content=row["content"], + source_id=row["source_doc_id"], + source_name=row.get("source_title", ""), + score=cosine, + metadata=chunk_meta, + doc_id=row["source_doc_id"], + title=row.get("source_title", ""), + ) + ) return results async def _query_client_side( self, - db: Any, + db: AsyncSession, query_embedding: list[float], top_k: int, ) -> list[QueryResult]: @@ -363,15 +372,17 @@ class LocalRAGService: except (json.JSONDecodeError, TypeError): pass - candidates.append(QueryResult( - content=row["content"], - source_id=row["source_doc_id"], - source_name=row.get("source_title", ""), - score=cosine, - metadata=chunk_meta, - doc_id=row["source_doc_id"], - title=row.get("source_title", ""), - )) + candidates.append( + QueryResult( + content=row["content"], + source_id=row["source_doc_id"], + source_name=row.get("source_title", ""), + score=cosine, + metadata=chunk_meta, + doc_id=row["source_doc_id"], + title=row.get("source_title", ""), + ) + ) candidates.sort(key=lambda x: x.score, reverse=True) return candidates[:top_k] @@ -398,11 +409,15 @@ class InMemoryLocalRAGService: """ self._embedder = embedder self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + self._structural_chunker = StructuralChunker( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) # 内存存储 - self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata} - self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at} + self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata} + self._documents: dict[ + str, InMemoryDocInfo + ] = {} # doc_id → {title, format, chunk_ids, metadata, created_at} async def ingest(self, documents: list[Document]) -> list[str]: """摄取文档列表 @@ -459,15 +474,17 @@ class InMemoryLocalRAGService: source_doc_id = chunk_data["source_doc_id"] doc_info = self._documents.get(source_doc_id, {}) - candidates.append(QueryResult( - content=chunk_data["content"], - source_id=source_doc_id, - source_name=doc_info.get("title", ""), - score=cosine, - metadata=chunk_data.get("metadata", {}), - doc_id=source_doc_id, - title=doc_info.get("title", ""), - )) + candidates.append( + QueryResult( + content=chunk_data["content"], + source_id=source_doc_id, + source_name=doc_info.get("title", ""), + score=cosine, + metadata=chunk_data.get("metadata", {}), + doc_id=source_doc_id, + title=doc_info.get("title", ""), + ) + ) candidates.sort(key=lambda x: x.score, reverse=True) return candidates[:top_k] @@ -488,13 +505,15 @@ class InMemoryLocalRAGService: """列出已摄取的文档""" sources = [] for doc_id, doc_info in self._documents.items(): - sources.append(SourceInfo( - source_id=doc_id, - source_name=doc_info["title"], - source_type=doc_info.get("format", "local"), - document_count=len(doc_info.get("chunk_ids", [])), - last_updated=doc_info.get("created_at"), - )) + sources.append( + SourceInfo( + source_id=doc_id, + source_name=doc_info["title"], + source_type=doc_info.get("format", "local"), + document_count=len(doc_info.get("chunk_ids", [])), + last_updated=doc_info.get("created_at"), + ) + ) return sources async def health_check(self) -> bool: diff --git a/src/agentkit/memory/multi_source_retriever.py b/src/agentkit/memory/multi_source_retriever.py index 6ccf35d..ac0d40f 100644 --- a/src/agentkit/memory/multi_source_retriever.py +++ b/src/agentkit/memory/multi_source_retriever.py @@ -13,7 +13,6 @@ import asyncio import hashlib import logging from dataclasses import replace -from typing import Any from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo @@ -186,15 +185,13 @@ class MultiSourceRetriever: Returns: 所有源的检索结果列表(已应用权重) """ + async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]: try: results = await kb.query(query, top_k=top_k) # 应用权重 weight = (weights or {}).get(name, 1.0) - return [ - replace(r, score=r.score * weight, source_name=name) - for r in results - ] + return [replace(r, score=r.score * weight, source_name=name) for r in results] except Exception as e: logger.error(f"Query failed for source '{name}': {e}") return [] diff --git a/src/agentkit/memory/profile.py b/src/agentkit/memory/profile.py index 6f7ace8..b7050d5 100644 --- a/src/agentkit/memory/profile.py +++ b/src/agentkit/memory/profile.py @@ -7,10 +7,10 @@ from __future__ import annotations import re -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Any, Callable +from typing import Callable class MemoryFile: @@ -26,8 +26,9 @@ class MemoryFile: """ - def __init__(self, path: Path, char_budget: int | None = None, - protected_sections: set[str] | None = None): + def __init__( + self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None + ): self.path = Path(path) self.char_budget = char_budget self._protected_sections = protected_sections or set() @@ -138,7 +139,7 @@ class MemoryFile: for match in re.finditer(r"^## (.+)$", content, re.MULTILINE): name = match.group(1).strip() start = match.start() - next_match = re.search(r"^## ", content[match.end():], re.MULTILINE) + next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE) if next_match: end = match.end() + next_match.start() else: @@ -146,7 +147,7 @@ class MemoryFile: sections.append((name, start, end)) if not sections: - return content[:self.char_budget] + return content[: self.char_budget] # 保持原始顺序,标记每个 section 是否受保护 ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected) @@ -222,8 +223,9 @@ class MemoryStore: """ - def __init__(self, base_dir: Path | str | None = None, - on_change: Callable[[str], None] | None = None): + def __init__( + self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None + ): if base_dir is None: base_dir = Path.home() / ".agentkit" self.base_dir = Path(base_dir) @@ -238,7 +240,9 @@ class MemoryStore: protected_sections={"版本", "更新历史"}, ) self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET) - self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET) + self._memory = MemoryFile( + self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET + ) self._daily_dir = self.base_dir / "memories" / "daily" self._daily_dir.mkdir(parents=True, exist_ok=True) @@ -376,4 +380,5 @@ class MemoryStore: self._on_change(new_prompt) except Exception: import logging + logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True) diff --git a/src/agentkit/memory/query_transformer.py b/src/agentkit/memory/query_transformer.py index 4bab9e6..defbb48 100644 --- a/src/agentkit/memory/query_transformer.py +++ b/src/agentkit/memory/query_transformer.py @@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase): """ _FILLER_WORDS_CN: list[str] = [ - "帮我", "请", "一下", "分析", "看看", "告诉我", "想知道", "请问", + "帮我", + "请", + "一下", + "分析", + "看看", + "告诉我", + "想知道", + "请问", ] _FILLER_WORDS_EN: list[str] = [ - "please", "can you", "help me", "could you", "i want to", "i need to", + "please", + "can you", + "help me", + "could you", + "i want to", + "i need to", ] def __init__( @@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase): self._synonyms = synonyms or {} self._max_sub_queries = max_sub_queries # Pre-compile filler patterns - self._filler_patterns_cn = [ - re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN - ] + self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN] self._filler_patterns_en = [ re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN ] @@ -166,7 +176,9 @@ def create_query_transformer( """工厂函数:根据策略创建查询改写器""" if strategy == "llm": if llm_gateway is None: - logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp") + logger.warning( + "LLM strategy requested but no llm_gateway provided, falling back to NoOp" + ) return NoOpQueryTransformer() return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries) elif strategy == "rule": diff --git a/src/agentkit/memory/rag_loop.py b/src/agentkit/memory/rag_loop.py index b0d6074..e621bde 100644 --- a/src/agentkit/memory/rag_loop.py +++ b/src/agentkit/memory/rag_loop.py @@ -7,11 +7,11 @@ from __future__ import annotations import logging -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import TYPE_CHECKING -from agentkit.memory.base import MemoryItem +from agentkit.memory.base import MemoryItem, MetadataDict from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer from agentkit.memory.relevance_scorer import ( RelevanceScorer, @@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import ( RetrievalEvaluation, ) +if TYPE_CHECKING: + # 避免与 retriever.py 形成运行时循环导入。 + from agentkit.memory.retriever import MemoryRetriever + logger = logging.getLogger(__name__) @@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop: def __init__( self, - retriever: Any, # MemoryRetriever + retriever: MemoryRetriever, scorer: RelevanceScorer | None = None, query_transformer: QueryTransformerBase | None = None, max_retries: int = 3, @@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop: query: str, top_k: int = 5, token_budget: int = 3000, - filters: dict[str, Any] | None = None, + filters: MetadataDict | None = None, ) -> RAGLoopResult: """执行带自纠正的检索 @@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop: while retry_count <= self._max_retries: # RETRIEVE items = await self._retriever.retrieve( - current_query, top_k=top_k, token_budget=token_budget, - filters=filters, _skip_correction=True, + current_query, + top_k=top_k, + token_budget=token_budget, + filters=filters, + _skip_correction=True, ) # EVALUATE @@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop: # CORRECT — rewrite query and retry retry_count += 1 if retry_count <= self._max_retries: - current_query = await self._rewrite_query( - query, current_query, evaluation - ) + current_query = await self._rewrite_query(query, current_query, evaluation) continue # DEGRADE — exceeded max retries @@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop: # Degraded result: filter to relevant items and mark low confidence relevant_items = [ - s.item - for s in evaluation.scores - if s.verdict != RelevanceVerdict.INCORRECT + s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT ] result_items = relevant_items if relevant_items else items diff --git a/src/agentkit/memory/relevance_scorer.py b/src/agentkit/memory/relevance_scorer.py index 7866cce..79d8746 100644 --- a/src/agentkit/memory/relevance_scorer.py +++ b/src/agentkit/memory/relevance_scorer.py @@ -6,11 +6,9 @@ from __future__ import annotations import logging -import math import re from dataclasses import dataclass from enum import Enum -from typing import Any from agentkit.memory.base import MemoryItem @@ -120,9 +118,7 @@ class RelevanceScorer: reason=reason, ) - def evaluate( - self, query: str, items: list[MemoryItem] - ) -> RetrievalEvaluation: + def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation: """评估一次检索的整体质量""" if not items: return RetrievalEvaluation( @@ -134,9 +130,7 @@ class RelevanceScorer: ) scores = [self.score_item(query, item) for item in items] - relevant_count = sum( - 1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT - ) + relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT) avg_score = sum(s.score for s in scores) / len(scores) # Overall verdict based on average score and relevant ratio diff --git a/src/agentkit/memory/retriever.py b/src/agentkit/memory/retriever.py index e1b36d2..01d9110 100644 --- a/src/agentkit/memory/retriever.py +++ b/src/agentkit/memory/retriever.py @@ -7,19 +7,16 @@ from __future__ import annotations import asyncio import logging -import math from dataclasses import replace -from datetime import datetime -from typing import Any -from agentkit.memory.base import Memory, MemoryItem +from agentkit.memory.base import MemoryItem, MetadataDict from agentkit.memory.working import WorkingMemory from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.semantic import SemanticMemory from agentkit.memory.query_transformer import QueryTransformerBase from agentkit.memory.rag_loop import RAGSelfCorrectionLoop from agentkit.memory.relevance_scorer import RelevanceScorer -from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult +from agentkit.memory.knowledge_base import KnowledgeBase from agentkit.memory.multi_source_retriever import MultiSourceRetriever from agentkit.tools.base import Tool @@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int: Chinese characters typically use 1-2 tokens each. English words typically use 1 token each. """ - cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff") non_cjk = text for c in text: - if '\u4e00' <= c <= '\u9fff': - non_cjk = non_cjk.replace(c, ' ') + if "\u4e00" <= c <= "\u9fff": + non_cjk = non_cjk.replace(c, " ") word_count = len(non_cjk.split()) return cjk_count * 2 + word_count @@ -89,7 +86,7 @@ class MemoryRetriever: query: str, top_k: int = 5, token_budget: int = 3000, - filters: dict[str, Any] | None = None, + filters: MetadataDict | None = None, _skip_correction: bool = False, sources: list[str] | None = None, source_weights: dict[str, float] | None = None, @@ -121,9 +118,7 @@ class MemoryRetriever: query, top_k=top_k, token_budget=token_budget, filters=filters ) if result.degraded: - logger.warning( - f"RAG self-correction degraded after {result.total_retries} retries" - ) + logger.warning(f"RAG self-correction degraded after {result.total_retries} retries") return result.items # Query transformation if self._query_transformer is not None: @@ -139,9 +134,7 @@ class MemoryRetriever: # Sub-query search in parallel if sub_queries: - sub_tasks = [ - self._search_layers(sq, top_k, filters) for sq in sub_queries - ] + sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries] sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True) for result in sub_results: if isinstance(result, Exception): @@ -178,7 +171,7 @@ class MemoryRetriever: self, query: str, top_k: int = 5, - filters: dict[str, Any] | None = None, + filters: MetadataDict | None = None, ) -> list[MemoryItem]: """Search all configured memory layers with a single query""" tasks = [] @@ -237,18 +230,20 @@ class MemoryRetriever: # QueryResult → MemoryItem items = [] for r in kb_results: - items.append(MemoryItem( - key=r.source_id, - value=r.content, - metadata={ - **r.metadata, - "source": "rag", - "source_name": r.source_name, - "doc_id": r.doc_id, - "document_title": r.title, - }, - score=r.score, - )) + items.append( + MemoryItem( + key=r.source_id, + value=r.content, + metadata={ + **r.metadata, + "source": "rag", + "source_name": r.source_name, + "doc_id": r.doc_id, + "document_title": r.title, + }, + score=r.score, + ) + ) # Token 预算管理 selected = [] @@ -318,7 +313,9 @@ class MemoryRetriever: if source == "rag": kb_type = item.metadata.get("kb_type", "知识库") document_title = item.metadata.get("document_title", "未知文档") - return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]" + return ( + f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]" + ) elif source == "graph": return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]" elif source == "episodic": @@ -330,7 +327,7 @@ class MemoryRetriever: return f"### 参考 [来源: {source} | 相关度: {score:.2f}]" async def store_episode( - self, key: str, value: Any, metadata: dict[str, Any] | None = None + self, key: str, value: object, metadata: MetadataDict | None = None ) -> None: """Store an episode into episodic memory if available. @@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool): items = await self._retriever.retrieve(query, top_k=5) results = [] for item in items: - results.append({ - "content": item.value, - "score": item.score, - "source": item.metadata.get("source", "unknown"), - "document_title": item.metadata.get("document_title", ""), - }) + results.append( + { + "content": item.value, + "score": item.score, + "source": item.metadata.get("source", "unknown"), + "document_title": item.metadata.get("document_title", ""), + } + ) return {"query": query, "results": results, "call_count": self._call_count} except Exception as e: return {"error": str(e), "results": []} diff --git a/src/agentkit/memory/semantic.py b/src/agentkit/memory/semantic.py index 181c9e2..2fd6e5e 100644 --- a/src/agentkit/memory/semantic.py +++ b/src/agentkit/memory/semantic.py @@ -3,14 +3,45 @@ 适配器模式,对接外部 RAG 服务和知识图谱。 """ -import logging -from typing import Any +from __future__ import annotations -from agentkit.memory.base import Memory, MemoryItem +import logging +from typing import TYPE_CHECKING, Protocol + +from agentkit.memory.base import Memory, MemoryItem, MetadataDict + +if TYPE_CHECKING: + from agentkit.memory.http_rag import RAGSearchResult logger = logging.getLogger(__name__) +class _RAGServiceLike(Protocol): + """RAG 检索服务最小接口契约(duck-typed)。""" + + async def search( + self, + query: str, + knowledge_base_ids: list[str] | None = ..., + top_k: int = ..., + ) -> list[RAGSearchResult]: ... + + async def enhanced_search( + self, + query: str, + knowledge_base_ids: list[str] | None = ..., + top_k: int = ..., + use_rerank: bool = ..., + use_compression: bool = ..., + ) -> list[RAGSearchResult]: ... + + +class _GraphServiceLike(Protocol): + """知识图谱服务最小接口契约(duck-typed)。""" + + async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ... + + class SemanticMemory(Memory): """Semantic Memory - 知识库检索 @@ -19,8 +50,8 @@ class SemanticMemory(Memory): def __init__( self, - rag_service: Any = None, - graph_service: Any = None, + rag_service: _RAGServiceLike | None = None, + graph_service: _GraphServiceLike | None = None, knowledge_base_ids: list[str] | None = None, search_mode: str = "standard", use_rerank: bool = True, @@ -45,9 +76,9 @@ class SemanticMemory(Memory): self._use_compression = use_compression self._kb_weights = kb_weights - async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None: """Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法""" - if self._rag_service and hasattr(self._rag_service, 'ingest'): + if self._rag_service and hasattr(self._rag_service, "ingest"): await self._rag_service.ingest(key, value, metadata) else: logger.warning("SemanticMemory.store: no RAG service configured for writing") @@ -56,7 +87,9 @@ class SemanticMemory(Memory): """按 key 精确检索(Semantic Memory 通常不按 key 检索)""" return None - async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + async def search( + self, query: str, top_k: int = 5, filters: MetadataDict | None = None + ) -> list[MemoryItem]: """语义检索知识库""" items = [] @@ -64,7 +97,9 @@ class SemanticMemory(Memory): if self._rag_service: try: kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids) - if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"): + if self._search_mode == "enhanced" and hasattr( + self._rag_service, "enhanced_search" + ): results = await self._rag_service.enhanced_search( query, knowledge_base_ids=kb_ids, @@ -73,24 +108,28 @@ class SemanticMemory(Memory): use_compression=self._use_compression, ) else: - results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k) + results = await self._rag_service.search( + query, knowledge_base_ids=kb_ids, top_k=top_k + ) for r in results: kb_id = r.get("knowledge_base_id", "") score = r.get("score", 0.0) # Apply per-KB weights if self._kb_weights and kb_id in self._kb_weights: score *= self._kb_weights[kb_id] - items.append(MemoryItem( - key=r.get("id", ""), - value=r.get("content", ""), - metadata={ - "source": r.get("source", "rag"), - "score": score, - "document_id": r.get("document_id"), - "knowledge_base_id": kb_id, - }, - score=score, - )) + items.append( + MemoryItem( + key=r.get("id", ""), + value=r.get("content", ""), + metadata={ + "source": r.get("source", "rag"), + "score": score, + "document_id": r.get("document_id"), + "knowledge_base_id": kb_id, + }, + score=score, + ) + ) except Exception as e: logger.error(f"RAG search failed: {e}") @@ -99,16 +138,18 @@ class SemanticMemory(Memory): try: graph_results = await self._graph_service.query(query, depth=2) for r in graph_results[:top_k]: - items.append(MemoryItem( - key=r.get("id", ""), - value=r.get("content", ""), - metadata={ - "source": "graph", - "entities": r.get("entities", []), - "relations": r.get("relations", []), - }, - score=r.get("score", 0.0), - )) + items.append( + MemoryItem( + key=r.get("id", ""), + value=r.get("content", ""), + metadata={ + "source": "graph", + "entities": r.get("entities", []), + "relations": r.get("relations", []), + }, + score=r.get("score", 0.0), + ) + ) except Exception as e: logger.error(f"Graph search failed: {e}") diff --git a/src/agentkit/memory/working.py b/src/agentkit/memory/working.py index 3861f50..83e3368 100644 --- a/src/agentkit/memory/working.py +++ b/src/agentkit/memory/working.py @@ -3,11 +3,10 @@ import json import logging from datetime import datetime, timezone -from typing import Any import redis.asyncio as aioredis -from agentkit.memory.base import Memory, MemoryItem +from agentkit.memory.base import Memory, MemoryItem, MetadataDict logger = logging.getLogger(__name__) @@ -32,7 +31,7 @@ class WorkingMemory(Memory): def _make_key(self, key: str) -> str: return f"{self._key_prefix}:{key}" - async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None: redis_key = self._make_key(key) item = MemoryItem( key=key, @@ -57,10 +56,14 @@ class WorkingMemory(Memory): value=item_dict["value"], metadata=item_dict.get("metadata", {}), score=item_dict.get("score", 1.0), - created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc), + created_at=datetime.fromisoformat(item_dict["created_at"]) + if item_dict.get("created_at") + else datetime.now(timezone.utc), ) - async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + async def search( + self, query: str, top_k: int = 5, filters: MetadataDict | None = None + ) -> list[MemoryItem]: """Working Memory 不支持语义检索,按 key 前缀匹配""" pattern = self._make_key(f"{query}*") keys = [] @@ -74,13 +77,15 @@ class WorkingMemory(Memory): data = await self._redis.get(key) if data: item_dict = json.loads(data) - items.append(MemoryItem( - key=item_dict["key"], - value=item_dict["value"], - metadata=item_dict.get("metadata", {}), - score=1.0, - created_at=datetime.now(timezone.utc), - )) + items.append( + MemoryItem( + key=item_dict["key"], + value=item_dict["value"], + metadata=item_dict.get("metadata", {}), + score=1.0, + created_at=datetime.now(timezone.utc), + ) + ) return items async def delete(self, key: str) -> bool: diff --git a/src/agentkit/orchestrator/checkpoint.py b/src/agentkit/orchestrator/checkpoint.py index 856e9cf..9cee4d0 100644 --- a/src/agentkit/orchestrator/checkpoint.py +++ b/src/agentkit/orchestrator/checkpoint.py @@ -17,7 +17,7 @@ import logging import time from dataclasses import asdict, dataclass, field from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING logger = logging.getLogger(__name__) @@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days _KEY_PREFIX = "agentkit:pipeline:checkpoint" +if TYPE_CHECKING: + from typing import Protocol + + class _RedisPipelineLike(Protocol): + def set(self, key: str, value: str, ex: int | None = None) -> object: ... + def zadd(self, name: str, mapping: dict[str, float]) -> object: ... + def get(self, key: str) -> object: ... + def delete(self, *keys: str) -> object: ... + async def execute(self) -> list[object]: ... + + class _RedisLike(Protocol): + async def set(self, key: str, value: str, ex: int | None = None) -> object: ... + async def get(self, key: str) -> object: ... + def pipeline(self) -> _RedisPipelineLike: ... + async def zrange(self, name: str, start: int, stop: int) -> list[object]: ... + + class _PlanLike(Protocol): + @property + def id(self) -> str: ... + def to_dict(self) -> dict[str, object]: ... + + @dataclass class CheckpointData: """单个阶段的 checkpoint 数据。""" @@ -33,15 +55,15 @@ class CheckpointData: phase_id: str phase_name: str phase_status: str - phase_result: dict[str, Any] | None = None + phase_result: dict[str, object] | None = None plan_status: str = "" saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, object]: return asdict(self) @classmethod - def from_dict(cls, data: dict[str, Any]) -> CheckpointData: + def from_dict(cls, data: dict[str, object]) -> CheckpointData: return cls( plan_id=data.get("plan_id", ""), phase_id=data.get("phase_id", ""), @@ -67,7 +89,7 @@ class PipelineCheckpoint: def __init__( self, - redis_client: Any = None, + redis_client: _RedisLike | None = None, prefix: str = _KEY_PREFIX, ttl_seconds: int = _TTL_SECONDS, ) -> None: @@ -78,7 +100,7 @@ class PipelineCheckpoint: # P1 #6: 改用 dict keyed by phase_id,避免重复 append self._memory: dict[str, dict[str, CheckpointData]] = {} # 内存降级存储:plan_id → (plan_dict, saved_timestamp) - self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {} + self._memory_plans: dict[str, tuple[dict[str, object], float]] = {} def _is_expired(self, saved_at: str) -> bool: """检查 checkpoint 是否已过期(内存模式 TTL)。""" @@ -102,7 +124,7 @@ class PipelineCheckpoint: """完整 plan JSON 的存储键。""" return f"{self._prefix}:plan:{plan_id}" - async def save_plan(self, plan: Any) -> None: + async def save_plan(self, plan: _PlanLike) -> None: """保存完整 TeamPlan(用于 resume 重建)。 Args: @@ -121,7 +143,7 @@ class PipelineCheckpoint: except Exception as e: logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}") - async def load_plan(self, plan_id: str) -> dict[str, Any] | None: + async def load_plan(self, plan_id: str) -> dict[str, object] | None: """加载完整 plan JSON。""" # 优先 Redis if self._redis is not None: @@ -142,7 +164,7 @@ class PipelineCheckpoint: return None return plan_dict - async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: + async def save(self, plan_id: str, phase: object, plan_status: str) -> None: """保存阶段 checkpoint。 Args: @@ -212,7 +234,8 @@ class PipelineCheckpoint: if not phase_ids: # Redis 无数据,检查内存(过滤过期) return [ - c for c in self._memory.get(plan_id, {}).values() + c + for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] @@ -236,8 +259,7 @@ class PipelineCheckpoint: # 内存降级(过滤过期 checkpoint) return [ - c for c in self._memory.get(plan_id, {}).values() - if not self._is_expired(c.saved_at) + c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] async def clear(self, plan_id: str) -> None: diff --git a/src/agentkit/orchestrator/compensation.py b/src/agentkit/orchestrator/compensation.py index 87eef65..92cd06a 100644 --- a/src/agentkit/orchestrator/compensation.py +++ b/src/agentkit/orchestrator/compensation.py @@ -1,8 +1,8 @@ """Saga compensation pattern for Pipeline execution""" import logging -from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable +from dataclasses import dataclass +from typing import Awaitable, Callable logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ class CompletedStep: """Record of a completed step with its compensation""" step_name: str - result: Any + result: object compensate_action: str | None = None @@ -28,9 +28,7 @@ class CompensationResult: class SagaOrchestrator: """Orchestrates LIFO compensation for failed pipelines""" - def __init__( - self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None - ): + def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None): """ Args: execute_skill_func: Async function to execute a skill by name @@ -42,7 +40,7 @@ class SagaOrchestrator: def record_completed( self, step_name: str, - result: Any, + result: object, compensate_action: str | None = None, ): """Record a completed step for potential compensation""" @@ -59,9 +57,7 @@ class SagaOrchestrator: results: list[CompensationResult] = [] for step in reversed(self._completed_steps): if step.compensate_action is None: - logger.info( - f"No compensation for step '{step.step_name}', skipping" - ) + logger.info(f"No compensation for step '{step.step_name}', skipping") results.append( CompensationResult( step_name=step.step_name, @@ -82,9 +78,7 @@ class SagaOrchestrator: ) ) except Exception as e: - logger.error( - f"Compensation for step '{step.step_name}' failed: {e}" - ) + logger.error(f"Compensation for step '{step.step_name}' failed: {e}") results.append( CompensationResult( step_name=step.step_name, diff --git a/src/agentkit/orchestrator/dynamic_pipeline.py b/src/agentkit/orchestrator/dynamic_pipeline.py index a6b8e51..cf82500 100644 --- a/src/agentkit/orchestrator/dynamic_pipeline.py +++ b/src/agentkit/orchestrator/dynamic_pipeline.py @@ -4,7 +4,6 @@ """ import logging -from typing import Any from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus @@ -15,7 +14,7 @@ logger = logging.getLogger(__name__) class DynamicPipeline: """动态 Pipeline 组合器""" - def __init__(self, engine: PipelineEngine, loader: Any = None): + def __init__(self, engine: PipelineEngine, loader: object | None = None): self._engine = engine self._loader = loader @@ -23,7 +22,7 @@ class DynamicPipeline: self, pipelines: dict[str, Pipeline], condition_key: str, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """根据条件选择子 Pipeline 执行""" context = context or {} @@ -37,14 +36,16 @@ class DynamicPipeline: ) selected = pipelines[condition_value] - logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}") + logger.info( + f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}" + ) return await self._engine.execute(selected, context) async def execute_nested( self, parent: Pipeline, sub_pipeline_map: dict[str, Pipeline], - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """执行嵌套 Pipeline""" # 先执行父 Pipeline @@ -52,7 +53,7 @@ class DynamicPipeline: # 根据父 Pipeline 结果选择子 Pipeline for stage_name, stage_result in parent_result.stage_results.items(): - if hasattr(stage_result, 'output_data') and stage_result.output_data: + if hasattr(stage_result, "output_data") and stage_result.output_data: sub_pipeline_name = stage_result.output_data.get("sub_pipeline") if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map: sub = sub_pipeline_map[sub_pipeline_name] @@ -66,7 +67,7 @@ class DynamicPipeline: pipeline: Pipeline, max_iterations: int = 5, exit_condition: str = "done", - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """循环执行 Pipeline 直到条件满足""" current_context = context or {} diff --git a/src/agentkit/orchestrator/handoff.py b/src/agentkit/orchestrator/handoff.py index cc13631..4c10751 100644 --- a/src/agentkit/orchestrator/handoff.py +++ b/src/agentkit/orchestrator/handoff.py @@ -3,25 +3,42 @@ import asyncio import json import logging -from typing import Any +from typing import Awaitable, Callable, Protocol from agentkit.core.protocol import HandoffMessage logger = logging.getLogger(__name__) +class _RedisPubSubLike(Protocol): + """Structural type for Redis pubsub object.""" + + async def subscribe(self, channel: str) -> None: ... + async def unsubscribe(self, channel: str) -> None: ... + def listen(self) -> object: ... + + +class _RedisLike(Protocol): + """Structural type for async Redis client.""" + + async def publish(self, channel: str, message: str) -> int: ... + def pubsub(self) -> _RedisPubSubLike: ... + + class HandoffManager: """Handoff 管理器 通过 Redis Pub/Sub 管理 Agent 间的任务转交。 """ - def __init__(self, redis: Any = None, dispatcher: Any = None): + def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None): self._redis = redis self._dispatcher = dispatcher - self._handlers: dict[str, list[Any]] = {} + self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {} - def register_handler(self, agent_name: str, handler: Any) -> None: + def register_handler( + self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]] + ) -> None: """注册 Handoff 处理器""" if agent_name not in self._handlers: self._handlers[agent_name] = [] diff --git a/src/agentkit/orchestrator/pipeline_engine.py b/src/agentkit/orchestrator/pipeline_engine.py index a8b0bd2..ec4c903 100644 --- a/src/agentkit/orchestrator/pipeline_engine.py +++ b/src/agentkit/orchestrator/pipeline_engine.py @@ -1,10 +1,12 @@ """Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿""" +from __future__ import annotations + import asyncio import logging from collections import defaultdict from datetime import datetime, timezone -from typing import Any +from typing import TYPE_CHECKING from agentkit.orchestrator.compensation import SagaOrchestrator from agentkit.orchestrator.pipeline_schema import ( @@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Protocol + + class _DispatcherLike(Protocol): + async def dispatch(self, task: object) -> None: ... + async def get_task_status(self, task_id: str) -> dict[str, object]: ... + + class _StateManagerLike(Protocol): + async def create_execution(self, **kwargs: object) -> str: ... + async def update_step(self, **kwargs: object) -> None: ... + async def complete_execution(self, **kwargs: object) -> None: ... + async def fail_execution(self, **kwargs: object) -> None: ... + + class _LLMGatewayLike(Protocol): + async def chat(self, **kwargs: object) -> object: ... + + class PipelineEngine: """Pipeline 执行引擎 @@ -38,7 +57,12 @@ class PipelineEngine: - 状态持久化(可选) """ - def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None): + def __init__( + self, + dispatcher: _DispatcherLike | None = None, + state_manager: _StateManagerLike | None = None, + llm_gateway: _LLMGatewayLike | None = None, + ): self._dispatcher = dispatcher self._state_manager = state_manager self._llm_gateway = llm_gateway @@ -46,7 +70,7 @@ class PipelineEngine: async def execute( self, pipeline: Pipeline, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, adaptive_config: AdaptiveConfig | None = None, ) -> PipelineResult: """执行 Pipeline @@ -68,7 +92,7 @@ class PipelineEngine: async def _adaptive_loop( self, pipeline: Pipeline, - context: dict[str, Any] | None, + context: dict[str, object] | None, failed_result: PipelineResult, adaptive_config: AdaptiveConfig, ) -> PipelineResult: @@ -92,34 +116,30 @@ class PipelineEngine: # Replan new_pipeline = await replanner.replan(current_pipeline, current_result, report) - logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)") + logger.info( + f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)" + ) # Re-execute current_result = await self._execute_pipeline(new_pipeline, context) current_pipeline = new_pipeline # Record reflection in metadata - current_result.metadata["reflections"] = [ - r.model_dump() for r in reflections - ] + current_result.metadata["reflections"] = [r.model_dump() for r in reflections] if current_result.status == StageStatus.COMPLETED: logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)") return current_result # Exhausted reflections - logger.warning( - f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)" - ) - current_result.metadata["reflections"] = [ - r.model_dump() for r in reflections - ] + logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)") + current_result.metadata["reflections"] = [r.model_dump() for r in reflections] return current_result async def _execute_pipeline( self, pipeline: Pipeline, - context: dict[str, Any] | None = None, + context: dict[str, object] | None = None, ) -> PipelineResult: """执行 Pipeline 的核心逻辑(不含反思-重规划)。""" result = PipelineResult(pipeline_name=pipeline.name) @@ -151,7 +171,9 @@ class PipelineEngine: # 逐层执行 for level, stages in enumerate(level_groups): - logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") + logger.info( + f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)" + ) # 并行执行同层 stages tasks = [] @@ -173,9 +195,11 @@ class PipelineEngine: # Update step state if self._state_manager is not None and execution_id is not None: try: - step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value - step_output = sr.output_data if hasattr(sr, 'output_data') else None - step_error = sr.error_message if hasattr(sr, 'error_message') else None + step_status = ( + "completed" if sr.status == StageStatus.COMPLETED else sr.status.value + ) + step_output = sr.output_data if hasattr(sr, "output_data") else None + step_error = sr.error_message if hasattr(sr, "error_message") else None await self._state_manager.update_step( execution_id=execution_id, step_name=stage.name, @@ -189,19 +213,21 @@ class PipelineEngine: # 收集输出变量 if sr.output_data and isinstance(sr, dict): pass - elif hasattr(sr, 'output_data') and sr.output_data: + elif hasattr(sr, "output_data") and sr.output_data: for output_key in stage.outputs: if output_key in sr.output_data: result.variables[output_key] = sr.output_data[output_key] # 检查是否需要中止 - if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: + if hasattr(sr, "status") and sr.status == StageStatus.FAILED: if not stage.continue_on_failure: # Execute Saga compensation for completed steps compensation_results = await saga.compensate() if compensation_results: failed_compensations = [ - cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed" + cr + for cr in compensation_results + if not cr.success and cr.error != "no_compensation_needed" ] if failed_compensations: logger.warning( @@ -219,7 +245,12 @@ class PipelineEngine: step_name=stage.name, error=result.error_message, ) - except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc: + except ( + asyncio.TimeoutError, + ConnectionError, + RuntimeError, + ValueError, + ) as exc: logger.warning(f"Failed to persist failure state: {exc}") return result @@ -252,7 +283,9 @@ class PipelineEngine: started_at = datetime.now(timezone.utc).isoformat() # 条件检查 - if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables): + if stage.condition and not self._evaluate_condition( + stage.condition, pipeline_result.variables + ): return StageResult( stage_name=stage.name, status=StageStatus.SKIPPED, @@ -312,7 +345,9 @@ class PipelineEngine: if status["status"] in ("completed", "failed", "cancelled"): return StageResult( stage_name=stage.name, - status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, + status=StageStatus.COMPLETED + if status["status"] == "completed" + else StageStatus.FAILED, output_data=status.get("output_data"), error_message=status.get("error_message"), started_at=started_at, @@ -406,7 +441,7 @@ class PipelineEngine: return resolved @staticmethod - def _get_nested(data: dict, path: str) -> Any: + def _get_nested(data: dict, path: str) -> object: keys = path.split(".") current = data for key in keys: @@ -497,9 +532,7 @@ class PipelineEngine: if verifier_feedback.passed: # 审查通过,返回成功结果 - logger.info( - f"Stage '{stage.name}' passed review in round {round_num}" - ) + logger.info(f"Stage '{stage.name}' passed review in round {round_num}") worker_result.output_data = worker_result.output_data or {} worker_result.output_data["adversarial_metadata"] = { "passed_round": round_num, @@ -553,7 +586,7 @@ class PipelineEngine: self, agent_name: str, action: str, - input_data: dict[str, Any], + input_data: dict[str, object], stage: PipelineStage, started_at: str, timeout_seconds: int | None = None, @@ -568,7 +601,9 @@ class PipelineEngine: started_at: 开始时间 timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds """ - effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds + effective_timeout = ( + timeout_seconds if timeout_seconds is not None else stage.timeout_seconds + ) if self._dispatcher is None: # Dry-run 模式 return StageResult( @@ -602,7 +637,9 @@ class PipelineEngine: if status["status"] in ("completed", "failed", "cancelled"): return StageResult( stage_name=stage.name, - status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, + status=StageStatus.COMPLETED + if status["status"] == "completed" + else StageStatus.FAILED, output_data=status.get("output_data"), error_message=status.get("error_message"), started_at=started_at, @@ -639,7 +676,7 @@ class PipelineEngine: async def _execute_verifier( self, verifier_name: str, - worker_output: dict[str, Any], + worker_output: dict[str, object], stage: PipelineStage, started_at: str, ) -> ReviewFeedback: @@ -679,10 +716,7 @@ class PipelineEngine: try: feedback = ReviewFeedback( passed=output_data.get("passed", False), - issues=[ - ReviewIssue(**issue) - for issue in output_data.get("issues", []) - ], + issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])], summary=output_data.get("summary", "No summary provided"), score=output_data.get("score", 0.0), ) @@ -699,7 +733,7 @@ class PipelineEngine: self, feedback: ReviewFeedback, feedback_mode: str = "structured+natural", - ) -> dict[str, Any]: + ) -> dict[str, object]: """构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复 Args: @@ -720,7 +754,7 @@ class PipelineEngine: for issue in feedback.issues ] - feedback_context: dict[str, Any] = { + feedback_context: dict[str, object] = { "previous_attempt_failed": True, } @@ -756,7 +790,9 @@ class PipelineEngine: ) else: # 未知模式,fallback 到 structured+natural - logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural") + logger.warning( + f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural" + ) feedback_context["review_feedback"] = { "summary": feedback.summary, "issues": issues_list, diff --git a/src/agentkit/orchestrator/pipeline_loader.py b/src/agentkit/orchestrator/pipeline_loader.py index e22498a..d7821c7 100644 --- a/src/agentkit/orchestrator/pipeline_loader.py +++ b/src/agentkit/orchestrator/pipeline_loader.py @@ -2,7 +2,6 @@ import logging from pathlib import Path -from typing import Any import yaml @@ -23,7 +22,9 @@ class PipelineLoader: if not yaml_path.exists(): yaml_path = self._pipelines_dir / f"{pipeline_name}.yml" if not yaml_path.exists(): - raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}") + raise FileNotFoundError( + f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}" + ) content = yaml_path.read_text(encoding="utf-8") return self.load_from_yaml(content, pipeline_name) diff --git a/src/agentkit/orchestrator/pipeline_models.py b/src/agentkit/orchestrator/pipeline_models.py index 3fa1208..a06c472 100644 --- a/src/agentkit/orchestrator/pipeline_models.py +++ b/src/agentkit/orchestrator/pipeline_models.py @@ -35,9 +35,7 @@ class PipelineExecutionModel(Base): ) completed_at = Column(DateTime) - __table_args__ = ( - Index("ix_pipeline_status_created", "status", "created_at"), - ) + __table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),) class PipelineStepHistoryModel(Base): diff --git a/src/agentkit/orchestrator/pipeline_schema.py b/src/agentkit/orchestrator/pipeline_schema.py index 5f3cf0a..67182a3 100644 --- a/src/agentkit/orchestrator/pipeline_schema.py +++ b/src/agentkit/orchestrator/pipeline_schema.py @@ -1,7 +1,7 @@ """Pipeline 数据模型""" from enum import Enum -from typing import Any, Literal +from typing import Literal from pydantic import BaseModel, Field @@ -18,8 +18,11 @@ class StageStatus(str, Enum): class ReviewIssue(BaseModel): """单条审查问题""" + severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度") - category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别") + category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field( + description="问题类别" + ) description: str = Field(min_length=1, description="问题描述") location: str | None = Field(default=None, description="文件路径/行号") suggestion: str | None = Field(default=None, description="修复建议") @@ -27,6 +30,7 @@ class ReviewIssue(BaseModel): class ReviewFeedback(BaseModel): """Verifier 返回的结构化审查反馈""" + passed: bool = Field(description="是否通过审查") issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表") summary: str = Field(min_length=1, description="自然语言审查报告") @@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel): class AdversarialState(BaseModel): """对抗轮次状态追踪""" + current_round: int = Field(default=0, description="当前对抗轮次") max_rounds: int = Field(default=3, description="最大对抗轮次") feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史") @@ -46,7 +51,7 @@ class PipelineStage(BaseModel): agent: str action: str depends_on: list[str] = [] - inputs: dict[str, Any] = {} + inputs: dict[str, object] = {} outputs: list[str] = [] timeout_seconds: int = 300 retry_count: int = 0 @@ -54,12 +59,19 @@ class PipelineStage(BaseModel): condition: str | None = None retry_policy: StepRetryPolicy | None = None compensate: str | None = None - + # 对抗闭环相关字段 - verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式") + verifier: str | None = Field( + default=None, description="Verifier Agent 名称,配置后启用对抗模式" + ) max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次") - verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds") - feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式") + verifier_timeout_seconds: int = Field( + default=120, + description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds", + ) + feedback_mode: Literal["structured+natural", "structured", "natural"] = Field( + default="structured+natural", description="反馈模式" + ) escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标") model_config = {"arbitrary_types_allowed": True} @@ -70,13 +82,13 @@ class Pipeline(BaseModel): version: str description: str stages: list[PipelineStage] - variables: dict[str, Any] = {} + variables: dict[str, object] = {} class StageResult(BaseModel): stage_name: str status: StageStatus = StageStatus.PENDING - output_data: dict[str, Any] | None = None + output_data: dict[str, object] | None = None error_message: str | None = None started_at: str | None = None completed_at: str | None = None @@ -86,9 +98,9 @@ class PipelineResult(BaseModel): pipeline_name: str status: StageStatus = StageStatus.PENDING stage_results: dict[str, StageResult] = {} - variables: dict[str, Any] = {} + variables: dict[str, object] = {} error_message: str | None = None - metadata: dict[str, Any] = {} + metadata: dict[str, object] = {} class AdaptiveConfig(BaseModel): diff --git a/src/agentkit/orchestrator/pipeline_state.py b/src/agentkit/orchestrator/pipeline_state.py index 1acc9c8..cb2a6ce 100644 --- a/src/agentkit/orchestrator/pipeline_state.py +++ b/src/agentkit/orchestrator/pipeline_state.py @@ -13,7 +13,7 @@ import json import logging import uuid from datetime import datetime, timezone -from typing import Any, Callable, Coroutine +from typing import Callable, Coroutine from agentkit.orchestrator.pipeline_models import ( PipelineExecutionModel, @@ -183,7 +183,7 @@ class PipelineStateRedis: return self._redis async def _safe_redis_call( - self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object + self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object ) -> object | None: """Execute a Redis call, falling back to memory on failure. diff --git a/src/agentkit/orchestrator/reflection.py b/src/agentkit/orchestrator/reflection.py index 18aabc9..de97a68 100644 --- a/src/agentkit/orchestrator/reflection.py +++ b/src/agentkit/orchestrator/reflection.py @@ -4,22 +4,30 @@ 生成修正后的 Pipeline 重新执行。 """ +from __future__ import annotations + import json import logging -from typing import Any +from typing import TYPE_CHECKING from agentkit.orchestrator.pipeline_schema import ( Pipeline, PipelineResult, PipelineStage, ReflectionReport, - StageResult, StageStatus, ) logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Protocol + + class _LLMGatewayLike(Protocol): + async def chat(self, **kwargs: object) -> object: ... + + class PipelineReflector: """分析 Pipeline 执行失败原因,生成结构化反思报告。 @@ -27,7 +35,7 @@ class PipelineReflector: 输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。 """ - def __init__(self, llm_gateway: Any = None): + def __init__(self, llm_gateway: _LLMGatewayLike | None = None): self._llm_gateway = llm_gateway async def reflect( @@ -54,19 +62,25 @@ class PipelineReflector: if self._llm_gateway is not None: try: return await self._llm_reflect( - pipeline, failed_stage, error_message, - completed_outputs, reflection_number, + pipeline, + failed_stage, + error_message, + completed_outputs, + reflection_number, ) except Exception as e: logger.warning(f"LLM reflection failed, falling back to rule-based: {e}") # 规则兜底:基于错误信息分类 return self._rule_based_reflect( - failed_stage, error_message, reflection_number, + failed_stage, + error_message, + reflection_number, ) def _find_failure( - self, result: PipelineResult, + self, + result: PipelineResult, ) -> tuple[str, str]: """找到第一个失败的 stage 及其错误信息。""" for name, sr in result.stage_results.items(): @@ -75,8 +89,9 @@ class PipelineReflector: return "", "no failed stage found" def _collect_completed_outputs( - self, result: PipelineResult, - ) -> dict[str, Any]: + self, + result: PipelineResult, + ) -> dict[str, object]: """收集已完成步骤的输出。""" outputs = {} for name, sr in result.stage_results.items(): @@ -89,13 +104,16 @@ class PipelineReflector: pipeline: Pipeline, failed_stage: str, error_message: str, - completed_outputs: dict[str, Any], + completed_outputs: dict[str, object], reflection_number: int, ) -> ReflectionReport: """使用 LLM 分析失败原因。""" prompt = self._build_reflection_prompt( - pipeline, failed_stage, error_message, - completed_outputs, reflection_number, + pipeline, + failed_stage, + error_message, + completed_outputs, + reflection_number, ) response = await self._llm_gateway.chat( @@ -106,7 +124,9 @@ class PipelineReflector: # 解析 LLM 返回的 JSON content = response.content if hasattr(response, "content") else str(response) return self._parse_reflection_response( - content, failed_stage, reflection_number, + content, + failed_stage, + reflection_number, ) def _build_reflection_prompt( @@ -114,15 +134,14 @@ class PipelineReflector: pipeline: Pipeline, failed_stage: str, error_message: str, - completed_outputs: dict[str, Any], + completed_outputs: dict[str, object], reflection_number: int, ) -> str: """构建反思提示词。""" stage_descriptions = [] for s in pipeline.stages: stage_descriptions.append( - f" - {s.name}: agent={s.agent}, action={s.action}, " - f"depends_on={s.depends_on}" + f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}" ) completed_summary = json.dumps( @@ -174,7 +193,9 @@ JSON response:""" except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Failed to parse LLM reflection response: {e}") return self._rule_based_reflect( - failed_stage, content, reflection_number, + failed_stage, + content, + reflection_number, ) def _rule_based_reflect( @@ -218,7 +239,7 @@ class PipelineReplanner: 保留已完成步骤的结果,仅重新规划失败及后续步骤。 """ - def __init__(self, llm_gateway: Any = None): + def __init__(self, llm_gateway: _LLMGatewayLike | None = None): self._llm_gateway = llm_gateway async def replan( @@ -255,8 +276,7 @@ class PipelineReplanner: ) -> Pipeline: """使用 LLM 生成修正后的 Pipeline。""" completed_stages = [ - name for name, sr in result.stage_results.items() - if sr.status == StageStatus.COMPLETED + name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED ] prompt = f"""Based on the reflection report, generate a corrected pipeline. @@ -284,7 +304,9 @@ JSON pipeline:""" return self._parse_pipeline_response(content, pipeline) def _parse_pipeline_response( - self, content: str, original: Pipeline, + self, + content: str, + original: Pipeline, ) -> Pipeline: """解析 LLM 返回的 Pipeline JSON。""" try: @@ -294,9 +316,7 @@ JSON pipeline:""" text = "\n".join(lines[1:-1]) data = json.loads(text) - stages = [ - PipelineStage(**s) for s in data.get("stages", []) - ] + stages = [PipelineStage(**s) for s in data.get("stages", [])] return Pipeline( name=data.get("name", original.name), version=data.get("version", original.version), @@ -316,8 +336,7 @@ JSON pipeline:""" ) -> Pipeline: """基于规则的兜底重规划。""" completed_stages = { - name for name, sr in result.stage_results.items() - if sr.status == StageStatus.COMPLETED + name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED } # 构建修正后的 stages 列表 @@ -345,17 +364,21 @@ JSON pipeline:""" ) def _adjust_failed_stage( - self, stage: PipelineStage, report: ReflectionReport, + self, + stage: PipelineStage, + report: ReflectionReport, ) -> PipelineStage: """根据反思报告调整失败的步骤。""" - adjustments: dict[str, Any] = {} + adjustments: dict[str, object] = {} if report.failure_type == "timeout": adjustments["timeout_seconds"] = min( - stage.timeout_seconds * 2, 3600, + stage.timeout_seconds * 2, + 3600, ) if stage.retry_policy is None: from agentkit.orchestrator.retry import StepRetryPolicy + adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) elif report.failure_type == "resource_error": @@ -365,6 +388,7 @@ JSON pipeline:""" # 添加重试策略,可能输入在后续可用 if stage.retry_policy is None: from agentkit.orchestrator.retry import StepRetryPolicy + adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) return stage.model_copy(update=adjustments) diff --git a/src/agentkit/orchestrator/retry.py b/src/agentkit/orchestrator/retry.py index 4cb4ebd..372305d 100644 --- a/src/agentkit/orchestrator/retry.py +++ b/src/agentkit/orchestrator/retry.py @@ -4,7 +4,7 @@ import asyncio import logging import random from dataclasses import dataclass -from typing import Any, Awaitable, Callable +from typing import Awaitable, Callable logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class StepRetryPolicy: def calculate_delay(self, attempt: int) -> float: """Calculate delay for given attempt number (0-based)""" delay = min( - self.base_delay * (self.exponential_base ** attempt), + self.base_delay * (self.exponential_base**attempt), self.max_delay, ) if self.jitter: @@ -36,10 +36,10 @@ class StepRetryPolicy: async def execute_with_retry( - func: Callable[..., Awaitable[Any]], + func: Callable[..., Awaitable[object]], retry_policy: StepRetryPolicy | None = None, step_name: str = "", -) -> Any: +) -> object: """Execute a function with retry policy""" if retry_policy is None: return await func() diff --git a/src/agentkit/orchestrator/workflow_schema.py b/src/agentkit/orchestrator/workflow_schema.py index feb3c48..ea26280 100644 --- a/src/agentkit/orchestrator/workflow_schema.py +++ b/src/agentkit/orchestrator/workflow_schema.py @@ -3,18 +3,17 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any from pydantic import BaseModel, Field -from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage +from agentkit.orchestrator.pipeline_schema import PipelineStage class WorkflowStage(PipelineStage): """A workflow stage extending PipelineStage with type and config.""" type: str = "skill" # "skill" | "condition" | "approval" | "parallel" - config: dict[str, Any] = Field(default_factory=dict) + config: dict[str, object] = Field(default_factory=dict) class WorkflowDefinition(BaseModel): @@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel): name: str version: int = 1 stages: list[WorkflowStage] = Field(default_factory=list) - triggers: list[dict[str, Any]] = Field(default_factory=list) - variables_schema: dict[str, Any] = Field(default_factory=dict) - output_schema: dict[str, Any] = Field(default_factory=dict) + triggers: list[dict[str, object]] = Field(default_factory=list) + variables_schema: dict[str, object] = Field(default_factory=dict) + output_schema: dict[str, object] = Field(default_factory=dict) created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel): workflow_id: str = "" status: str = "pending" # pending|running|paused|completed|failed|cancelled current_stage: str | None = None - stage_results: dict[str, Any] = Field(default_factory=dict) + stage_results: dict[str, object] = Field(default_factory=dict) started_at: str | None = None completed_at: str | None = None error: str | None = None - variables: dict[str, Any] = Field(default_factory=dict) + variables: dict[str, object] = Field(default_factory=dict) class WorkflowSummary(BaseModel): @@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel): name: str stages: list[WorkflowStage] = Field(default_factory=list) - triggers: list[dict[str, Any]] = Field(default_factory=list) - variables_schema: dict[str, Any] = Field(default_factory=dict) - output_schema: dict[str, Any] = Field(default_factory=dict) + triggers: list[dict[str, object]] = Field(default_factory=list) + variables_schema: dict[str, object] = Field(default_factory=dict) + output_schema: dict[str, object] = Field(default_factory=dict) class ExecuteWorkflowRequest(BaseModel): """Request body for executing a workflow.""" - variables: dict[str, Any] = Field(default_factory=dict) + variables: dict[str, object] = Field(default_factory=dict) class ApproveRequest(BaseModel): diff --git a/src/agentkit/server/routes/agents.py b/src/agentkit/server/routes/agents.py index 9e77e72..5f39ac6 100644 --- a/src/agentkit/server/routes/agents.py +++ b/src/agentkit/server/routes/agents.py @@ -37,7 +37,7 @@ async def create_agent(request: CreateAgentRequest, req: Request): config_dict = request.config try: config = SkillConfig.from_dict(config_dict) - except Exception: + except (ValueError, KeyError, TypeError): config = AgentConfig.from_dict(config_dict) agent = await pool.create_agent(config) else: diff --git a/src/agentkit/server/routes/auth.py b/src/agentkit/server/routes/auth.py index aa16d42..88a6c87 100644 --- a/src/agentkit/server/routes/auth.py +++ b/src/agentkit/server/routes/auth.py @@ -53,6 +53,8 @@ from agentkit.server.auth.session_service import ( REVOKE_REASON_PASSWORD_CHANGED, REVOKE_REASON_USER_TERMINATED, SessionCreate, + SessionNotFound, + SessionReuseDetected, SessionService, get_session_service, ) @@ -253,7 +255,7 @@ def _is_legacy_client(request: Request) -> bool: if client_v is None or cutoff_v is None: return False return client_v < cutoff_v - except Exception: # noqa: BLE001 + except (ValueError, TypeError, AttributeError): # noqa: BLE001 logger.debug("Failed to parse X-Client-Version %r", raw) return False @@ -492,7 +494,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse: # 1. Verify signature + type try: refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh") - except Exception as exc: # noqa: BLE001 + except (jwt.PyJWTError, ValueError, KeyError) as exc: # noqa: BLE001 raise HTTPException(status_code=401, detail="Invalid refresh token") from exc # 2-3. Validate the session (also handles reuse detection) @@ -510,7 +512,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse: new_refresh_token=new_pair.refresh_token, new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()), ) - except Exception as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound + except (SessionReuseDetected, SessionNotFound, ValueError, KeyError, RuntimeError) as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound logger.info("Refresh rejected: %s", exc) raise HTTPException(status_code=401, detail="Invalid refresh token") from exc diff --git a/src/agentkit/server/routes/bitable.py b/src/agentkit/server/routes/bitable.py index 64a11d6..d8e8f95 100644 --- a/src/agentkit/server/routes/bitable.py +++ b/src/agentkit/server/routes/bitable.py @@ -483,7 +483,7 @@ async def validate_formula( parse_formula(body.formula) except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e: return {"valid": False, "error": str(e)} - except Exception as e: # pragma: no cover — defensive + except (ValueError, TypeError, KeyError, AttributeError) as e: # pragma: no cover — defensive return {"valid": False, "error": f"Unexpected error: {e}"} return {"valid": True} @@ -750,7 +750,7 @@ async def upload_file( f.write(chunk) except HTTPException: raise - except Exception as exc: + except (OSError, RuntimeError) as exc: file_path.unlink(missing_ok=True) logger.error(f"Failed to save uploaded bitable file: {exc}") raise HTTPException(status_code=500, detail="Failed to save file") from exc diff --git a/src/agentkit/server/routes/channels.py b/src/agentkit/server/routes/channels.py index 8b6033f..f8ecda5 100644 --- a/src/agentkit/server/routes/channels.py +++ b/src/agentkit/server/routes/channels.py @@ -553,7 +553,7 @@ async def _invalidate_adapter_cache(channel_id: str) -> None: if old is not None: try: await old.close() - except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更 + except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001 — 关闭异常不应阻塞配置变更 logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id) @@ -562,7 +562,7 @@ async def close_all_adapters() -> None: for channel_id, adapter in list(_adapter_cache.items()): try: await adapter.close() - except Exception: # noqa: BLE001 + except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001 logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id) _adapter_cache.clear() @@ -614,6 +614,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess model=routing.model or "default", ) final_content = getattr(result, "content", "") or "" + except asyncio.CancelledError: + raise except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常 logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc) final_content = await _direct_chat(llm_gateway, routing) @@ -628,6 +630,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess content=final_content, ) await adapter.send_message(outgoing) + except asyncio.CancelledError: + raise except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力 logger.exception("处理入站消息失败: %s", exc) @@ -703,7 +707,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: except WeComURLVerification as e: # 企微 URL 验证 — 返回 XML 响应 return Response(content=e.response_xml, media_type="application/xml") - except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴 + except (ValueError, KeyError, RuntimeError, AttributeError, OSError) as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴 logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc) return {"code": 0, "msg": "invalid_payload"} diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index f47b5a7..ae81da5 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -108,7 +108,7 @@ class ChatConnectionManager: for ws, _ in conns: try: await ws.send_json(message) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): stale.append(ws) for ws in stale: self.remove(session_id, ws) @@ -295,6 +295,8 @@ async def _execute_board_meeting( await team.create_board(topic=routing_result.topic, expert_configs=expert_configs) orchestrator = BoardOrchestrator(team=team) result = await orchestrator.execute(routing_result.topic) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True) await websocket.send_json( @@ -302,7 +304,7 @@ async def _execute_board_meeting( ) try: await team.dissolve() - except Exception: + except (RuntimeError, asyncio.TimeoutError, ConnectionError): pass return True finally: @@ -348,7 +350,7 @@ async def _execute_board_meeting( # Dissolve the team to release expert agents try: await team.dissolve() - except Exception as e: + except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.warning(f"Board team dissolve failed: {e}") return True @@ -467,7 +469,7 @@ async def _execute_team_collab( # Always dissolve the team and remove handler to avoid leaks try: await team.dissolve() - except Exception as e: + except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.warning(f"Team dissolve failed: {e}") # dissolve() already clears handlers via handoff_transport.close() @@ -585,7 +587,7 @@ def _build_phase_engine( if phase_policy is None: # Empty config (no `plan_exec:` section) → use KTD5 defaults. phase_policy = default_policy() - except Exception as e: + except (ValueError, TypeError, KeyError) as e: logger.error( "PLAN_EXEC phase policy construction failed for session %s: %s", session_id, @@ -695,6 +697,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques agent_name=agent.name, system_prompt=system_prompt, ) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -773,6 +777,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques return response_dict return response + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -829,7 +835,7 @@ async def _resolve_ws_dept_context( db_path_resolved = Path(db_path) try: department_ids = await _fetch_user_department_ids(db_path_resolved, user_id) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.exception( "Failed to fetch department ids for WebSocket user %s — fail-closed", user_id, @@ -946,7 +952,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: "data": {"content": content}, } ) - except Exception as e: + except (asyncio.QueueFull, RuntimeError, ConnectionError) as e: logger.warning(f"Failed to enqueue intervention: {e}") await websocket.send_json( { @@ -1022,11 +1028,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: except WebSocketDisconnect: logger.debug(f"Chat WebSocket disconnected for session {session_id}") + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Chat WebSocket error for session {session_id}: {e}") try: await websocket.send_json({"type": "error", "data": {"message": str(e)}}) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: # Clean up pending futures @@ -1174,6 +1182,8 @@ async def _handle_chat_message( content=final_content, agent_name=agent.name, ) + except asyncio.CancelledError: + raise except Exception as e: # Check if this is a QuotaExceededError (U4: WebSocket quota). from agentkit.llm.gateway import QuotaExceededError @@ -1422,6 +1432,8 @@ async def _handle_chat_message( agent_name=agent.name, ) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") # Show meaningful error to user, but avoid leaking full stack traces @@ -1473,6 +1485,8 @@ async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]: file_path.write_bytes(contents) except HTTPException: raise + except asyncio.CancelledError: + raise except Exception as exc: logger.error(f"Failed to save uploaded file: {exc}") raise HTTPException(status_code=500, detail="Failed to save file") from exc diff --git a/src/agentkit/server/routes/config_sync.py b/src/agentkit/server/routes/config_sync.py index 50b95b3..8231891 100644 --- a/src/agentkit/server/routes/config_sync.py +++ b/src/agentkit/server/routes/config_sync.py @@ -64,7 +64,7 @@ def _collect_skill_configs(request: Request) -> list[dict[str, Any]]: "version": getattr(skill.config, "version", "1.0.0"), "config": config_dict, }) - except Exception as e: + except (AttributeError, ValueError, KeyError, RuntimeError) as e: logger.warning(f"Failed to collect skill configs: {e}") return configs @@ -81,7 +81,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]: from agentkit.server.routes.workflows import _workflow_store workflow_store = _workflow_store - except Exception: + except (ImportError, AttributeError): return [] configs: list[dict[str, Any]] = [] @@ -94,7 +94,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]: configs.append(wf.to_dict()) else: configs.append(dict(wf)) - except Exception as e: + except (RuntimeError, AttributeError, ValueError) as e: logger.warning(f"Failed to collect workflow configs: {e}") return configs diff --git a/src/agentkit/server/routes/documents.py b/src/agentkit/server/routes/documents.py index 0282f58..554c924 100644 --- a/src/agentkit/server/routes/documents.py +++ b/src/agentkit/server/routes/documents.py @@ -13,6 +13,7 @@ Endpoints: from __future__ import annotations +import asyncio import hmac import logging import uuid @@ -155,6 +156,8 @@ async def create_document( raise HTTPException(status_code=400, detail=str(e)) from e except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Document creation failed: {e}") raise HTTPException(status_code=500, detail="Document creation failed") from e @@ -187,7 +190,7 @@ async def upload_template( file_path.write_bytes(contents) except HTTPException: raise - except Exception as exc: + except (OSError, RuntimeError) as exc: logger.error(f"Failed to save template: {exc}") raise HTTPException(status_code=500, detail="Failed to save template") from exc finally: diff --git a/src/agentkit/server/routes/evolution.py b/src/agentkit/server/routes/evolution.py index 6db3930..10ff964 100644 --- a/src/agentkit/server/routes/evolution.py +++ b/src/agentkit/server/routes/evolution.py @@ -1,5 +1,6 @@ """Evolution API routes""" +import asyncio import logging from fastapi import APIRouter, HTTPException, Request @@ -42,6 +43,8 @@ async def list_evolution_events( agent_name=agent_name, change_type=event_type, ) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to list evolution events: {e}") raise HTTPException(status_code=500, detail="Failed to list evolution events") @@ -63,6 +66,8 @@ async def get_skill_versions(skill_name: str, req: Request = None): store = _get_evolution_store(req) try: versions = await store.list_skill_versions(skill_name) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to get skill versions for '{skill_name}': {e}") raise HTTPException(status_code=500, detail="Failed to get skill versions") @@ -103,6 +108,8 @@ async def trigger_evolution(request: TriggerEvolutionRequest, req: Request = Non ) try: event_id = await store.record(event) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to record trigger event: {e}") raise HTTPException(status_code=500, detail="Failed to trigger evolution") @@ -153,6 +160,8 @@ async def list_ab_tests( for e in entries ] return {"items": results[:limit], "total": len(results)} + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to list A/B tests from persistent store: {e}") raise HTTPException(status_code=500, detail="Failed to list A/B tests") diff --git a/src/agentkit/server/routes/evolution_dashboard.py b/src/agentkit/server/routes/evolution_dashboard.py index 2c491fb..4168f73 100644 --- a/src/agentkit/server/routes/evolution_dashboard.py +++ b/src/agentkit/server/routes/evolution_dashboard.py @@ -194,7 +194,7 @@ async def list_experiences( } ) return {"experiences": experiences, "total": len(experiences)} - except Exception as e: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e: logger.error(f"Failed to list experiences from store: {e}") # Fallback to in-memory store @@ -324,7 +324,7 @@ async def get_metrics( # Generate daily trends from the metrics trends = _generate_trends(metrics_list, period) - except Exception as e: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e: logger.error(f"Failed to get metrics from store: {e}") else: # Generate from in-memory experiences @@ -501,7 +501,7 @@ async def get_usage( "errors": 0, "avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1), }) - except Exception as e: + except (ConnectionError, OSError, ValueError, KeyError, RuntimeError, AttributeError) as e: logger.error(f"Failed to get usage from LLMGateway: {e}") # Fill in missing dates with zero @@ -587,7 +587,7 @@ async def check_pitfalls( }) return {"warnings": warnings_data} - except Exception as e: + except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError, ConnectionError) as e: logger.error(f"Failed to check pitfalls: {e}") return {"warnings": []} @@ -642,7 +642,7 @@ async def list_path_optimizations( else None, } ) - except Exception as e: + except (ValueError, KeyError, RuntimeError, AttributeError) as e: logger.error(f"Failed to get path optimizations: {e}") # Also include in-memory optimizations @@ -767,7 +767,7 @@ async def evolution_dashboard_ws(websocket: WebSocket): ) except WebSocketDisconnect: logger.debug("Evolution dashboard WebSocket disconnected") - except Exception as e: + except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.error(f"Evolution dashboard WebSocket error: {e}") finally: if websocket in _ws_connections: @@ -781,7 +781,7 @@ async def _broadcast_event(event_type: str, data: dict): for ws in _ws_connections: try: await ws.send_json(message) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): disconnected.append(ws) for ws in disconnected: if ws in _ws_connections: diff --git a/src/agentkit/server/routes/health.py b/src/agentkit/server/routes/health.py index 06b3fe6..64e2ba9 100644 --- a/src/agentkit/server/routes/health.py +++ b/src/agentkit/server/routes/health.py @@ -1,5 +1,7 @@ """Health check route""" +import asyncio + from fastapi import APIRouter, Request router = APIRouter(tags=["health"]) @@ -23,14 +25,14 @@ async def health_check(request: Request): redis_client = await task_store._get_redis() await redis_client.ping() redis_status = "available" - except Exception as ping_exc: + except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as ping_exc: redis_status = f"error: {str(ping_exc)[:100]}" overall_status = "degraded" else: redis_status = "not_configured" else: redis_status = "not_configured" - except Exception as exc: + except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError, AttributeError) as exc: redis_status = f"error: {str(exc)[:100]}" overall_status = "degraded" checks["redis"] = redis_status @@ -42,7 +44,7 @@ async def health_check(request: Request): try: agents = agent_pool.list_agents() pool_size = len(agents) - except Exception: + except (RuntimeError, AttributeError): pass checks["agent_pool"] = {"status": "available", "size": pool_size} @@ -57,7 +59,7 @@ async def health_check(request: Request): else: llm_status = "no_providers" overall_status = "degraded" - except Exception: + except (RuntimeError, AttributeError, ValueError): llm_status = "error" overall_status = "degraded" checks["llm_gateway"] = llm_status @@ -68,7 +70,7 @@ async def health_check(request: Request): if skill_registry: try: skill_count = len(skill_registry.list_skills()) - except Exception: + except (RuntimeError, AttributeError): pass checks["skill_registry"] = { "status": "available" if skill_registry else "not_configured", diff --git a/src/agentkit/server/routes/kb_management.py b/src/agentkit/server/routes/kb_management.py index 1bc570d..4644828 100644 --- a/src/agentkit/server/routes/kb_management.py +++ b/src/agentkit/server/routes/kb_management.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import hmac import logging import os @@ -252,7 +253,7 @@ async def list_sources( visible_ids = await filter_kb_sources_by_department( db_path, dept_ctx.department_ids, all_ids ) - except Exception: # noqa: BLE001 — never block listing on DB errors + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors logger.exception("Department KB filtering failed — returning empty list") return {"sources": []} visible_set = set(visible_ids) @@ -398,7 +399,7 @@ async def list_documents( visible_ids = await filter_kb_sources_by_department( db_path, dept_ctx.department_ids, all_source_ids ) - except Exception: # noqa: BLE001 — never block listing on DB errors + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors logger.exception("Department KB filtering failed — returning empty list") return {"documents": []} visible_set = set(visible_ids) @@ -484,7 +485,7 @@ async def upload_document( try: text = processor.parse(tmp_path, file_type) chunks = processor.segment(text) - except Exception as e: + except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e: logger.warning("Document parsing failed: %s", e) raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e @@ -567,7 +568,7 @@ async def preview_document( chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) - except Exception as e: + except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e: logger.warning("Document preview failed: %s", e) raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e @@ -628,7 +629,7 @@ async def search_knowledge( for r in results ] } - except Exception as e: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e: logger.warning(f"Semantic search failed: {e}") # Fallback: return empty results with a hint diff --git a/src/agentkit/server/routes/mcp_publish.py b/src/agentkit/server/routes/mcp_publish.py index 25397c0..1d9b074 100644 --- a/src/agentkit/server/routes/mcp_publish.py +++ b/src/agentkit/server/routes/mcp_publish.py @@ -115,7 +115,7 @@ async def publish_skill( try: skill = skill_registry.get(skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") try: diff --git a/src/agentkit/server/routes/memory.py b/src/agentkit/server/routes/memory.py index 7863a5f..8e6f098 100644 --- a/src/agentkit/server/routes/memory.py +++ b/src/agentkit/server/routes/memory.py @@ -1,5 +1,6 @@ """Memory API routes""" +import asyncio import logging from fastapi import APIRouter, HTTPException, Request @@ -40,6 +41,8 @@ async def search_episodic_memory( if agent_name: filters["agent_name"] = agent_name items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to search episodic memory: {e}") raise HTTPException(status_code=500, detail="Failed to search episodic memory") @@ -76,6 +79,8 @@ async def search_semantic_memory( if knowledge_base_ids: filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")] items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to search semantic memory: {e}") raise HTTPException(status_code=500, detail="Failed to search semantic memory") @@ -104,6 +109,8 @@ async def delete_episodic_memory(key: str, req: Request = None): try: deleted = await retriever._episodic.delete(key) + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Failed to delete episodic memory '{key}': {e}") raise HTTPException(status_code=500, detail="Failed to delete episodic memory") diff --git a/src/agentkit/server/routes/metrics.py b/src/agentkit/server/routes/metrics.py index 451002b..b9eb8fd 100644 --- a/src/agentkit/server/routes/metrics.py +++ b/src/agentkit/server/routes/metrics.py @@ -1,5 +1,6 @@ """Metrics route — /api/v1/metrics""" +import asyncio import logging from fastapi import APIRouter, Request @@ -29,7 +30,7 @@ async def get_metrics(request: Request): task_metrics["completed_tasks"] = counts.get("completed", 0) task_metrics["failed_tasks"] = counts.get("failed", 0) task_metrics["pending_tasks"] = counts.get("pending", 0) - except Exception as e: + except (RuntimeError, AttributeError, ConnectionError, asyncio.TimeoutError) as e: logger.warning(f"Failed to collect task metrics: {e}") # Agent pool metrics @@ -41,7 +42,7 @@ async def get_metrics(request: Request): try: agents = agent_pool.list_agents() agent_metrics["total_agents"] = len(agents) - except Exception as e: + except (RuntimeError, AttributeError) as e: logger.warning(f"Failed to collect agent metrics: {e}") # Skill registry metrics @@ -53,7 +54,7 @@ async def get_metrics(request: Request): try: skills = skill_registry.list_skills() skill_metrics["total_skills"] = len(skills) - except Exception as e: + except (RuntimeError, AttributeError) as e: logger.warning(f"Failed to collect skill metrics: {e}") return { diff --git a/src/agentkit/server/routes/portal.py b/src/agentkit/server/routes/portal.py index 80e7f1d..76b597c 100644 --- a/src/agentkit/server/routes/portal.py +++ b/src/agentkit/server/routes/portal.py @@ -94,7 +94,7 @@ async def _emit_event_safe( data=data or {}, ) await event_queue.emit(event) - except Exception as e: + except (asyncio.QueueFull, RuntimeError, ConnectionError) as e: logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True) @@ -211,7 +211,7 @@ class PortalConnectionManager: import asyncio asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded")) - except Exception: + except (ConnectionError, RuntimeError): pass conns.append(ws) @@ -235,7 +235,7 @@ class PortalConnectionManager: for ws in conns: try: await ws.send_json(message) - except Exception as e: + except (ConnectionError, RuntimeError, asyncio.TimeoutError) as e: logger.debug( "Portal WS send failed for user %s (marking stale): %s", user_id, e ) @@ -285,7 +285,7 @@ async def _build_history_messages( """ try: history = await _conversation_store.get_history(conv_id, limit=limit) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): return [] # The last message in history is the current user message (just added), @@ -553,6 +553,8 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify ): if event.event_type == "final_answer": collected_output.append(event.data.get("output", "")) + except asyncio.CancelledError: + raise except Exception as e: response_text = f"执行出错: {e}" else: @@ -682,6 +684,8 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends( } ), } + except asyncio.CancelledError: + raise except Exception as e: yield { "event": "error", @@ -862,7 +866,7 @@ async def _execute_react_background( await _task_store_update_status( task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc) ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to update TaskStore RUNNING", exc_info=True) async for event in react_engine.execute_stream( @@ -909,7 +913,7 @@ async def _execute_react_background( progress=1.0, progress_message="Completed", ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to update TaskStore COMPLETED", exc_info=True) # Emit task.completed so subscribers know the task is done @@ -932,7 +936,15 @@ async def _execute_react_background( partial = _ensure_non_empty("".join(collected_output)) try: await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial)) - except (Exception, asyncio.CancelledError): + except ( + asyncio.CancelledError, + ConnectionError, + OSError, + asyncio.TimeoutError, + ValueError, + KeyError, + RuntimeError, + ): logger.warning("Failed to persist partial output on cancel") if task_store is not None: try: @@ -945,7 +957,15 @@ async def _execute_react_background( completed_at=datetime.now(timezone.utc), ) ) - except (Exception, asyncio.CancelledError): + except ( + asyncio.CancelledError, + ConnectionError, + OSError, + asyncio.TimeoutError, + ValueError, + KeyError, + RuntimeError, + ): logger.warning("Failed to update TaskStore on cancel", exc_info=True) # P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit). # Shield it so a re-entrant CancelledError doesn't kill the emit @@ -963,7 +983,7 @@ async def _execute_react_background( }, ) ) - except (Exception, asyncio.CancelledError): + except (asyncio.CancelledError, asyncio.QueueFull, RuntimeError, ConnectionError): logger.warning("Failed to emit TASK_FAILED on cancel") raise # Propagate cancellation @@ -973,7 +993,7 @@ async def _execute_react_background( partial = _ensure_non_empty("".join(collected_output)) try: await conversation_store.add_message(conv_id, "assistant", partial) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to persist partial output in background task") if task_store is not None: @@ -985,7 +1005,7 @@ async def _execute_react_background( error_message=str(e), completed_at=datetime.now(timezone.utc), ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to update TaskStore FAILED", exc_info=True) # Emit task.failed so subscribers know the task failed @@ -1120,7 +1140,7 @@ async def portal_websocket(websocket: WebSocket): try: record = await _task_store_get(resume_task_store, resume_task_id) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("TaskStore.get failed during resume", exc_info=True) record = None if record is not None: @@ -1333,7 +1353,7 @@ async def portal_websocket(websocket: WebSocket): }, ) await _broadcast_dashboard_event("metrics_updated", {"period": "7d"}) - except Exception as e: + except (asyncio.QueueFull, RuntimeError, ConnectionError, ValueError, KeyError) as e: logger.warning(f"Failed to record experience: {e}") # Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT) @@ -1414,7 +1434,7 @@ async def portal_websocket(websocket: WebSocket): TaskStatus.PENDING, metadata={"conversation_id": conv.id}, ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to register task in TaskStore", exc_info=True) # Execute based on routing result's execution_mode @@ -1455,7 +1475,7 @@ async def portal_websocket(websocket: WebSocket): progress=1.0, progress_message="Completed", ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True) # Emit turn.final_answer and task.completed to EQ @@ -1526,7 +1546,7 @@ async def portal_websocket(websocket: WebSocket): chat_messages.insert( -1, {"role": hist_msg.role, "content": hist_msg.content} ) - except Exception: + except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): pass response = await llm_gateway.chat( messages=chat_messages, @@ -1627,7 +1647,7 @@ async def portal_websocket(websocket: WebSocket): logger.warning("EventQueue not configured; awaiting background task directly") try: await bg_task - except Exception: + except (RuntimeError, ConnectionError, asyncio.TimeoutError): pass # errors handled inside _execute_react_background active_bg_task = None continue @@ -1734,6 +1754,8 @@ async def portal_websocket(websocket: WebSocket): # kill the task, lose the full output, and mark it FAILED — # defeating layers 2 and 3. The task is only cancelled on explicit # user cancel (msg_type == 'cancel') or application shutdown. + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Portal WebSocket error: {e}") # P1 #6 fix: Do NOT cancel the background task on connection-level @@ -1758,7 +1780,7 @@ async def portal_websocket(websocket: WebSocket): ) try: await websocket.send_json({"type": "error", "data": {"message": str(e)}}) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: # Remove from user-scoped push tracking on any disconnect/error/return. diff --git a/src/agentkit/server/routes/skill_management.py b/src/agentkit/server/routes/skill_management.py index a998c03..12838bb 100644 --- a/src/agentkit/server/routes/skill_management.py +++ b/src/agentkit/server/routes/skill_management.py @@ -119,7 +119,7 @@ def _skill_to_detail(skill: Any) -> dict[str, Any]: if hasattr(skill, "config"): try: config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {} - except Exception: + except (AttributeError, ValueError, TypeError): config = {} return { @@ -174,7 +174,7 @@ async def get_skill_detail(skill_name: str, req: Request): skill_registry = req.app.state.skill_registry try: skill = skill_registry.get(skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") return _skill_to_detail(skill) @@ -186,7 +186,7 @@ async def check_skill_health(skill_name: str, req: Request): skill_registry = req.app.state.skill_registry try: skill_registry.get(skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") # Basic health check - skill exists and is registered @@ -243,7 +243,7 @@ async def reload_skill(skill_name: str, req: Request): # Verify the skill is currently registered (404 if not). try: skill_registry.get(skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") # Resolve the skills directory (mirrors routes.skills._get_skills_dir). diff --git a/src/agentkit/server/routes/skills.py b/src/agentkit/server/routes/skills.py index abb8a6c..e40a5f8 100644 --- a/src/agentkit/server/routes/skills.py +++ b/src/agentkit/server/routes/skills.py @@ -136,7 +136,7 @@ async def register_skill(request: RegisterSkillRequest, req: Request): try: config = SkillConfig.from_dict(request.config) - except Exception as e: + except (ValueError, TypeError, KeyError) as e: raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}") skill = Skill(config=config) @@ -279,7 +279,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): resp = await client.get(source) resp.raise_for_status() yaml_content = resp.text - except Exception as e: + except (httpx.HTTPError, OSError) as e: raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}") elif source and source.startswith("file://"): # Read from local file path @@ -295,7 +295,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): try: with open(local_path, encoding="utf-8") as f: yaml_content = f.read() - except Exception as e: + except OSError as e: raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}") else: # Search GitHub for skills (YAML config files) @@ -313,7 +313,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): }, ) gh_data = gh_resp.json() - except Exception as e: + except (httpx.HTTPError, OSError, ValueError) as e: raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}") items = gh_data.get("items", []) @@ -334,7 +334,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): }, ) items = gh_resp2.json().get("items", []) - except Exception: + except (httpx.HTTPError, OSError, ValueError, KeyError): items = [] if not items: @@ -362,7 +362,7 @@ async def install_skill(request: InstallSkillRequest, req: Request): resp = await client.get(raw_url) resp.raise_for_status() yaml_content = resp.text - except Exception as e: + except (httpx.HTTPError, OSError) as e: raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}") # Validate YAML content before writing to disk @@ -391,14 +391,14 @@ async def install_skill(request: InstallSkillRequest, req: Request): ) loader.load_from_file(file_path) registration_ok = True - except Exception as e: + except (ValueError, TypeError, KeyError, OSError, RuntimeError) as e: logger.warning(f"Failed to register installed skill: {e}") if not registration_ok: # Remove the invalid YAML file and report error try: os.remove(file_path) - except Exception: + except OSError: pass raise HTTPException(status_code=500, detail="Skill downloaded but registration failed") @@ -419,7 +419,7 @@ async def uninstall_skill(name: str, req: Request): try: skill_registry.get(validated_name) - except Exception: + except (KeyError, ValueError, RuntimeError): raise HTTPException(status_code=404, detail=f"Skill '{name}' not found") # Remove from registry @@ -487,7 +487,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ try: result = await pipeline.execute(input_data=request.input_data) - except Exception as e: + except (ValueError, TypeError, KeyError, RuntimeError, OSError, ConnectionError) as e: logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True) raise HTTPException(status_code=500, detail="Pipeline execution failed") diff --git a/src/agentkit/server/routes/system.py b/src/agentkit/server/routes/system.py index d85f8ad..5b7dcaf 100644 --- a/src/agentkit/server/routes/system.py +++ b/src/agentkit/server/routes/system.py @@ -24,7 +24,7 @@ def _read_meminfo() -> dict[str, int]: key = parts[0].strip() value = parts[1].strip().split()[0] values[key] = int(value) * 1024 # kB -> bytes - except Exception as exc: + except (OSError, ValueError, FileNotFoundError) as exc: logger.debug(f"Failed to read /proc/meminfo: {exc}") return values @@ -52,7 +52,7 @@ async def get_system_resources() -> dict[str, Any]: if hasattr(os, "getloadavg"): try: loadavg = list(os.getloadavg()) - except Exception as exc: + except (OSError, AttributeError) as exc: logger.debug(f"Failed to get loadavg: {exc}") meminfo = _read_meminfo() @@ -68,7 +68,7 @@ async def get_system_resources() -> dict[str, Any]: disk_total = du.total disk_used = du.used disk_free = du.free - except Exception as exc: + except (OSError, FileNotFoundError) as exc: logger.debug(f"Failed to get disk usage: {exc}") return { diff --git a/src/agentkit/server/routes/tasks.py b/src/agentkit/server/routes/tasks.py index b274e0e..1e00ca3 100644 --- a/src/agentkit/server/routes/tasks.py +++ b/src/agentkit/server/routes/tasks.py @@ -83,7 +83,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): elif request.skill_name: try: skill = skill_registry.get(request.skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException( status_code=404, detail=f"Skill '{request.skill_name}' not found", @@ -145,7 +145,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): quality_result = await quality_gate.validate( task_result.output_data or {}, skill, skill_context=skill_context ) - except Exception: + except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError): pass # Quality gate failure shouldn't block the response # 7. Standardize output if skill available @@ -167,7 +167,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): "task_id": task.task_id, "status": task_result.status, } - except Exception: + except (ValueError, KeyError, AttributeError, RuntimeError): pass # Fall through to raw output # 8. Return raw result if no skill or standardization failed @@ -307,7 +307,7 @@ async def resume_task(task_id: str, req: Request, plan_id: str | None = None): finally: try: await team.dissolve() - except Exception: + except (RuntimeError, asyncio.TimeoutError, ConnectionError): pass return { @@ -343,7 +343,7 @@ async def stream_task(request: SubmitTaskRequest, req: Request): elif request.skill_name: try: skill_registry.get(request.skill_name) - except Exception: + except (KeyError, ValueError, AttributeError): raise HTTPException( status_code=404, detail=f"Skill '{request.skill_name}' not found", diff --git a/src/agentkit/server/routes/terminal.py b/src/agentkit/server/routes/terminal.py index fe7f876..0577345 100644 --- a/src/agentkit/server/routes/terminal.py +++ b/src/agentkit/server/routes/terminal.py @@ -359,14 +359,14 @@ async def execute_command( await check_pty.close() if cwd_result.exit_code == 0: state.cwd = cwd_result.output.strip() - except Exception: + except (OSError, RuntimeError, asyncio.TimeoutError): pass except asyncio.TimeoutError: output = f"命令执行超时({request.timeout}s)" exit_code = -1 duration_ms = int((time.monotonic() - start_time) * 1000) - except Exception as e: + except (OSError, RuntimeError, ValueError) as e: output = str(e) exit_code = -1 duration_ms = int((time.monotonic() - start_time) * 1000) @@ -515,7 +515,7 @@ async def terminal_websocket(websocket: WebSocket) -> None: }) await websocket.close(code=4003, reason="Permission denied") return - except Exception: + except (ValueError, KeyError, RuntimeError, OSError): pass # Fall through to API key / dev mode # 2. API key via ?api_key= @@ -721,7 +721,7 @@ async def terminal_websocket(websocket: WebSocket) -> None: }) except asyncio.TimeoutError: continue - except Exception: + except (OSError, RuntimeError): break # Get final result @@ -741,7 +741,7 @@ async def terminal_websocket(websocket: WebSocket) -> None: await check_pty.close() if cwd_result.exit_code == 0: state.cwd = cwd_result.output.strip() - except Exception: + except (OSError, RuntimeError, asyncio.TimeoutError): pass # Record history @@ -790,6 +790,8 @@ async def terminal_websocket(websocket: WebSocket) -> None: "cwd": state.cwd, "duration_ms": duration_ms, }) + except asyncio.CancelledError: + raise except Exception as e: duration_ms = int((time.monotonic() - start_time) * 1000) await websocket.send_json({ @@ -800,7 +802,7 @@ async def terminal_websocket(websocket: WebSocket) -> None: if active_pty and active_pty.is_running: try: await active_pty.close() - except Exception: + except (OSError, RuntimeError): pass active_pty = None @@ -836,11 +838,13 @@ async def terminal_websocket(websocket: WebSocket) -> None: except WebSocketDisconnect: logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}") + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}") try: await websocket.send_json({"type": "error", "message": str(e)[:200]}) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: # Clean up pending confirmations @@ -851,5 +855,5 @@ async def terminal_websocket(websocket: WebSocket) -> None: if active_pty and active_pty.is_running: try: await active_pty.close() - except Exception: + except (OSError, RuntimeError): pass diff --git a/src/agentkit/server/routes/terminal_server.py b/src/agentkit/server/routes/terminal_server.py index 80df81d..05db4ae 100644 --- a/src/agentkit/server/routes/terminal_server.py +++ b/src/agentkit/server/routes/terminal_server.py @@ -549,7 +549,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: }) await websocket.close(code=4003, reason="Permission denied") return - except Exception: + except (ValueError, KeyError, RuntimeError, OSError): pass if not auth_ok: @@ -585,7 +585,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: }) await websocket.close(code=4003, reason="Not authorized") return - except Exception as e: + except (aiosqlite.Error, OSError, ValueError, KeyError, RuntimeError) as e: logger.warning(f"Failed to check server terminal authorization: {e}") # If DB check fails, deny access await websocket.send_json({ @@ -822,7 +822,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: }) except asyncio.TimeoutError: continue - except Exception: + except (OSError, RuntimeError): break result = await run_task @@ -846,7 +846,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: else: base = state.cwd or os.path.expanduser("~") state.cwd = os.path.normpath(os.path.join(base, cd_arg)) - except Exception: + except (ValueError, OSError, RuntimeError): pass # cd parsing failed — keep old cwd # Record history @@ -893,6 +893,8 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: "cwd": state.cwd, "duration_ms": duration_ms, }) + except asyncio.CancelledError: + raise except Exception as e: duration_ms = int((time.monotonic() - start_time) * 1000) await websocket.send_json({ @@ -903,7 +905,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: if active_pty and active_pty.is_running: try: await active_pty.close() - except Exception: + except (OSError, RuntimeError): pass active_pty = None @@ -942,17 +944,19 @@ async def server_terminal_websocket(websocket: WebSocket) -> None: except WebSocketDisconnect: logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}") + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}") try: await websocket.send_json({"type": "error", "message": str(e)[:200]}) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: if active_pty and active_pty.is_running: try: await active_pty.close() - except Exception: + except (OSError, RuntimeError): pass # Clean up session from in-memory stores to prevent leaks _server_sessions.pop(state.session_id, None) diff --git a/src/agentkit/server/routes/workflows.py b/src/agentkit/server/routes/workflows.py index a41331f..8676221 100644 --- a/src/agentkit/server/routes/workflows.py +++ b/src/agentkit/server/routes/workflows.py @@ -433,6 +433,8 @@ async def _execute_workflow( ) result = await agent.handle_task(task) stage_result = {"output": result, "skill": stage.action} + except asyncio.CancelledError: + raise except Exception as e: stage_result = {"error": str(e), "skill": stage.action} else: @@ -474,6 +476,8 @@ async def _execute_workflow( ) result = await agent.handle_task(task) return {"output": result, "skill": action} + except asyncio.CancelledError: + raise except Exception as e: return {"error": str(e), "skill": action} return {"dry_run": True, "action": action} @@ -515,6 +519,8 @@ async def _execute_workflow( execution_id=execution.execution_id, ) + except asyncio.CancelledError: + raise except Exception as e: execution.stage_results[stage_name] = { "status": "failed", @@ -633,7 +639,7 @@ async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None for ws in targets: try: await ws.send_json(message) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): disconnected.append(ws) if disconnected: async with _ws_subscribers_lock: @@ -952,7 +958,7 @@ async def workflow_websocket(websocket: WebSocket): # Keep connection alive - messages are primarily server-push except WebSocketDisconnect: logger.debug("Workflow WebSocket disconnected") - except Exception as e: + except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.error(f"Workflow WebSocket error: {e}") finally: if subscribed_execution_id: diff --git a/src/agentkit/server/routes/ws.py b/src/agentkit/server/routes/ws.py index 10bd943..338885e 100644 --- a/src/agentkit/server/routes/ws.py +++ b/src/agentkit/server/routes/ws.py @@ -47,7 +47,7 @@ class ConnectionManager: for ws, _ in conns: try: await ws.send_json(message) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): stale.append(ws) for ws in stale: self.remove(task_id, ws) @@ -153,6 +153,8 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None: except WebSocketDisconnect: logger.debug(f"WebSocket disconnected for task {task_id}") + except asyncio.CancelledError: + raise except Exception as e: logger.error(f"WebSocket error for task {task_id}: {e}") try: @@ -162,7 +164,7 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None: "data": {"message": str(e)}, } ) - except Exception: + except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: manager.remove(task_id, websocket) @@ -243,6 +245,8 @@ async def _run_react_and_stream( }, ) + except asyncio.CancelledError: + raise except Exception as e: await websocket.send_json( {