Compare commits

...

3 Commits

Author SHA1 Message Date
chiguyong 7b1b198058 refactor(orchestrator+bitable): remove Any from type signatures
Test / backend-test (pull_request) Has been cancelled Details
Test / frontend-unit (pull_request) Has been cancelled Details
Test / api-e2e (pull_request) Has been cancelled Details
Test / frontend-e2e (pull_request) Has been cancelled Details
Eliminate 112 Any usages across orchestrator/ (62) and bitable/ (50) via:
- TYPE_CHECKING Protocol for Redis/LLMGateway/Plan/Dispatcher/StateManager
- object for arbitrary dict/list/value types (Pydantic v2 serializes fine)
- RecalcTask concrete import (replacing Any in recalc_worker.py)
- Coroutine[object, object, object] for async generic
- Remove unused Any imports (F401 cleanup)

Note: Avoided recursive TypeAlias (FieldValue) because Pydantic v2 cannot
build schemas for recursive named aliases (RecursionError).

Tests: 245 passed (bitable 91 + orchestrator 154), 0 regressions
ruff: All checks passed
2026-07-01 02:41:14 +08:00
chiguyong 34a89c4873 refactor(llm+memory+client): remove Any from type signatures
Eliminate 172 Any usages across llm/, memory/, client/ via:
- TypeAlias (MetadataValue, MetadataDict, RAGSearchResult, etc.)
- object for arbitrary dict/value types
- TYPE_CHECKING Protocol for Redis/Quota/RAG/Graph services
- TYPE_CHECKING import + string annotations for forward refs
- Remove unused Any imports (18 F401 fixed)

Tests: 253 passed (llm 21 failures are pre-existing litellm env issue)
ruff: All checks passed
2026-07-01 02:03:51 +08:00
chiguyong aa6367ff9f refactor(server/routes): classify except Exception in 23 route files
Narrow 131 except Exception to specific exception types across all
server/routes/ modules. Framework boundaries (main execute paths,
WebSocket top-level) retain except Exception with asyncio.CancelledError
guard.

Categories:
- WebSocket ops: (ConnectionError, RuntimeError, asyncio.TimeoutError)
- DB/Store ops: (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError)
- EventQueue: (asyncio.QueueFull, RuntimeError, ConnectionError)
- Config construction: (ValueError, TypeError, KeyError)
- Cleanup/dissolve: (RuntimeError, asyncio.TimeoutError, ConnectionError)
- HTTP handlers: business-specific exceptions
- Framework boundaries: retain except Exception + CancelledError guard

Stats: 101 narrowed, 31 framework boundary retained, 2 noqa (DB resilience)

Follow-up to PR #8 (U1-U5 systematic tech debt cleanup).
2026-07-01 01:26:31 +08:00
79 changed files with 1388 additions and 900 deletions

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import ast
from collections import deque
from typing import Any
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
from agentkit.bitable.formula.parser import (
@ -104,9 +103,9 @@ class FormulaEngine:
def evaluate(
self,
field_id: str,
row_values: dict[str, Any],
column_values: dict[str, list[Any]] | None = None,
) -> Any:
row_values: dict[str, object],
column_values: dict[str, list[object]] | None = None,
) -> object:
"""Evaluate a formula field for a specific record.
Args:
@ -130,7 +129,7 @@ class FormulaEngine:
# Build the field_values dict for the evaluator
# Aggregate refs get column values (lists), row refs get row values (scalars)
eval_values: dict[str, Any] = {}
eval_values: dict[str, object] = {}
# Map real field IDs to safe names
for safe_name, real_id in entry.field_mapping.items():
@ -143,16 +142,16 @@ class FormulaEngine:
def evaluate_all_for_record(
self,
row_values: dict[str, Any],
column_values: dict[str, list[Any]] | None = None,
) -> dict[str, Any]:
row_values: dict[str, object],
column_values: dict[str, list[object]] | None = None,
) -> dict[str, object]:
"""Evaluate all registered formulas for a record.
Returns a dict of field_id computed value.
Formulas are evaluated in topological order so that formula-to-formula
dependencies are resolved correctly.
"""
results: dict[str, Any] = {}
results: dict[str, object] = {}
column_values = column_values or {}
for field_id in self.topological_order():

View File

@ -16,13 +16,11 @@ Usage::
from __future__ import annotations
from typing import Any
def transform_records(
records: list[dict[str, Any]],
records: list[dict[str, object]],
field_mapping: dict[str, str],
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
"""Map source record keys to bitable field IDs via field_mapping.
Keys not in ``field_mapping`` are dropped. Values are passed through
@ -40,9 +38,9 @@ def transform_records(
if not field_mapping:
return []
transformed: list[dict[str, Any]] = []
transformed: list[dict[str, object]] = []
for rec in records:
out: dict[str, Any] = {}
out: dict[str, object] = {}
for src_key, field_id in field_mapping.items():
if src_key in rec:
out[field_id] = rec[src_key]

View File

@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable):
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import (
BigInteger,
@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = {
READ_BATCH = 1000
def infer_field_type(sqla_type: Any) -> str:
def infer_field_type(sqla_type: object) -> str:
"""Map a SQLAlchemy column type instance or class to a bitable field type.
Handles both type instances (``Integer()``) and type classes (``Integer``).
@ -78,7 +77,7 @@ def import_table(
table_name: str,
*,
max_rows: int = 50_000,
) -> dict[str, Any]:
) -> dict[str, object]:
"""Reflect a single table from an external DB.
Returns ``{"table_name": str, "fields": [...], "records": [...],
@ -97,7 +96,7 @@ def import_table(
engine.dispose()
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]:
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]:
"""Reflect one table and read its rows."""
insp = inspect(engine)
@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
table = Table(table_name, metadata, autoload_with=engine)
# Build field definitions
fields: list[dict[str, Any]] = []
fields: list[dict[str, object]] = []
pk_columns = list(table.primary_key.columns)
pk_name = pk_columns[0].name if pk_columns else None
@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
pk_name = "id"
# Read rows
records: list[dict[str, Any]] = []
records: list[dict[str, object]] = []
with engine.connect() as conn:
result = conn.execute(select(table))
for i, row in enumerate(result):
if i >= max_rows:
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
break
rec: dict[str, Any] = {}
rec: dict[str, object] = {}
for col in table.columns:
val = getattr(row, col.name, None)
if val is not None:
@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
}
def _serialize(val: Any) -> Any:
def _serialize(val: object) -> object:
"""Serialize a DB value to JSON-safe form."""
from datetime import date, datetime
from decimal import Decimal

View File

@ -18,7 +18,6 @@ import logging
import socket
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import httpx
@ -36,7 +35,7 @@ class ParsedSheet:
name: str
columns: list[str] = field(default_factory=list)
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
records: list[dict[str, Any]] = field(default_factory=list)
records: list[dict[str, object]] = field(default_factory=list)
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None:
col_count = len(clean_headers)
field_types = _infer_column_types(data_rows, col_count)
records: list[dict[str, Any]] = []
records: list[dict[str, object]] = []
for row in data_rows:
rec: dict[str, Any] = {}
rec: dict[str, object] = {}
for i, col_name in enumerate(clean_headers):
val = row[i] if i < len(row) else None
if val is not None:
@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]:
return types
def _coerce_value(val: Any, field_type: str) -> Any:
def _coerce_value(val: object, field_type: str) -> object:
"""Coerce a cell value to the inferred field type. Truncate long strings."""
if field_type == "date":
from datetime import datetime

View File

@ -9,10 +9,14 @@ from __future__ import annotations
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field as PydanticField
# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of
# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive
# named aliases (RecursionError). `object` is the most permissive type and
# Pydantic v2 serializes dict/list/primitive values fine at runtime.
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
@ -97,7 +101,7 @@ class Table(BaseModel):
# ---------------------------------------------------------------------------
# Status select field options — labels and colors match Feishu Bitable defaults.
_STATUS_OPTIONS: list[dict[str, Any]] = [
_STATUS_OPTIONS: list[dict[str, object]] = [
{"label": "未开始", "value": "not_started", "color": "default"},
{"label": "进行中", "value": "in_progress", "color": "processing"},
{"label": "已完成", "value": "done", "color": "success"},
@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [
#: Templates for the 5 default fields created on every new table (R2).
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
#: on record creation; user-owned fields are user-editable.
DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [
DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [
{
"name": "标题",
"field_type": FieldType.text,
@ -155,7 +159,7 @@ class Field(BaseModel):
table_id: str
name: str
field_type: FieldType
config: dict[str, Any] = PydanticField(default_factory=dict)
config: dict[str, object] = PydanticField(default_factory=dict)
owner: FieldOwner = FieldOwner.user
created_at: datetime = PydanticField(default_factory=_utcnow)
@ -167,7 +171,7 @@ class Record(BaseModel):
id: str
table_id: str
values: dict[str, Any] = PydanticField(default_factory=dict)
values: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow)
updated_at: datetime = PydanticField(default_factory=_utcnow)
@ -181,7 +185,7 @@ class View(BaseModel):
table_id: str
name: str
view_type: ViewType = ViewType.grid
config: dict[str, Any] = PydanticField(default_factory=dict)
config: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow)

View File

@ -18,11 +18,10 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any
from agentkit.bitable.db import BitableDB
from agentkit.bitable.formula.engine import FormulaEngine
from agentkit.bitable.models import FieldType, RecalcStatus
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
from agentkit.bitable.repository import BitableRepository
from agentkit.bitable.service import BitableService
@ -124,7 +123,7 @@ class RecalcWorker:
logger.exception("RecalcWorker error in main loop")
await asyncio.sleep(self._poll_interval)
async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]:
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
"""Sort claimed tasks so dependencies are processed first (P1 #7).
Groups tasks by table_id, builds (or reuses) the engine to get the
@ -146,7 +145,7 @@ class RecalcWorker:
order = engine.topological_order()
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
def _key(t: Any) -> tuple[str, str, int]:
def _key(t: RecalcTask) -> tuple[str, str, int]:
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
return (t.table_id, t.record_id, idx)
@ -175,7 +174,7 @@ class RecalcWorker:
except Exception:
logger.exception("RecalcWorker reaper error")
async def process_task(self, task: Any) -> None:
async def process_task(self, task: RecalcTask) -> None:
"""Process a single recalc task: evaluate formula → write result.
The task is expected to already be in ``calculating`` status when
@ -216,7 +215,7 @@ class RecalcWorker:
return
deps = engine.get_dependencies(task.field_id)
column_values: dict[str, list[Any]] = {}
column_values: dict[str, list[object]] = {}
for dep_field_id in deps:
column_values[dep_field_id] = await self._repo.get_column_values(
task.table_id, dep_field_id

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import delete, func, insert, select, text, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
@ -102,7 +101,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [BitableFile.model_validate(e) for e in result.scalars().all()]
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
"""Update a file's attributes."""
async with self._session_factory() as session:
stmt = (
@ -181,7 +180,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [Table.model_validate(e) for e in result.scalars().all()]
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
"""Update a table's attributes."""
async with self._session_factory() as session:
stmt = (
@ -236,7 +235,7 @@ class BitableRepository:
table_id: str,
name: str,
field_type: FieldType,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
owner: FieldOwner = FieldOwner.user,
) -> Field:
"""Create a new field in a table."""
@ -277,7 +276,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [Field.model_validate(e) for e in result.scalars().all()]
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
"""Update a field's attributes."""
async with self._session_factory() as session:
stmt = (
@ -300,7 +299,7 @@ class BitableRepository:
# ── Records ─────────────────────────────────────────────
async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record:
async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record:
"""Create a new record."""
async with self._session_factory() as session:
stmt = (
@ -318,7 +317,7 @@ class BitableRepository:
return Record.model_validate(entity)
async def create_records_batch(
self, table_id: str, records_values: list[dict[str, Any]]
self, table_id: str, records_values: list[dict[str, object]]
) -> list[Record]:
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
@ -376,7 +375,7 @@ class BitableRepository:
return [Record.model_validate(e) for e in entities], next_cursor
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None:
"""Update a record's values (full replace)."""
async with self._session_factory() as session:
stmt = (
@ -413,7 +412,7 @@ class BitableRepository:
table_id: str,
name: str,
view_type: ViewType = ViewType.grid,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
) -> View:
"""Create a new view."""
async with self._session_factory() as session:
@ -451,7 +450,7 @@ class BitableRepository:
result = await session.execute(stmt)
return [View.model_validate(e) for e in result.scalars().all()]
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
"""Update a view's attributes."""
async with self._session_factory() as session:
stmt = (
@ -543,7 +542,7 @@ class BitableRepository:
) -> None:
"""Update a recalc task's status."""
async with self._session_factory() as session:
kwargs: dict[str, Any] = {"status": status.value}
kwargs: dict[str, object] = {"status": status.value}
if error_message is not None:
kwargs["error_message"] = error_message
if status in (RecalcStatus.done, RecalcStatus.error):
@ -630,7 +629,7 @@ class BitableRepository:
return result_map
async def upsert_record_agent_fields(
self, record_id: str, agent_field_values: dict[str, Any]
self, record_id: str, agent_field_values: dict[str, object]
) -> None:
"""Update agent-owned fields using jsonb_set (KTD8).
@ -646,7 +645,7 @@ class BitableRepository:
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
# misparses the `::` as part of the param name.
inner = "values"
params: dict[str, Any] = {"record_id": record_id}
params: dict[str, object] = {"record_id": record_id}
for i, (field_id, value) in enumerate(agent_field_values.items()):
param_key = f"v{i}"
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
@ -660,8 +659,8 @@ class BitableRepository:
async def list_records_filtered(
self,
table_id: str,
filters: list[dict[str, Any]] | None = None,
sorts: list[dict[str, Any]] | None = None,
filters: list[dict[str, object]] | None = None,
sorts: list[dict[str, object]] | None = None,
cursor: str | None = None,
limit: int = 50,
) -> tuple[list[Record], str | None]:
@ -682,7 +681,7 @@ class BitableRepository:
# Build raw SQL with JSONB filter/sort translation.
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
where_clauses = ["table_id = :table_id"]
params: dict[str, Any] = {"table_id": table_id}
params: dict[str, object] = {"table_id": table_id}
if filters:
for i, f in enumerate(filters):
@ -783,7 +782,7 @@ class BitableRepository:
last_mapping = last_row._mapping
# Build composite cursor from sort values + id.
# Sort values are extracted as text to match `values->>'fid'` expressions.
sv: list[Any] = []
sv: list[object] = []
last_values = last_mapping.get("values")
if isinstance(last_values, str):
# asyncpg may return JSONB as str in raw text() queries.
@ -852,7 +851,7 @@ class BitableRepository:
# ── Recalc support (U3) ────────────────────────────────
async def get_column_values(self, table_id: str, field_id: str) -> list[Any]:
async def get_column_values(self, table_id: str, field_id: str) -> list[object]:
"""Get all values for a field across all records in a table (for aggregates).
Returns a list of values (preserving order by record id). Missing values
@ -866,7 +865,7 @@ class BitableRepository:
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
return [row[0] for row in result.fetchall()]
async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None:
async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None:
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
import json

View File

@ -34,12 +34,19 @@ import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable
from typing import Callable, TypeAlias
import httpx
logger = logging.getLogger(__name__)
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
# 服务器返回的 skills/workflows 列表元素是 dictmodel_dump/to_dict 输出),
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
SkillConfigDict: TypeAlias = dict[str, object]
WorkflowConfigDict: TypeAlias = dict[str, object]
SyncedConfigPayload: TypeAlias = dict[str, object]
# ── Defaults ──────────────────────────────────────────────────────────
@ -100,8 +107,8 @@ class ConfigSync:
# In-memory cache (mirrors the SQLite cache for fast access)
self._version: str | None = None
self._skills: list[dict[str, Any]] = []
self._workflows: list[dict[str, Any]] = []
self._skills: list[SkillConfigDict] = []
self._workflows: list[WorkflowConfigDict] = []
self._last_synced_at: str | None = None
# ── Lifecycle ─────────────────────────────────────────────────
@ -232,15 +239,15 @@ class ConfigSync:
"""Return the current cached config version hash."""
return self._version
def get_skills(self) -> list[dict[str, Any]]:
def get_skills(self) -> list[SkillConfigDict]:
"""Return the cached skill configs."""
return list(self._skills)
def get_workflows(self) -> list[dict[str, Any]]:
def get_workflows(self) -> list[WorkflowConfigDict]:
"""Return the cached workflow configs."""
return list(self._workflows)
def get_all(self) -> dict[str, Any]:
def get_all(self) -> SyncedConfigPayload:
"""Return all cached configs as a single dict."""
return {
"version": self._version,
@ -249,14 +256,14 @@ class ConfigSync:
"synced_at": self._last_synced_at,
}
def get_skill(self, name: str) -> dict[str, Any] | None:
def get_skill(self, name: str) -> SkillConfigDict | None:
"""Return a single skill config by name, or ``None``."""
for skill in self._skills:
if skill.get("name") == name:
return skill
return None
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None:
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
"""Return a single workflow config by ID, or ``None``."""
for wf in self._workflows:
if wf.get("workflow_id") == workflow_id:
@ -281,7 +288,7 @@ class ConfigSync:
conn.executescript(_CACHE_SCHEMA)
conn.commit()
def _save_to_cache(self, data: dict[str, Any]) -> None:
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
"""Save the synced configs to the local SQLite cache."""
now = datetime.now(timezone.utc).isoformat()
with sqlite3.connect(str(self.cache_db_path)) as conn:

View File

@ -14,7 +14,7 @@ import logging
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity
@ -25,6 +25,52 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的第三方对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _RedisLike(Protocol):
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
async def get(self, key: str) -> bytes | str | None: ...
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
async def smembers(self, key: str) -> set[bytes | str]: ...
async def sadd(self, name: str, *values: str) -> int: ...
async def srem(self, name: str, *values: str) -> int: ...
async def scard(self, name: str) -> int: ...
async def delete(self, *names: str) -> int: ...
def pipeline(self) -> "_RedisPipelineLike": ...
class _RedisPipelineLike(Protocol):
"""Redis pipeline 最小契约。"""
def get(self, key: str) -> "_RedisPipelineLike": ...
def set(
self, key: str, value: bytes | str, ex: int | None = None
) -> "_RedisPipelineLike": ...
def delete(self, *names: str) -> "_RedisPipelineLike": ...
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
async def execute(self) -> list[object]: ...
# ---------------------------------------------------------------------------
# Data Classes
# ---------------------------------------------------------------------------
@ -328,7 +374,7 @@ class RedisLLMCache:
self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold
self._max_entries_to_scan = max_entries_to_scan
self._redis: Any = None
self._redis: _RedisLike | None = None
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
self._degraded = False # True if Redis is unreachable
@ -691,7 +737,7 @@ class LitellmCacheManager:
def __init__(self, config: LitellmCacheConfig):
self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例
self._cache_instance: object | None = None # litellm.caching.Cache 实例
self._hits = 0
self._misses = 0
@ -709,7 +755,7 @@ class LitellmCacheManager:
litellm.cache = None
self._cache_instance = None
def _create_cache_instance(self) -> Any:
def _create_cache_instance(self) -> object:
"""根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache

View File

@ -2,14 +2,13 @@
import hashlib
import json
from typing import Any
def generate_cache_key(
model: str,
messages: list[dict[str, str]],
temperature: float,
tools: list[dict[str, Any]] | None = None,
tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto",
max_tokens: int = 2000,
user_id: str | None = None,

View File

@ -3,12 +3,12 @@
import json
import logging
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
if TYPE_CHECKING:
from agentkit.channels.secrets import SecretsStore
from agentkit.channels.secrets import SecretEntry, SecretsStore
logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class ProviderConfig:
api_key: str
base_url: str
models: dict[str, dict[str, Any]] = field(default_factory=dict)
models: dict[str, dict[str, object]] = field(default_factory=dict)
type: str = "openai" # "openai" | "anthropic" | "gemini"
max_tokens: int = 4096 # Anthropic: default max_tokens
timeout: float = 120.0 # Anthropic: request timeout
@ -168,18 +168,18 @@ class ProviderConfig:
return f"llm:provider:{self.type}:api_key"
@staticmethod
def _encode_secret_entry(entry: Any, key: str) -> str:
def _encode_secret_entry(entry: object, key: str) -> str:
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
# entry 是 SecretEntry pydantic 模型,有 model_dump()
if hasattr(entry, "model_dump"):
data = entry.model_dump()
data = entry.model_dump() # type: ignore[attr-defined]
else:
data = dict(entry)
data = dict(entry) # type: ignore[call-overload]
data["key"] = key
return json.dumps(data)
@staticmethod
def _decode_secret_entry(encoded: str) -> Any:
def _decode_secret_entry(encoded: str) -> "SecretEntry":
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
from agentkit.channels.secrets import SecretEntry

View File

@ -5,7 +5,7 @@ import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Protocol
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import llm_token_histogram
if TYPE_CHECKING:
from agentkit.llm.cache import LitellmCacheManager
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _QuotaServiceLike(Protocol):
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
async def is_model_allowed(
self, db: Path, department_id: str, model: str
) -> tuple[bool, str]: ...
async def check_quota(
self,
db: Path,
department_id: str,
quota_type: str,
period: str,
current: float,
) -> tuple[bool, str]: ...
async def get_quota(
self, db: Path, department_id: str, quota_type: str, period: str
) -> dict[str, object] | None: ...
class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded.
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
department_id: str,
quota_type: str,
period: str,
limit: Any,
current: Any,
limit: object,
current: object,
) -> None:
self.department_id = department_id
self.quota_type = quota_type
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig()
# Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache_manager: Any = None # LitellmCacheManager | None
self._cache_manager: "LitellmCacheManager | None" = None
if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
@ -601,7 +632,7 @@ class LLMGateway:
async def _check_quota_value(
self,
quota_service: Any,
quota_service: _QuotaServiceLike,
db: Path,
dept_id: str,
period: str,

View File

@ -27,12 +27,11 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
流程
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
store = SecretsStore() # master key 从 env 加载
async def _run() -> dict[str, dict[str, Any]]:
report: dict[str, dict[str, Any]] = {}
async def _run() -> dict[str, dict[str, object]]:
report: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items():
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
report[name] = {"status": "skipped", "source": pconf.api_key_source}
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
report = asyncio.run(_run())
# 写回 YAML更新 llm.providers 段,保留其它段
providers_out: dict[str, dict[str, Any]] = {}
providers_out: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items():
entry: dict[str, Any] = {
entry: dict[str, object] = {
"type": pconf.type,
"base_url": pconf.base_url,
"models": pconf.models,

View File

@ -2,7 +2,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass
@ -23,7 +22,7 @@ class ToolCall:
id: str
name: str
arguments: dict[str, Any]
arguments: dict[str, object]
@dataclass
@ -32,7 +31,7 @@ class LLMRequest:
messages: list[dict[str, str]]
model: str
tools: list[dict[str, Any]] | None = None
tools: list[dict[str, object]] | None = None
tool_choice: str = "auto"
temperature: float = 0.7
max_tokens: int = 2000
@ -42,13 +41,13 @@ class LLMRequest:
self,
messages: list[dict[str, str]],
model: str,
tools: list[dict[str, Any]] | None = None,
tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto",
temperature: float = 0.7,
max_tokens: int = 2000,
timeout: float | None = None,
cache: dict[str, Any] | None = None,
**kwargs: Any,
cache: dict[str, object] | None = None,
**kwargs: object,
):
self.messages = messages
self.model = model
@ -59,7 +58,7 @@ class LLMRequest:
self.timeout = timeout
self._extra = kwargs
# U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion
self._cache: dict[str, Any] | None = cache
self._cache: dict[str, object] | None = cache
@dataclass

View File

@ -3,7 +3,6 @@
import json
import logging
import time
from typing import Any
import httpx
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
"content-type": "application/json",
}
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]:
def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Anthropic 格式
Returns:
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
- None: system 消息
"""
system_prompt: str | list[dict[str, Any]] | None = None
anthropic_messages: list[dict[str, Any]] = []
system_prompt: str | list[dict[str, object]] | None = None
anthropic_messages: list[dict[str, object]] = []
for msg in messages:
role = msg.get("role", "")
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
# 检查是否有 tool_calls (OpenAI 格式)
tool_calls = msg.get("tool_calls")
if tool_calls:
blocks: list[dict[str, Any]] = []
blocks: list[dict[str, object]] = []
# 如果有文本内容,先添加文本块
if content:
blocks.append({"type": "text", "text": content})
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {"raw": arguments}
blocks.append({
blocks.append(
{
"type": "tool_use",
"id": tc.get("id", ""),
"name": func.get("name", ""),
"input": arguments,
})
}
)
anthropic_messages.append({"role": "assistant", "content": blocks})
else:
anthropic_messages.append({
anthropic_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": content}],
})
}
)
continue
if role == "user":
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
if msg.get("tool_call_id"):
tool_result_blocks: list[dict[str, Any]] = []
tool_result_blocks: list[dict[str, object]] = []
tool_content = msg.get("content", "")
# tool_result 的 content 可以是字符串或内容块列表
if isinstance(tool_content, str):
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
else:
tool_result_blocks.append({"type": "text", "text": str(tool_content)})
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": tool_result_blocks,
}],
})
}
],
}
)
else:
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{"type": "text", "text": content}],
})
}
)
continue
if role == "tool":
# OpenAI 格式中独立的 tool 消息
tool_content = msg.get("content", "")
if isinstance(tool_content, str):
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}]
result_content: list[dict[str, object]] | str = [
{"type": "text", "text": tool_content}
]
elif isinstance(tool_content, list):
result_content = tool_content
else:
result_content = [{"type": "text", "text": str(tool_content)}]
anthropic_messages.append({
anthropic_messages.append(
{
"role": "user",
"content": [{
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": result_content,
}],
})
}
],
}
)
return system_prompt, anthropic_messages
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Anthropic tool 格式"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
anthropic_tools.append({
anthropic_tools.append(
{
"name": func.get("name", ""),
"description": func.get("description", ""),
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
})
"input_schema": func.get(
"parameters", {"type": "object", "properties": {}}
),
}
)
return anthropic_tools
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
if tool_choice == "auto":
return {"type": "auto"}
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
return {"type": "tool", "name": tool_choice}
return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Anthropic 响应转换为 LLMResponse"""
content_blocks = data.get("content", [])
text_parts: list[str] = []
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
if block_type == "text":
text_parts.append(block.get("text", ""))
elif block_type == "tool_use":
tool_calls.append(ToolCall(
tool_calls.append(
ToolCall(
id=block.get("id", ""),
name=block.get("name", ""),
arguments=block.get("input", {}),
))
)
)
usage_data = data.get("usage", {})
usage = TokenUsage(
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"model": request.model,
"max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages,
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"model": request.model,
"max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages,
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks."""
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
accumulated_tool_calls: dict[str, dict[str, Any]] = {}
accumulated_tool_calls: dict[str, dict[str, object]] = {}
current_tool_id: str | None = None
current_tool_name: str | None = None
current_tool_input_json: str = ""
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
# Finalize current tool call if any
if current_tool_id is not None:
try:
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {}
arguments = (
json.loads(current_tool_input_json) if current_tool_input_json else {}
)
except json.JSONDecodeError:
arguments = {"raw": current_tool_input_json}
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
error_msg = error_info.get("message", "Stream error")
raise LLMProviderError("anthropic", error_msg)
def get_model_info(self) -> dict[str, Any]:
def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息"""
return {
"provider": "anthropic",

View File

@ -8,7 +8,6 @@ API火山引擎 OpenAI 兼容接口
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k",
**kwargs: Any,
**kwargs: object,
):
super().__init__(
api_key=api_key,

View File

@ -3,7 +3,6 @@
import json
import logging
import time
from typing import Any
import httpx
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Gemini 格式
Returns:
(system_instruction, contents)
"""
system_instruction: dict[str, Any] | None = None
contents: list[dict[str, Any]] = []
system_instruction: dict[str, object] | None = None
contents: list[dict[str, object]] = []
for msg in messages:
role = msg.get("role", "")
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
tool_name = parsed.get("name", "")
except (json.JSONDecodeError, AttributeError):
pass
contents.append({
contents.append(
{
"role": "user",
"parts": [{
"parts": [
{
"functionResponse": {
"name": tool_name,
"response": {
"content": content,
},
},
}],
})
}
],
}
)
else:
contents.append({
contents.append(
{
"role": "user",
"parts": [{"text": content}],
})
}
)
continue
if role == "assistant":
tool_calls = msg.get("tool_calls")
if tool_calls:
parts: list[dict[str, Any]] = []
parts: list[dict[str, object]] = []
if content:
parts.append({"text": content})
for tc in tool_calls:
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {"raw": arguments}
parts.append({
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": arguments,
},
})
}
)
contents.append({"role": "model", "parts": parts})
else:
contents.append({
contents.append(
{
"role": "model",
"parts": [{"text": content}],
})
}
)
continue
if role == "tool":
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
tool_name = msg.get("name", "")
tool_content = msg.get("content", "")
contents.append({
contents.append(
{
"role": "user",
"parts": [{
"parts": [
{
"functionResponse": {
"name": tool_name,
"response": {
"content": tool_content,
},
},
}],
})
}
],
}
)
return system_instruction, contents
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Gemini functionDeclarations"""
declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
declarations.append({
declarations.append(
{
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
})
}
)
if not declarations:
return []
return [{"functionDeclarations": declarations}]
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
if tool_choice == "auto":
return {"functionCallingConfig": {"mode": "AUTO"}}
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
return {"functionCallingConfig": {"mode": "NONE"}}
return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Gemini 响应转换为 LLMResponse"""
candidates = data.get("candidates", [])
text_parts: list[str] = []
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
text_parts.append(part["text"])
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append(ToolCall(
tool_calls.append(
ToolCall(
id=f"call_{tool_call_index}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
))
)
)
tool_call_index += 1
usage_metadata = data.get("usageMetadata", {})
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = {
payload: dict[str, object] = {
"contents": contents,
"generationConfig": {
"temperature": request.temperature,
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks."""
accumulated_tool_calls: list[dict[str, Any]] = []
accumulated_tool_calls: list[dict[str, object]] = []
model = request.model or self._model
async for line in response.aiter_lines():
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
)
elif "functionCall" in part:
fc = part["functionCall"]
accumulated_tool_calls.append({
accumulated_tool_calls.append(
{
"id": f"call_{len(accumulated_tool_calls)}",
"name": fc.get("name", ""),
"arguments": fc.get("args", {}),
})
}
)
# Check for finish reason
finish_reason = candidates[0].get("finishReason", "")
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
)
accumulated_tool_calls = []
def get_model_info(self) -> dict[str, Any]:
def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息"""
return {
"provider": "gemini",

View File

@ -26,8 +26,7 @@ import inspect
import json
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
from collections.abc import AsyncGenerator, Iterable
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import (
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
api_key: str,
base_url: str | None = None,
provider_type: str = "openai",
**default_kwargs: Any,
**default_kwargs: object,
) -> None:
self._model_prefix = model_prefix
self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
self._default_kwargs: dict[str, object] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
accumulated_tool_calls: dict[int, dict[str, object]] = {}
final_usage: TokenUsage | None = None
final_model: str = request.model
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
# 内部辅助
# ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, Any] = {
kwargs: dict[str, object] = {
"model": f"{self._model_prefix}{request.model}",
"messages": request.messages,
"temperature": request.temperature,
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
def _parse_response(
self,
response: Any,
response: object,
request_model: str,
latency_ms: float,
) -> LLMResponse:
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
def _parse_stream_chunk(
self,
chunk: Any,
chunk: object,
request_model: str,
accumulated_tool_calls: dict[int, dict[str, Any]],
accumulated_tool_calls: dict[int, dict[str, object]],
) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or []
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
def _finalize_tool_calls(
self,
accumulated: dict[int, dict[str, Any]],
accumulated: dict[int, dict[str, object]],
) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = []
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
# ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = []
for tc in raw_tool_calls:
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
return result
def _parse_usage(usage_obj: Any) -> TokenUsage:
def _parse_usage(usage_obj: object) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None)
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
def _accumulate_stream_tool_calls(
raw_tool_calls: Any,
accumulated: dict[int, dict[str, Any]],
raw_tool_calls: Iterable[object],
accumulated: dict[int, dict[str, object]],
) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
@ -364,7 +363,7 @@ def create_litellm_provider(
provider_type: str,
api_key: str,
base_url: str | None = None,
**kwargs: Any,
**kwargs: object,
) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。

View File

@ -18,7 +18,7 @@ import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from agentkit.llm.protocol import TokenUsage
@ -294,8 +294,8 @@ class RedisUsageStore:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url
self._redis: Any = None
self._sync_redis: Any = None
self._redis: object | None = None
self._sync_redis: object | None = None
self._fallback: InMemoryUsageStore | None = None
self._degraded = False
self._health_check_task: asyncio.Task[None] | None = None
@ -687,7 +687,7 @@ class RedisUsageStore:
@staticmethod
def _read_list(
r: Any,
r: object,
list_key: str,
start: datetime,
end: datetime,

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import time
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
secret_key: str | None = None,
base_url: str = WENXIN_DEFAULT_BASE_URL,
default_model: str = "ernie-4.5-turbo-128k",
**kwargs: Any,
**kwargs: object,
):
# If AK/SK provided, use token-based auth
self._access_key = access_key

View File

@ -8,7 +8,6 @@ API腾讯云 OpenAI 兼容接口
from __future__ import annotations
import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
base_url: str = YUANBAO_DEFAULT_BASE_URL,
default_model: str = "hunyuan-turbos-latest",
enable_enhancement: bool = False,
**kwargs: Any,
**kwargs: object,
):
self._enable_enhancement = enable_enhancement
super().__init__(

View File

@ -3,7 +3,6 @@
import json
import logging
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any
import httpx
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
"Content-Type": "application/json",
}
def _build_payload(self, request: LLMRequest) -> dict[str, Any]:
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
"""Convert LLMRequest to server API payload."""
return {
"messages": request.messages,
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
return str(body["error"])
return str(body)
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse:
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
"""Parse server response JSON into an LLMResponse."""
usage_data = data.get("usage") or {}
usage = TokenUsage(
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
latency_ms=data.get("latency_ms", 0.0),
)
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk:
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
"""Parse a single SSE data payload into a StreamChunk."""
usage: TokenUsage | None = None
usage_data = data.get("usage")
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
if response.status_code == 502:
await response.aread()
detail = self._extract_error_detail(response)
raise LLMProviderError(
"remote", f"Server LLM gateway error: {detail}"
)
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
if response.status_code != 200:
await response.aread()
raise LLMProviderError(

View File

@ -5,7 +5,7 @@ import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable
from typing import Callable
from agentkit.core.exceptions import LLMProviderError
@ -20,9 +20,7 @@ class RetryConfig:
base_delay: float = 1.0
max_delay: float = 30.0
exponential_base: float = 2.0
retryable_status_codes: set[int] = field(
default_factory=lambda: {429, 500, 502, 503, 529}
)
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
class CircuitState(Enum):
@ -69,7 +67,7 @@ class RetryPolicy:
def __init__(self, config: RetryConfig | None = None):
self._config = config or RetryConfig()
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn with retry on retryable errors."""
last_error: Exception | None = None
@ -84,7 +82,7 @@ class RetryPolicy:
raise
delay = min(
self._config.base_delay * (self._config.exponential_base ** attempt),
self._config.base_delay * (self._config.exponential_base**attempt),
self._config.max_delay,
)
logger.warning(
@ -142,7 +140,7 @@ class CircuitBreaker:
f"after {self._failure_count} failures"
)
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn through the circuit breaker."""
current_state = self.state
@ -158,6 +156,6 @@ class CircuitBreaker:
result = await fn(*args, **kwargs)
self._on_success()
return result
except Exception as e:
except Exception:
self._on_failure()
raise

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any
import httpx
@ -95,8 +94,7 @@ class KBAdapter(ABC):
async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除(子类可覆盖)"""
logger.warning(
f"{self.__class__.__name__} does not support delete_by_id; "
f"id '{id}' skipped"
f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
)
return False
@ -127,8 +125,7 @@ class KBAdapter(ABC):
async def get_document(self, doc_id: str) -> Document | None:
"""按 ID 获取单个文档(子类可覆盖)"""
logger.warning(
f"{self.__class__.__name__} does not support get_document; "
f"doc_id '{doc_id}' not found"
f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
)
return None
@ -156,5 +153,5 @@ class KBAdapter(ABC):
async def __aenter__(self) -> KBAdapter:
return self
async def __aexit__(self, *args: Any) -> None:
async def __aexit__(self, *args: object) -> None:
await self.close()

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import logging
from typing import Any
import httpx
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
)
self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._username = username
self._api_token = api_token
self._space_keys = space_keys or []
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
"""创建 Confluence API HTTP 客户端"""
import base64
credentials = base64.b64encode(
f"{self._username}:{self._api_token}".encode()
).decode()
credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
return httpx.AsyncClient(
base_url=self._base_url,
headers={
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
space_filter = " OR ".join(
f'space = "{_escape_cql(key)}"' for key in self._space_keys
)
cql = f'{cql} AND ({space_filter})'
cql = f"{cql} AND ({space_filter})"
resp = await client.get(
"/rest/api/content/search",
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "")
# Strip HTML tags for plain text content
import re
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
results.append(
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"Confluence search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"Confluence search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "")
import re
content = re.sub(r"<[^>]+>", "", body) if body else ""
return Document(
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
source_type="confluence",
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.error(f"Confluence list_sources error: {e}")
return [

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import logging
import time
from typing import Any
from typing import TypeAlias
import httpx
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__)
# 飞书搜索请求 payloadsearch_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
FeishuSearchPayload: TypeAlias = dict[str, object]
class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
self._app_secret = app_secret
self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._space_ids = space_ids or []
self._access_token: str | None = None
self._token_expiry: float = 0.0
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
self._client = None
return self._access_token
else:
logger.error(
f"Feishu auth failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
)
logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
return None
except Exception as e:
logger.error(f"Feishu auth error: {e}")
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client()
try:
payload: dict[str, Any] = {
payload: FeishuSearchPayload = {
"search_key": query,
"page_size": top_k,
}
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
if data.get("code") != 0:
logger.error(
f"Feishu search failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
)
return []
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"Feishu search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"Feishu search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client()
try:
resp = await client.get(
f"/wiki/v2/spaces/get_node",
"/wiki/v2/spaces/get_node",
params={"token": doc_id},
)
resp.raise_for_status()
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
source_type="feishu",
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.error(f"Feishu list_sources error: {e}")
return [

View File

@ -7,7 +7,6 @@
from __future__ import annotations
import logging
from typing import Any
import httpx
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
)
self._endpoint_url = endpoint_url.rstrip("/")
if not is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
raise ValueError(
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
)
self._auth_config = auth_config or {}
self._extra_headers = headers or {}
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic":
import base64
username = self._auth_config.get("username", "")
password = self._auth_config.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
elif auth_type == "api_key":
key_name = self._auth_config.get("header_name", "X-API-Key")
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP search HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"GenericHTTP search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e:
logger.error(
f"GenericHTTP ingest HTTP error: {e.response.status_code}"
f"{e.response.text[:200]}"
f"GenericHTTP ingest HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except Exception as e:
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
document_count=item.get("document_count", 0),
)
)
return sources if sources else [
return (
sources
if sources
else [
SourceInfo(
source_id=self._source_id,
source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e:
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")

View File

@ -3,19 +3,29 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from typing import TypeAlias
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
MetadataValue: TypeAlias = str | int | float | bool | None
MetadataDict: TypeAlias = dict[str, MetadataValue]
MemoryValue: TypeAlias = (
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
)
@dataclass
class MemoryItem:
"""记忆条目"""
key: str
value: Any
metadata: dict[str, Any] = field(default_factory=dict)
value: object
metadata: MetadataDict = field(default_factory=dict)
score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict:
def to_dict(self) -> dict[str, object]:
return {
"key": self.key,
"value": self.value,
@ -35,7 +45,7 @@ class Memory(ABC):
"""
@abstractmethod
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储记忆"""
...
@ -45,7 +55,9 @@ class Memory(ABC):
...
@abstractmethod
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索"""
...
@ -54,7 +66,7 @@ class Memory(ABC):
"""删除记忆"""
...
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
"""批量存储"""
for key, value, metadata in items:
await self.store(key, value, metadata)

View File

@ -11,10 +11,18 @@ import logging
import re
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__)
# 分块元数据source_doc/position/char_count/chunking_strategy/heading/heading_level
# — 全部为原始标量str/int
ChunkMetadata: TypeAlias = MetadataDict
# _split_by_headings 返回的节段结构。
SectionInfo: TypeAlias = dict[str, str | int]
@dataclass
class Chunk:
@ -22,7 +30,7 @@ class Chunk:
chunk_id: str
content: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: ChunkMetadata = field(default_factory=dict)
def __post_init__(self) -> None:
if "source_doc" not in self.metadata:
@ -30,7 +38,7 @@ class Chunk:
if "position" not in self.metadata:
self.metadata["position"] = 0
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return {
"chunk_id": self.chunk_id,
"content": self.content,
@ -57,7 +65,9 @@ class TextChunker:
separator: 优先分割符
"""
if chunk_overlap >= chunk_size:
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})")
raise ValueError(
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._separator = separator
@ -66,7 +76,7 @@ class TextChunker:
self,
text: str,
source_doc_id: str = "",
metadata: dict[str, Any] | None = None,
metadata: ChunkMetadata | None = None,
) -> list[Chunk]:
"""将文本分块
@ -96,11 +106,13 @@ class TextChunker:
chunk_meta = dict(base_meta)
chunk_meta["position"] = i
chunk_meta["char_count"] = len(chunk_text)
chunks.append(Chunk(
chunks.append(
Chunk(
chunk_id=str(uuid.uuid4()),
content=chunk_text,
metadata=chunk_meta,
))
)
)
return chunks
@ -142,7 +154,9 @@ class TextChunker:
overlap_text[overlap_start:], segments
)
current = overlap_segments
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1)
current_len = sum(len(s) for s in current) + len(self._separator) * max(
0, len(current) - 1
)
current.append(segment)
current_len += seg_len + len(self._separator)
@ -214,7 +228,7 @@ class StructuralChunker:
self,
text: str,
source_doc_id: str = "",
metadata: dict[str, Any] | None = None,
metadata: ChunkMetadata | None = None,
) -> list[Chunk]:
"""将文本按结构分块
@ -266,23 +280,25 @@ class StructuralChunker:
chunk_meta["heading"] = heading
chunk_meta["heading_level"] = level
chunk_meta["char_count"] = len(content)
chunks.append(Chunk(
chunks.append(
Chunk(
chunk_id=str(uuid.uuid4()),
content=content,
metadata=chunk_meta,
))
)
)
position += 1
return chunks
def _split_by_headings(self, text: str) -> list[dict[str, Any]]:
def _split_by_headings(self, text: str) -> list[SectionInfo]:
"""按标题分割 Markdown 文本
Returns:
列表每项包含 heading, content, level
"""
lines = text.split("\n")
sections: list[dict[str, Any]] = []
sections: list[SectionInfo] = []
current_heading = ""
current_level = 0
current_lines: list[str] = []
@ -296,11 +312,13 @@ class StructuralChunker:
if current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append({
sections.append(
{
"heading": current_heading,
"content": content,
"level": current_level,
})
}
)
# 开始新节
current_heading = match.group(2).strip()
@ -313,18 +331,22 @@ class StructuralChunker:
if current_lines:
content = "\n".join(current_lines).strip()
if content:
sections.append({
sections.append(
{
"heading": current_heading,
"content": content,
"level": current_level,
})
}
)
# 如果没有标题结构,整体作为一个块
if not sections:
sections.append({
sections.append(
{
"heading": "",
"content": text.strip(),
"level": 0,
})
}
)
return sections

View File

@ -9,10 +9,14 @@ from __future__ import annotations
import hashlib
import logging
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING
from agentkit.memory.base import MetadataDict
from agentkit.memory.embedder import EmbeddingCache
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
@ -24,7 +28,7 @@ class ContextualChunk:
context_prefix: str
enhanced_content: str
chunk_index: int
metadata: dict[str, Any]
metadata: MetadataDict
@property
def content(self) -> str:
@ -65,7 +69,7 @@ class ContextualChunker:
def __init__(
self,
llm_gateway: Any = None,
llm_gateway: LLMGateway | None = None,
cache: EmbeddingCache | None = None,
batch_size: int = 8,
max_context_length: int = 200,
@ -90,7 +94,7 @@ class ContextualChunker:
self,
document: str,
chunks: list[str],
metadata: dict[str, Any] | None = None,
metadata: MetadataDict | None = None,
) -> list[ContextualChunk]:
"""为文档块添加上下文前缀
@ -134,7 +138,7 @@ class ContextualChunker:
document: str,
chunks: list[str],
start_index: int,
metadata: dict[str, Any] | None,
metadata: MetadataDict | None,
) -> list[ContextualChunk]:
"""处理一批文档块"""
results: list[ContextualChunk] = []

View File

@ -12,7 +12,9 @@ import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__)
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
MAX_ROWS_PER_SHEET = 10_000
MAX_CELL_CHARS = 10_000
# 文档元数据source/format/parser/page_count/table_count/sheet_count/row_count/
# heading_count/created_at/title/truncated — 全部为原始标量。
DocumentMetadata: TypeAlias = MetadataDict
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
@dataclass
class Document:
@ -31,7 +38,7 @@ class Document:
doc_id: str
title: str
content: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: DocumentMetadata = field(default_factory=dict)
def __post_init__(self) -> None:
if "source" not in self.metadata:
@ -43,7 +50,7 @@ class Document:
if "created_at" not in self.metadata:
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return {
"doc_id": self.doc_id,
"title": self.title,
@ -136,12 +143,14 @@ class DocumentLoader:
parser = parsers.get(doc_format)
if parser is None:
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text")
logger.warning(
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
)
parser = self._parse_text
text, extra_meta = parser(content, filename)
metadata: dict[str, Any] = {
metadata: DocumentMetadata = {
"source": filename,
"format": doc_format,
"created_at": datetime.now(timezone.utc).isoformat(),
@ -159,7 +168,7 @@ class DocumentLoader:
metadata=metadata,
)
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
"""解析 PDF 文件
优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本
@ -215,7 +224,7 @@ class DocumentLoader:
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
return self._parse_text(content, filename)
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Word 文件
使用 python-docx回退到纯文本
@ -259,7 +268,7 @@ class DocumentLoader:
logger.warning(f"python-docx parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Excel 文件
使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格
@ -313,7 +322,7 @@ class DocumentLoader:
finally:
wb.close()
text = "\n".join(sections).strip()
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "openpyxl",
"sheet_count": sheet_count,
"row_count": total_rows,
@ -328,7 +337,7 @@ class DocumentLoader:
logger.warning(f"openpyxl parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
"""解析 Markdown 文件
使用 mistune如果可用否则直接读取文本
@ -347,7 +356,7 @@ class DocumentLoader:
title = line_stripped.lstrip("#").strip()
break
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "markdown",
}
if title:
@ -362,7 +371,7 @@ class DocumentLoader:
return text, meta
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_html(self, content: bytes, filename: str) -> ParseResult:
"""解析 HTML 文件
使用 BeautifulSoup 提取文本回退到纯文本
@ -388,7 +397,7 @@ class DocumentLoader:
if soup.title and soup.title.string:
title = soup.title.string.strip()
meta: dict[str, Any] = {
meta: DocumentMetadata = {
"parser": "beautifulsoup",
}
if title:
@ -402,7 +411,7 @@ class DocumentLoader:
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
return self._parse_text(content, filename)
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
def _parse_text(self, content: bytes, filename: str) -> ParseResult:
"""解析纯文本文件"""
try:
text = content.decode("utf-8")

View File

@ -1,12 +1,17 @@
"""Embedder 接口与实现 - 文本向量化"""
from __future__ import annotations
import hashlib
import logging
import os
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__)
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
self._model = model
self._base_url = base_url
self._dimension = 1536 # text-embedding-3-small 默认维度
self._client: Any = None
self._client: httpx.AsyncClient | None = None
self._cache = cache
def _get_client(self):
def _get_client(self) -> httpx.AsyncClient:
"""Lazily create and reuse a single httpx.AsyncClient."""
if self._client is None:
import httpx
self._client = httpx.AsyncClient(timeout=30.0)
return self._client

View File

@ -1,17 +1,22 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
from __future__ import annotations
import json
import logging
import math
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
def __init__(
self,
session_factory: Any,
episodic_model: Any,
session_factory: object,
episodic_model: object,
embedder: Embedder | None = None,
decay_rate: float = 0.01,
alpha: float = 0.7,
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
self._pgvector_enabled = pgvector_enabled
self._table_name = table_name
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储任务经验"""
async with self._session_factory() as db:
try:
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
embedding = None
if self._embedder:
if isinstance(value, dict):
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
text = (
value.get("output_summary", "")
or value.get("input_summary", "")
or json.dumps(value, ensure_ascii=False)[:500]
)
else:
text = str(value)
embedding = await self._embedder.embed(text)
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
logger.error(f"Failed to retrieve episodic memory: {e}")
return None
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
async def _retrieve_pgvector(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""使用 pgvector ``<=>`` 算符检索最相似条目"""
sql = text(
f"SELECT * FROM {self._table_name} "
f"ORDER BY embedding <=> :query_vec "
f"LIMIT :lim"
)
sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
row = result.mappings().first()
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
created_at=row.get("created_at") or datetime.now(timezone.utc),
)
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
async def _retrieve_client_side(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""客户端 O(N) cosine similarity 检索(回退路径)"""
Model = self._episodic_model
from sqlalchemy import select
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
created_at=best_item.created_at or datetime.now(timezone.utc),
)
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
async def search(
self,
query: str,
top_k: int = 5,
filters: MetadataDict | None = None,
search_multiplier: int = 5,
) -> list[MemoryItem]:
"""语义检索相似历史案例
Args:
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
async def _search_pgvector(
self,
db: Any,
db: AsyncSession,
query: str,
top_k: int,
filters: dict[str, Any] | None,
filters: MetadataDict | None,
search_multiplier: int,
) -> list[MemoryItem]:
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
fetch_limit = top_k * search_multiplier
where_clauses = []
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit}
params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
filters = filters or {}
if filters.get("agent_name"):
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
items = []
for row in rows:
row_embedding = row.get("embedding")
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0
age_hours = (
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
if row.get("created_at")
else 0
)
decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (row.get("quality_score") or 0.5) * decay
@ -266,7 +285,8 @@ class EpisodicMemory(Memory):
else:
score = time_decay_score
items.append(MemoryItem(
items.append(
MemoryItem(
key=str(row.get("id", "")),
value={
"input_summary": row.get("input_summary", ""),
@ -278,21 +298,24 @@ class EpisodicMemory(Memory):
metadata={
"agent_name": row.get("agent_name", ""),
"task_type": row.get("task_type", ""),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
"created_at": row["created_at"].isoformat()
if row.get("created_at")
else None,
},
score=score,
created_at=row.get("created_at") or datetime.now(timezone.utc),
))
)
)
items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k]
async def _search_client_side(
self,
db: Any,
db: AsyncSession,
query: str,
top_k: int,
filters: dict[str, Any] | None,
filters: MetadataDict | None,
search_multiplier: int,
) -> list[MemoryItem]:
"""客户端 O(N) cosine similarity 检索(回退路径)"""
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
filters = filters or {}
from sqlalchemy import select
stmt = select(Model)
if filters.get("agent_name"):
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
# 计算得分并构建 MemoryItem
items = []
for entry in entries:
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
age_hours = (
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
if entry.created_at
else 0
)
decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (entry.quality_score or 0.5) * decay
@ -333,7 +361,8 @@ class EpisodicMemory(Memory):
else:
score = time_decay_score
items.append(MemoryItem(
items.append(
MemoryItem(
key=str(entry.id),
value={
"input_summary": entry.input_summary,
@ -349,14 +378,17 @@ class EpisodicMemory(Memory):
},
score=score,
created_at=entry.created_at or datetime.now(timezone.utc),
))
)
)
items.sort(key=lambda x: x.score, reverse=True)
if len(items) < top_k:
logger.warning(
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
len(items), top_k, search_multiplier,
len(items),
top_k,
search_multiplier,
)
return items[:top_k]
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
"""删除指定经验"""
async with self._session_factory() as db:
try:
from sqlalchemy import select, delete as sql_delete
from sqlalchemy import delete as sql_delete
import uuid
Model = self._episodic_model
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))

View File

@ -3,13 +3,26 @@
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
"""
from __future__ import annotations
import logging
from typing import Any
from typing import TYPE_CHECKING, TypeAlias
import httpx
from agentkit.memory.base import MetadataDict
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
# 标准化检索结果id/content/score/source/document_id/document_title/
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
RAGSearchResult: TypeAlias = dict[str, object]
# ingest() 写入的文档负载title/content/source_type/metadata。
RAGIngestPayload: TypeAlias = dict[str, object]
class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API
@ -39,7 +52,7 @@ class HttpRAGService:
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
contextual_chunking: bool = False,
llm_gateway: Any = None,
llm_gateway: LLMGateway | None = None,
):
"""
Args:
@ -74,7 +87,7 @@ class HttpRAGService:
query: str,
knowledge_base_ids: list[str] | None = None,
top_k: int = 5,
) -> list[dict[str, Any]]:
) -> list[RAGSearchResult]:
"""语义检索知识库
Args:
@ -113,7 +126,8 @@ class HttpRAGService:
normalized = []
for r in results:
if isinstance(r, dict):
normalized.append({
normalized.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
@ -121,11 +135,14 @@ class HttpRAGService:
"document_id": r.get("document_id", ""),
"document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}),
})
}
)
return normalized
except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}")
logger.error(
f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return []
except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}")
@ -141,7 +158,7 @@ class HttpRAGService:
top_k: int = 5,
use_rerank: bool = True,
use_compression: bool = False,
) -> list[dict[str, Any]]:
) -> list[RAGSearchResult]:
"""增强语义检索知识库(支持 rerank 和 compression
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口
@ -169,7 +186,7 @@ class HttpRAGService:
}
client = self._get_client()
all_results: list[dict[str, Any]] = []
all_results: list[RAGSearchResult] = []
for kb_id in kb_ids:
try:
@ -189,7 +206,8 @@ class HttpRAGService:
# 标准化
for r in results:
if isinstance(r, dict):
all_results.append({
all_results.append(
{
"id": r.get("chunk_id", r.get("id", "")),
"content": r.get("content", ""),
"score": float(r.get("score", 0.0)),
@ -198,19 +216,17 @@ class HttpRAGService:
"document_title": r.get("document_title", ""),
"knowledge_base_id": kb_id,
"metadata": r.get("metadata", {}),
})
}
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
# This KB doesn't support enhanced search — fall back to
# standard search for THIS KB only, not all KBs.
logger.info(
f"Enhanced search not available for KB {kb_id}, "
f"using standard search"
)
std_result = await self.search(
query, knowledge_base_ids=[kb_id], top_k=top_k
f"Enhanced search not available for KB {kb_id}, using standard search"
)
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
all_results.extend(std_result)
else:
logger.error(
@ -232,9 +248,9 @@ class HttpRAGService:
async def ingest(
self,
key: str,
value: Any,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
value: object,
metadata: MetadataDict | None = None,
) -> dict[str, object] | None:
"""写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured,
@ -308,5 +324,5 @@ class HttpRAGService:
async def __aenter__(self) -> "HttpRAGService":
return self
async def __aexit__(self, *args: Any) -> None:
async def __aexit__(self, *args: object) -> None:
await self.close()

View File

@ -11,7 +11,11 @@ from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Protocol, runtime_checkable
from typing import Protocol, TypeAlias, runtime_checkable
from agentkit.memory.base import MetadataDict
KBMetadata: TypeAlias = MetadataDict
@dataclass
@ -22,7 +26,7 @@ class Document:
content: str
title: str = ""
source_id: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
metadata: KBMetadata = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@ -34,7 +38,7 @@ class QueryResult:
source_id: str
source_name: str
score: float
metadata: dict[str, Any] = field(default_factory=dict)
metadata: KBMetadata = field(default_factory=dict)
doc_id: str = ""
title: str = ""

View File

@ -11,25 +11,32 @@ from __future__ import annotations
import json
import logging
import re
import uuid
from datetime import datetime, timezone
from typing import Any
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from typing import TYPE_CHECKING, TypeAlias
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder
from agentkit.memory.knowledge_base import (
Document,
KnowledgeBase,
QueryResult,
SourceInfo,
)
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
logger = logging.getLogger(__name__)
# InMemoryLocalRAGService 内部存储的文档元信息结构。
# 字段title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
InMemoryDocInfo: TypeAlias = dict[str, object]
# 内部 chunk 存储结构content/embedding/metadata/source_doc_id。
InMemoryChunkInfo: TypeAlias = dict[str, object]
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
"""将 document_loader.Document 转换为 knowledge_base.Document"""
@ -53,7 +60,7 @@ class LocalRAGService:
def __init__(
self,
session_factory: Any,
session_factory: object,
embedder: Embedder,
chunk_size: int = 1000,
chunk_overlap: int = 200,
@ -75,10 +82,14 @@ class LocalRAGService:
self._chunk_overlap = chunk_overlap
self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
raise ValueError(
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
@ -136,9 +147,7 @@ class LocalRAGService:
try:
from sqlalchemy import text as sql_text
sql = sql_text(
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
)
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
await db.execute(sql, {"doc_id": id})
await db.commit()
return True
@ -171,20 +180,15 @@ class LocalRAGService:
sources = []
for row in rows:
meta = {}
if row.get("doc_metadata"):
try:
meta = json.loads(row["doc_metadata"])
except (json.JSONDecodeError, TypeError):
pass
sources.append(SourceInfo(
sources.append(
SourceInfo(
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
source_type=row.get("doc_format", "local"),
document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
))
)
)
return sources
except Exception as e:
logger.error(f"Failed to list sources: {e}")
@ -271,7 +275,7 @@ class LocalRAGService:
async def _query_pgvector(
self,
db: Any,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
@ -286,10 +290,13 @@ class LocalRAGService:
f"LIMIT :lim"
)
result = await db.execute(sql, {
result = await db.execute(
sql,
{
"query_vec": str(query_embedding),
"lim": top_k,
})
},
)
rows = result.mappings().all()
results = []
@ -306,7 +313,8 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError):
pass
results.append(QueryResult(
results.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
@ -314,13 +322,14 @@ class LocalRAGService:
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
))
)
)
return results
async def _query_client_side(
self,
db: Any,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
@ -363,7 +372,8 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError):
pass
candidates.append(QueryResult(
candidates.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
@ -371,7 +381,8 @@ class LocalRAGService:
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
))
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
"""
self._embedder = embedder
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
# 内存存储
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[
str, InMemoryDocInfo
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
@ -459,7 +474,8 @@ class InMemoryLocalRAGService:
source_doc_id = chunk_data["source_doc_id"]
doc_info = self._documents.get(source_doc_id, {})
candidates.append(QueryResult(
candidates.append(
QueryResult(
content=chunk_data["content"],
source_id=source_doc_id,
source_name=doc_info.get("title", ""),
@ -467,7 +483,8 @@ class InMemoryLocalRAGService:
metadata=chunk_data.get("metadata", {}),
doc_id=source_doc_id,
title=doc_info.get("title", ""),
))
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
"""列出已摄取的文档"""
sources = []
for doc_id, doc_info in self._documents.items():
sources.append(SourceInfo(
sources.append(
SourceInfo(
source_id=doc_id,
source_name=doc_info["title"],
source_type=doc_info.get("format", "local"),
document_count=len(doc_info.get("chunk_ids", [])),
last_updated=doc_info.get("created_at"),
))
)
)
return sources
async def health_check(self) -> bool:

View File

@ -13,7 +13,6 @@ import asyncio
import hashlib
import logging
from dataclasses import replace
from typing import Any
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
@ -186,15 +185,13 @@ class MultiSourceRetriever:
Returns:
所有源的检索结果列表已应用权重
"""
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
try:
results = await kb.query(query, top_k=top_k)
# 应用权重
weight = (weights or {}).get(name, 1.0)
return [
replace(r, score=r.score * weight, source_name=name)
for r in results
]
return [replace(r, score=r.score * weight, source_name=name) for r in results]
except Exception as e:
logger.error(f"Query failed for source '{name}': {e}")
return []

View File

@ -7,10 +7,10 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable
from typing import Callable
class MemoryFile:
@ -26,8 +26,9 @@ class MemoryFile:
"""
def __init__(self, path: Path, char_budget: int | None = None,
protected_sections: set[str] | None = None):
def __init__(
self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
):
self.path = Path(path)
self.char_budget = char_budget
self._protected_sections = protected_sections or set()
@ -138,7 +139,7 @@ class MemoryFile:
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
name = match.group(1).strip()
start = match.start()
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE)
next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
if next_match:
end = match.end() + next_match.start()
else:
@ -146,7 +147,7 @@ class MemoryFile:
sections.append((name, start, end))
if not sections:
return content[:self.char_budget]
return content[: self.char_budget]
# 保持原始顺序,标记每个 section 是否受保护
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
@ -222,8 +223,9 @@ class MemoryStore:
"""
def __init__(self, base_dir: Path | str | None = None,
on_change: Callable[[str], None] | None = None):
def __init__(
self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
):
if base_dir is None:
base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir)
@ -238,7 +240,9 @@ class MemoryStore:
protected_sections={"版本", "更新历史"},
)
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
self._memory = MemoryFile(
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
)
self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True)
@ -376,4 +380,5 @@ class MemoryStore:
self._on_change(new_prompt)
except Exception:
import logging
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)

View File

@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
"""
_FILLER_WORDS_CN: list[str] = [
"帮我", "", "一下", "分析", "看看", "告诉我", "想知道", "请问",
"帮我",
"",
"一下",
"分析",
"看看",
"告诉我",
"想知道",
"请问",
]
_FILLER_WORDS_EN: list[str] = [
"please", "can you", "help me", "could you", "i want to", "i need to",
"please",
"can you",
"help me",
"could you",
"i want to",
"i need to",
]
def __init__(
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
self._synonyms = synonyms or {}
self._max_sub_queries = max_sub_queries
# Pre-compile filler patterns
self._filler_patterns_cn = [
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
]
self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
self._filler_patterns_en = [
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
]
@ -166,7 +176,9 @@ def create_query_transformer(
"""工厂函数:根据策略创建查询改写器"""
if strategy == "llm":
if llm_gateway is None:
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp")
logger.warning(
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
)
return NoOpQueryTransformer()
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
elif strategy == "rule":

View File

@ -7,11 +7,11 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from agentkit.memory.base import MemoryItem
from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
RetrievalEvaluation,
)
if TYPE_CHECKING:
# 避免与 retriever.py 形成运行时循环导入。
from agentkit.memory.retriever import MemoryRetriever
logger = logging.getLogger(__name__)
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
def __init__(
self,
retriever: Any, # MemoryRetriever
retriever: MemoryRetriever,
scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3,
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
) -> RAGLoopResult:
"""执行带自纠正的检索
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
while retry_count <= self._max_retries:
# RETRIEVE
items = await self._retriever.retrieve(
current_query, top_k=top_k, token_budget=token_budget,
filters=filters, _skip_correction=True,
current_query,
top_k=top_k,
token_budget=token_budget,
filters=filters,
_skip_correction=True,
)
# EVALUATE
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
# CORRECT — rewrite query and retry
retry_count += 1
if retry_count <= self._max_retries:
current_query = await self._rewrite_query(
query, current_query, evaluation
)
current_query = await self._rewrite_query(query, current_query, evaluation)
continue
# DEGRADE — exceeded max retries
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
# Degraded result: filter to relevant items and mark low confidence
relevant_items = [
s.item
for s in evaluation.scores
if s.verdict != RelevanceVerdict.INCORRECT
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
]
result_items = relevant_items if relevant_items else items

View File

@ -6,11 +6,9 @@
from __future__ import annotations
import logging
import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem
@ -120,9 +118,7 @@ class RelevanceScorer:
reason=reason,
)
def evaluate(
self, query: str, items: list[MemoryItem]
) -> RetrievalEvaluation:
def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
"""评估一次检索的整体质量"""
if not items:
return RetrievalEvaluation(
@ -134,9 +130,7 @@ class RelevanceScorer:
)
scores = [self.score_item(query, item) for item in items]
relevant_count = sum(
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
)
relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
avg_score = sum(s.score for s in scores) / len(scores)
# Overall verdict based on average score and relevant ratio

View File

@ -7,19 +7,16 @@ from __future__ import annotations
import asyncio
import logging
import math
from dataclasses import replace
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.query_transformer import QueryTransformerBase
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
from agentkit.memory.relevance_scorer import RelevanceScorer
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult
from agentkit.memory.knowledge_base import KnowledgeBase
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
from agentkit.tools.base import Tool
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
Chinese characters typically use 1-2 tokens each.
English words typically use 1 token each.
"""
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
non_cjk = text
for c in text:
if '\u4e00' <= c <= '\u9fff':
non_cjk = non_cjk.replace(c, ' ')
if "\u4e00" <= c <= "\u9fff":
non_cjk = non_cjk.replace(c, " ")
word_count = len(non_cjk.split())
return cjk_count * 2 + word_count
@ -89,7 +86,7 @@ class MemoryRetriever:
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
_skip_correction: bool = False,
sources: list[str] | None = None,
source_weights: dict[str, float] | None = None,
@ -121,9 +118,7 @@ class MemoryRetriever:
query, top_k=top_k, token_budget=token_budget, filters=filters
)
if result.degraded:
logger.warning(
f"RAG self-correction degraded after {result.total_retries} retries"
)
logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
return result.items
# Query transformation
if self._query_transformer is not None:
@ -139,9 +134,7 @@ class MemoryRetriever:
# Sub-query search in parallel
if sub_queries:
sub_tasks = [
self._search_layers(sq, top_k, filters) for sq in sub_queries
]
sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
for result in sub_results:
if isinstance(result, Exception):
@ -178,7 +171,7 @@ class MemoryRetriever:
self,
query: str,
top_k: int = 5,
filters: dict[str, Any] | None = None,
filters: MetadataDict | None = None,
) -> list[MemoryItem]:
"""Search all configured memory layers with a single query"""
tasks = []
@ -237,7 +230,8 @@ class MemoryRetriever:
# QueryResult → MemoryItem
items = []
for r in kb_results:
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.source_id,
value=r.content,
metadata={
@ -248,7 +242,8 @@ class MemoryRetriever:
"document_title": r.title,
},
score=r.score,
))
)
)
# Token 预算管理
selected = []
@ -318,7 +313,9 @@ class MemoryRetriever:
if source == "rag":
kb_type = item.metadata.get("kb_type", "知识库")
document_title = item.metadata.get("document_title", "未知文档")
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
return (
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
)
elif source == "graph":
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
elif source == "episodic":
@ -330,7 +327,7 @@ class MemoryRetriever:
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
async def store_episode(
self, key: str, value: Any, metadata: dict[str, Any] | None = None
self, key: str, value: object, metadata: MetadataDict | None = None
) -> None:
"""Store an episode into episodic memory if available.
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
items = await self._retriever.retrieve(query, top_k=5)
results = []
for item in items:
results.append({
results.append(
{
"content": item.value,
"score": item.score,
"source": item.metadata.get("source", "unknown"),
"document_title": item.metadata.get("document_title", ""),
})
}
)
return {"query": query, "results": results, "call_count": self._call_count}
except Exception as e:
return {"error": str(e), "results": []}

View File

@ -3,14 +3,45 @@
适配器模式对接外部 RAG 服务和知识图谱
"""
import logging
from typing import Any
from __future__ import annotations
from agentkit.memory.base import Memory, MemoryItem
import logging
from typing import TYPE_CHECKING, Protocol
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
if TYPE_CHECKING:
from agentkit.memory.http_rag import RAGSearchResult
logger = logging.getLogger(__name__)
class _RAGServiceLike(Protocol):
"""RAG 检索服务最小接口契约duck-typed"""
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
) -> list[RAGSearchResult]: ...
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
use_rerank: bool = ...,
use_compression: bool = ...,
) -> list[RAGSearchResult]: ...
class _GraphServiceLike(Protocol):
"""知识图谱服务最小接口契约duck-typed"""
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
def __init__(
self,
rag_service: Any = None,
graph_service: Any = None,
rag_service: _RAGServiceLike | None = None,
graph_service: _GraphServiceLike | None = None,
knowledge_base_ids: list[str] | None = None,
search_mode: str = "standard",
use_rerank: bool = True,
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
self._use_compression = use_compression
self._kb_weights = kb_weights
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, 'ingest'):
if self._rag_service and hasattr(self._rag_service, "ingest"):
await self._rag_service.ingest(key, value, metadata)
else:
logger.warning("SemanticMemory.store: no RAG service configured for writing")
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
"""按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索知识库"""
items = []
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
if self._rag_service:
try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"):
if self._search_mode == "enhanced" and hasattr(
self._rag_service, "enhanced_search"
):
results = await self._rag_service.enhanced_search(
query,
knowledge_base_ids=kb_ids,
@ -73,14 +108,17 @@ class SemanticMemory(Memory):
use_compression=self._use_compression,
)
else:
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
results = await self._rag_service.search(
query, knowledge_base_ids=kb_ids, top_k=top_k
)
for r in results:
kb_id = r.get("knowledge_base_id", "")
score = r.get("score", 0.0)
# Apply per-KB weights
if self._kb_weights and kb_id in self._kb_weights:
score *= self._kb_weights[kb_id]
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
@ -90,7 +128,8 @@ class SemanticMemory(Memory):
"knowledge_base_id": kb_id,
},
score=score,
))
)
)
except Exception as e:
logger.error(f"RAG search failed: {e}")
@ -99,7 +138,8 @@ class SemanticMemory(Memory):
try:
graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]:
items.append(MemoryItem(
items.append(
MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
@ -108,7 +148,8 @@ class SemanticMemory(Memory):
"relations": r.get("relations", []),
},
score=r.get("score", 0.0),
))
)
)
except Exception as e:
logger.error(f"Graph search failed: {e}")

View File

@ -3,11 +3,10 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
def _make_key(self, key: str) -> str:
return f"{self._key_prefix}:{key}"
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
redis_key = self._make_key(key)
item = MemoryItem(
key=key,
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=item_dict.get("score", 1.0),
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc),
created_at=datetime.fromisoformat(item_dict["created_at"])
if item_dict.get("created_at")
else datetime.now(timezone.utc),
)
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
pattern = self._make_key(f"{query}*")
keys = []
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
data = await self._redis.get(key)
if data:
item_dict = json.loads(data)
items.append(MemoryItem(
items.append(
MemoryItem(
key=item_dict["key"],
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=1.0,
created_at=datetime.now(timezone.utc),
))
)
)
return items
async def delete(self, key: str) -> bool:

View File

@ -17,7 +17,7 @@ import logging
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
if TYPE_CHECKING:
from typing import Protocol
class _RedisPipelineLike(Protocol):
def set(self, key: str, value: str, ex: int | None = None) -> object: ...
def zadd(self, name: str, mapping: dict[str, float]) -> object: ...
def get(self, key: str) -> object: ...
def delete(self, *keys: str) -> object: ...
async def execute(self) -> list[object]: ...
class _RedisLike(Protocol):
async def set(self, key: str, value: str, ex: int | None = None) -> object: ...
async def get(self, key: str) -> object: ...
def pipeline(self) -> _RedisPipelineLike: ...
async def zrange(self, name: str, start: int, stop: int) -> list[object]: ...
class _PlanLike(Protocol):
@property
def id(self) -> str: ...
def to_dict(self) -> dict[str, object]: ...
@dataclass
class CheckpointData:
"""单个阶段的 checkpoint 数据。"""
@ -33,15 +55,15 @@ class CheckpointData:
phase_id: str
phase_name: str
phase_status: str
phase_result: dict[str, Any] | None = None
phase_result: dict[str, object] | None = None
plan_status: str = ""
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> dict[str, object]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
def from_dict(cls, data: dict[str, object]) -> CheckpointData:
return cls(
plan_id=data.get("plan_id", ""),
phase_id=data.get("phase_id", ""),
@ -67,7 +89,7 @@ class PipelineCheckpoint:
def __init__(
self,
redis_client: Any = None,
redis_client: _RedisLike | None = None,
prefix: str = _KEY_PREFIX,
ttl_seconds: int = _TTL_SECONDS,
) -> None:
@ -78,7 +100,7 @@ class PipelineCheckpoint:
# P1 #6: 改用 dict keyed by phase_id避免重复 append
self._memory: dict[str, dict[str, CheckpointData]] = {}
# 内存降级存储plan_id → (plan_dict, saved_timestamp)
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
self._memory_plans: dict[str, tuple[dict[str, object], float]] = {}
def _is_expired(self, saved_at: str) -> bool:
"""检查 checkpoint 是否已过期(内存模式 TTL"""
@ -102,7 +124,7 @@ class PipelineCheckpoint:
"""完整 plan JSON 的存储键。"""
return f"{self._prefix}:plan:{plan_id}"
async def save_plan(self, plan: Any) -> None:
async def save_plan(self, plan: _PlanLike) -> None:
"""保存完整 TeamPlan用于 resume 重建)。
Args:
@ -121,7 +143,7 @@ class PipelineCheckpoint:
except Exception as e:
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
async def load_plan(self, plan_id: str) -> dict[str, object] | None:
"""加载完整 plan JSON。"""
# 优先 Redis
if self._redis is not None:
@ -142,7 +164,7 @@ class PipelineCheckpoint:
return None
return plan_dict
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
async def save(self, plan_id: str, phase: object, plan_status: str) -> None:
"""保存阶段 checkpoint。
Args:
@ -212,7 +234,8 @@ class PipelineCheckpoint:
if not phase_ids:
# Redis 无数据,检查内存(过滤过期)
return [
c for c in self._memory.get(plan_id, {}).values()
c
for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
]
@ -236,8 +259,7 @@ class PipelineCheckpoint:
# 内存降级(过滤过期 checkpoint
return [
c for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at)
]
async def clear(self, plan_id: str) -> None:

View File

@ -1,8 +1,8 @@
"""Saga compensation pattern for Pipeline execution"""
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from dataclasses import dataclass
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
@ -12,7 +12,7 @@ class CompletedStep:
"""Record of a completed step with its compensation"""
step_name: str
result: Any
result: object
compensate_action: str | None = None
@ -28,9 +28,7 @@ class CompensationResult:
class SagaOrchestrator:
"""Orchestrates LIFO compensation for failed pipelines"""
def __init__(
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
):
def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None):
"""
Args:
execute_skill_func: Async function to execute a skill by name
@ -42,7 +40,7 @@ class SagaOrchestrator:
def record_completed(
self,
step_name: str,
result: Any,
result: object,
compensate_action: str | None = None,
):
"""Record a completed step for potential compensation"""
@ -59,9 +57,7 @@ class SagaOrchestrator:
results: list[CompensationResult] = []
for step in reversed(self._completed_steps):
if step.compensate_action is None:
logger.info(
f"No compensation for step '{step.step_name}', skipping"
)
logger.info(f"No compensation for step '{step.step_name}', skipping")
results.append(
CompensationResult(
step_name=step.step_name,
@ -82,9 +78,7 @@ class SagaOrchestrator:
)
)
except Exception as e:
logger.error(
f"Compensation for step '{step.step_name}' failed: {e}"
)
logger.error(f"Compensation for step '{step.step_name}' failed: {e}")
results.append(
CompensationResult(
step_name=step.step_name,

View File

@ -4,7 +4,6 @@
"""
import logging
from typing import Any
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
class DynamicPipeline:
"""动态 Pipeline 组合器"""
def __init__(self, engine: PipelineEngine, loader: Any = None):
def __init__(self, engine: PipelineEngine, loader: object | None = None):
self._engine = engine
self._loader = loader
@ -23,7 +22,7 @@ class DynamicPipeline:
self,
pipelines: dict[str, Pipeline],
condition_key: str,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""根据条件选择子 Pipeline 执行"""
context = context or {}
@ -37,14 +36,16 @@ class DynamicPipeline:
)
selected = pipelines[condition_value]
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
logger.info(
f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}"
)
return await self._engine.execute(selected, context)
async def execute_nested(
self,
parent: Pipeline,
sub_pipeline_map: dict[str, Pipeline],
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""执行嵌套 Pipeline"""
# 先执行父 Pipeline
@ -52,7 +53,7 @@ class DynamicPipeline:
# 根据父 Pipeline 结果选择子 Pipeline
for stage_name, stage_result in parent_result.stage_results.items():
if hasattr(stage_result, 'output_data') and stage_result.output_data:
if hasattr(stage_result, "output_data") and stage_result.output_data:
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
sub = sub_pipeline_map[sub_pipeline_name]
@ -66,7 +67,7 @@ class DynamicPipeline:
pipeline: Pipeline,
max_iterations: int = 5,
exit_condition: str = "done",
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""循环执行 Pipeline 直到条件满足"""
current_context = context or {}

View File

@ -3,25 +3,42 @@
import asyncio
import json
import logging
from typing import Any
from typing import Awaitable, Callable, Protocol
from agentkit.core.protocol import HandoffMessage
logger = logging.getLogger(__name__)
class _RedisPubSubLike(Protocol):
"""Structural type for Redis pubsub object."""
async def subscribe(self, channel: str) -> None: ...
async def unsubscribe(self, channel: str) -> None: ...
def listen(self) -> object: ...
class _RedisLike(Protocol):
"""Structural type for async Redis client."""
async def publish(self, channel: str, message: str) -> int: ...
def pubsub(self) -> _RedisPubSubLike: ...
class HandoffManager:
"""Handoff 管理器
通过 Redis Pub/Sub 管理 Agent 间的任务转交
"""
def __init__(self, redis: Any = None, dispatcher: Any = None):
def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None):
self._redis = redis
self._dispatcher = dispatcher
self._handlers: dict[str, list[Any]] = {}
self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {}
def register_handler(self, agent_name: str, handler: Any) -> None:
def register_handler(
self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]]
) -> None:
"""注册 Handoff 处理器"""
if agent_name not in self._handlers:
self._handlers[agent_name] = []

View File

@ -1,10 +1,12 @@
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any
from typing import TYPE_CHECKING
from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import (
@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _DispatcherLike(Protocol):
async def dispatch(self, task: object) -> None: ...
async def get_task_status(self, task_id: str) -> dict[str, object]: ...
class _StateManagerLike(Protocol):
async def create_execution(self, **kwargs: object) -> str: ...
async def update_step(self, **kwargs: object) -> None: ...
async def complete_execution(self, **kwargs: object) -> None: ...
async def fail_execution(self, **kwargs: object) -> None: ...
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineEngine:
"""Pipeline 执行引擎
@ -38,7 +57,12 @@ class PipelineEngine:
- 状态持久化可选
"""
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
def __init__(
self,
dispatcher: _DispatcherLike | None = None,
state_manager: _StateManagerLike | None = None,
llm_gateway: _LLMGatewayLike | None = None,
):
self._dispatcher = dispatcher
self._state_manager = state_manager
self._llm_gateway = llm_gateway
@ -46,7 +70,7 @@ class PipelineEngine:
async def execute(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
adaptive_config: AdaptiveConfig | None = None,
) -> PipelineResult:
"""执行 Pipeline
@ -68,7 +92,7 @@ class PipelineEngine:
async def _adaptive_loop(
self,
pipeline: Pipeline,
context: dict[str, Any] | None,
context: dict[str, object] | None,
failed_result: PipelineResult,
adaptive_config: AdaptiveConfig,
) -> PipelineResult:
@ -92,34 +116,30 @@ class PipelineEngine:
# Replan
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
logger.info(
f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)"
)
# Re-execute
current_result = await self._execute_pipeline(new_pipeline, context)
current_pipeline = new_pipeline
# Record reflection in metadata
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
if current_result.status == StageStatus.COMPLETED:
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
return current_result
# Exhausted reflections
logger.warning(
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
)
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)")
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
return current_result
async def _execute_pipeline(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
context: dict[str, object] | None = None,
) -> PipelineResult:
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
result = PipelineResult(pipeline_name=pipeline.name)
@ -151,7 +171,9 @@ class PipelineEngine:
# 逐层执行
for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
logger.info(
f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)"
)
# 并行执行同层 stages
tasks = []
@ -173,9 +195,11 @@ class PipelineEngine:
# Update step state
if self._state_manager is not None and execution_id is not None:
try:
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
step_output = sr.output_data if hasattr(sr, 'output_data') else None
step_error = sr.error_message if hasattr(sr, 'error_message') else None
step_status = (
"completed" if sr.status == StageStatus.COMPLETED else sr.status.value
)
step_output = sr.output_data if hasattr(sr, "output_data") else None
step_error = sr.error_message if hasattr(sr, "error_message") else None
await self._state_manager.update_step(
execution_id=execution_id,
step_name=stage.name,
@ -189,19 +213,21 @@ class PipelineEngine:
# 收集输出变量
if sr.output_data and isinstance(sr, dict):
pass
elif hasattr(sr, 'output_data') and sr.output_data:
elif hasattr(sr, "output_data") and sr.output_data:
for output_key in stage.outputs:
if output_key in sr.output_data:
result.variables[output_key] = sr.output_data[output_key]
# 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
if hasattr(sr, "status") and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure:
# Execute Saga compensation for completed steps
compensation_results = await saga.compensate()
if compensation_results:
failed_compensations = [
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
cr
for cr in compensation_results
if not cr.success and cr.error != "no_compensation_needed"
]
if failed_compensations:
logger.warning(
@ -219,7 +245,12 @@ class PipelineEngine:
step_name=stage.name,
error=result.error_message,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
except (
asyncio.TimeoutError,
ConnectionError,
RuntimeError,
ValueError,
) as exc:
logger.warning(f"Failed to persist failure state: {exc}")
return result
@ -252,7 +283,9 @@ class PipelineEngine:
started_at = datetime.now(timezone.utc).isoformat()
# 条件检查
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
if stage.condition and not self._evaluate_condition(
stage.condition, pipeline_result.variables
):
return StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
@ -312,7 +345,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
@ -406,7 +441,7 @@ class PipelineEngine:
return resolved
@staticmethod
def _get_nested(data: dict, path: str) -> Any:
def _get_nested(data: dict, path: str) -> object:
keys = path.split(".")
current = data
for key in keys:
@ -497,9 +532,7 @@ class PipelineEngine:
if verifier_feedback.passed:
# 审查通过,返回成功结果
logger.info(
f"Stage '{stage.name}' passed review in round {round_num}"
)
logger.info(f"Stage '{stage.name}' passed review in round {round_num}")
worker_result.output_data = worker_result.output_data or {}
worker_result.output_data["adversarial_metadata"] = {
"passed_round": round_num,
@ -553,7 +586,7 @@ class PipelineEngine:
self,
agent_name: str,
action: str,
input_data: dict[str, Any],
input_data: dict[str, object],
stage: PipelineStage,
started_at: str,
timeout_seconds: int | None = None,
@ -568,7 +601,9 @@ class PipelineEngine:
started_at: 开始时间
timeout_seconds: 独立超时时间不传则使用 stage.timeout_seconds
"""
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
effective_timeout = (
timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
)
if self._dispatcher is None:
# Dry-run 模式
return StageResult(
@ -602,7 +637,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
@ -639,7 +676,7 @@ class PipelineEngine:
async def _execute_verifier(
self,
verifier_name: str,
worker_output: dict[str, Any],
worker_output: dict[str, object],
stage: PipelineStage,
started_at: str,
) -> ReviewFeedback:
@ -679,10 +716,7 @@ class PipelineEngine:
try:
feedback = ReviewFeedback(
passed=output_data.get("passed", False),
issues=[
ReviewIssue(**issue)
for issue in output_data.get("issues", [])
],
issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])],
summary=output_data.get("summary", "No summary provided"),
score=output_data.get("score", 0.0),
)
@ -699,7 +733,7 @@ class PipelineEngine:
self,
feedback: ReviewFeedback,
feedback_mode: str = "structured+natural",
) -> dict[str, Any]:
) -> dict[str, object]:
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
Args:
@ -720,7 +754,7 @@ class PipelineEngine:
for issue in feedback.issues
]
feedback_context: dict[str, Any] = {
feedback_context: dict[str, object] = {
"previous_attempt_failed": True,
}
@ -756,7 +790,9 @@ class PipelineEngine:
)
else:
# 未知模式fallback 到 structured+natural
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
logger.warning(
f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural"
)
feedback_context["review_feedback"] = {
"summary": feedback.summary,
"issues": issues_list,

View File

@ -2,7 +2,6 @@
import logging
from pathlib import Path
from typing import Any
import yaml
@ -23,7 +22,9 @@ class PipelineLoader:
if not yaml_path.exists():
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
if not yaml_path.exists():
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
raise FileNotFoundError(
f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}"
)
content = yaml_path.read_text(encoding="utf-8")
return self.load_from_yaml(content, pipeline_name)

View File

@ -35,9 +35,7 @@ class PipelineExecutionModel(Base):
)
completed_at = Column(DateTime)
__table_args__ = (
Index("ix_pipeline_status_created", "status", "created_at"),
)
__table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),)
class PipelineStepHistoryModel(Base):

View File

@ -1,7 +1,7 @@
"""Pipeline 数据模型"""
from enum import Enum
from typing import Any, Literal
from typing import Literal
from pydantic import BaseModel, Field
@ -18,8 +18,11 @@ class StageStatus(str, Enum):
class ReviewIssue(BaseModel):
"""单条审查问题"""
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别")
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(
description="问题类别"
)
description: str = Field(min_length=1, description="问题描述")
location: str | None = Field(default=None, description="文件路径/行号")
suggestion: str | None = Field(default=None, description="修复建议")
@ -27,6 +30,7 @@ class ReviewIssue(BaseModel):
class ReviewFeedback(BaseModel):
"""Verifier 返回的结构化审查反馈"""
passed: bool = Field(description="是否通过审查")
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
summary: str = Field(min_length=1, description="自然语言审查报告")
@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel):
class AdversarialState(BaseModel):
"""对抗轮次状态追踪"""
current_round: int = Field(default=0, description="当前对抗轮次")
max_rounds: int = Field(default=3, description="最大对抗轮次")
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
@ -46,7 +51,7 @@ class PipelineStage(BaseModel):
agent: str
action: str
depends_on: list[str] = []
inputs: dict[str, Any] = {}
inputs: dict[str, object] = {}
outputs: list[str] = []
timeout_seconds: int = 300
retry_count: int = 0
@ -56,10 +61,17 @@ class PipelineStage(BaseModel):
compensate: str | None = None
# 对抗闭环相关字段
verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式")
verifier: str | None = Field(
default=None, description="Verifier Agent 名称,配置后启用对抗模式"
)
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds")
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式")
verifier_timeout_seconds: int = Field(
default=120,
description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds",
)
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(
default="structured+natural", description="反馈模式"
)
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
model_config = {"arbitrary_types_allowed": True}
@ -70,13 +82,13 @@ class Pipeline(BaseModel):
version: str
description: str
stages: list[PipelineStage]
variables: dict[str, Any] = {}
variables: dict[str, object] = {}
class StageResult(BaseModel):
stage_name: str
status: StageStatus = StageStatus.PENDING
output_data: dict[str, Any] | None = None
output_data: dict[str, object] | None = None
error_message: str | None = None
started_at: str | None = None
completed_at: str | None = None
@ -86,9 +98,9 @@ class PipelineResult(BaseModel):
pipeline_name: str
status: StageStatus = StageStatus.PENDING
stage_results: dict[str, StageResult] = {}
variables: dict[str, Any] = {}
variables: dict[str, object] = {}
error_message: str | None = None
metadata: dict[str, Any] = {}
metadata: dict[str, object] = {}
class AdaptiveConfig(BaseModel):

View File

@ -13,7 +13,7 @@ import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
from typing import Callable, Coroutine
from agentkit.orchestrator.pipeline_models import (
PipelineExecutionModel,
@ -183,7 +183,7 @@ class PipelineStateRedis:
return self._redis
async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object
) -> object | None:
"""Execute a Redis call, falling back to memory on failure.

View File

@ -4,22 +4,30 @@
生成修正后的 Pipeline 重新执行
"""
from __future__ import annotations
import json
import logging
from typing import Any
from typing import TYPE_CHECKING
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineReflector:
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
@ -27,7 +35,7 @@ class PipelineReflector:
输出 ReflectionReport 包含 failure_typeroot_cause suggested_fix
"""
def __init__(self, llm_gateway: Any = None):
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway
async def reflect(
@ -54,19 +62,25 @@ class PipelineReflector:
if self._llm_gateway is not None:
try:
return await self._llm_reflect(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
pipeline,
failed_stage,
error_message,
completed_outputs,
reflection_number,
)
except Exception as e:
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
# 规则兜底:基于错误信息分类
return self._rule_based_reflect(
failed_stage, error_message, reflection_number,
failed_stage,
error_message,
reflection_number,
)
def _find_failure(
self, result: PipelineResult,
self,
result: PipelineResult,
) -> tuple[str, str]:
"""找到第一个失败的 stage 及其错误信息。"""
for name, sr in result.stage_results.items():
@ -75,8 +89,9 @@ class PipelineReflector:
return "", "no failed stage found"
def _collect_completed_outputs(
self, result: PipelineResult,
) -> dict[str, Any]:
self,
result: PipelineResult,
) -> dict[str, object]:
"""收集已完成步骤的输出。"""
outputs = {}
for name, sr in result.stage_results.items():
@ -89,13 +104,16 @@ class PipelineReflector:
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
completed_outputs: dict[str, object],
reflection_number: int,
) -> ReflectionReport:
"""使用 LLM 分析失败原因。"""
prompt = self._build_reflection_prompt(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
pipeline,
failed_stage,
error_message,
completed_outputs,
reflection_number,
)
response = await self._llm_gateway.chat(
@ -106,7 +124,9 @@ class PipelineReflector:
# 解析 LLM 返回的 JSON
content = response.content if hasattr(response, "content") else str(response)
return self._parse_reflection_response(
content, failed_stage, reflection_number,
content,
failed_stage,
reflection_number,
)
def _build_reflection_prompt(
@ -114,15 +134,14 @@ class PipelineReflector:
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
completed_outputs: dict[str, object],
reflection_number: int,
) -> str:
"""构建反思提示词。"""
stage_descriptions = []
for s in pipeline.stages:
stage_descriptions.append(
f" - {s.name}: agent={s.agent}, action={s.action}, "
f"depends_on={s.depends_on}"
f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}"
)
completed_summary = json.dumps(
@ -174,7 +193,9 @@ JSON response:"""
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM reflection response: {e}")
return self._rule_based_reflect(
failed_stage, content, reflection_number,
failed_stage,
content,
reflection_number,
)
def _rule_based_reflect(
@ -218,7 +239,7 @@ class PipelineReplanner:
保留已完成步骤的结果仅重新规划失败及后续步骤
"""
def __init__(self, llm_gateway: Any = None):
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway
async def replan(
@ -255,8 +276,7 @@ class PipelineReplanner:
) -> Pipeline:
"""使用 LLM 生成修正后的 Pipeline。"""
completed_stages = [
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
]
prompt = f"""Based on the reflection report, generate a corrected pipeline.
@ -284,7 +304,9 @@ JSON pipeline:"""
return self._parse_pipeline_response(content, pipeline)
def _parse_pipeline_response(
self, content: str, original: Pipeline,
self,
content: str,
original: Pipeline,
) -> Pipeline:
"""解析 LLM 返回的 Pipeline JSON。"""
try:
@ -294,9 +316,7 @@ JSON pipeline:"""
text = "\n".join(lines[1:-1])
data = json.loads(text)
stages = [
PipelineStage(**s) for s in data.get("stages", [])
]
stages = [PipelineStage(**s) for s in data.get("stages", [])]
return Pipeline(
name=data.get("name", original.name),
version=data.get("version", original.version),
@ -316,8 +336,7 @@ JSON pipeline:"""
) -> Pipeline:
"""基于规则的兜底重规划。"""
completed_stages = {
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
}
# 构建修正后的 stages 列表
@ -345,17 +364,21 @@ JSON pipeline:"""
)
def _adjust_failed_stage(
self, stage: PipelineStage, report: ReflectionReport,
self,
stage: PipelineStage,
report: ReflectionReport,
) -> PipelineStage:
"""根据反思报告调整失败的步骤。"""
adjustments: dict[str, Any] = {}
adjustments: dict[str, object] = {}
if report.failure_type == "timeout":
adjustments["timeout_seconds"] = min(
stage.timeout_seconds * 2, 3600,
stage.timeout_seconds * 2,
3600,
)
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
elif report.failure_type == "resource_error":
@ -365,6 +388,7 @@ JSON pipeline:"""
# 添加重试策略,可能输入在后续可用
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
return stage.model_copy(update=adjustments)

View File

@ -4,7 +4,7 @@ import asyncio
import logging
import random
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from typing import Awaitable, Callable
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class StepRetryPolicy:
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt number (0-based)"""
delay = min(
self.base_delay * (self.exponential_base ** attempt),
self.base_delay * (self.exponential_base**attempt),
self.max_delay,
)
if self.jitter:
@ -36,10 +36,10 @@ class StepRetryPolicy:
async def execute_with_retry(
func: Callable[..., Awaitable[Any]],
func: Callable[..., Awaitable[object]],
retry_policy: StepRetryPolicy | None = None,
step_name: str = "",
) -> Any:
) -> object:
"""Execute a function with retry policy"""
if retry_policy is None:
return await func()

View File

@ -3,18 +3,17 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, Field
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
from agentkit.orchestrator.pipeline_schema import PipelineStage
class WorkflowStage(PipelineStage):
"""A workflow stage extending PipelineStage with type and config."""
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
config: dict[str, Any] = Field(default_factory=dict)
config: dict[str, object] = Field(default_factory=dict)
class WorkflowDefinition(BaseModel):
@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel):
name: str
version: int = 1
stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict)
triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, object] = Field(default_factory=dict)
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel):
workflow_id: str = ""
status: str = "pending" # pending|running|paused|completed|failed|cancelled
current_stage: str | None = None
stage_results: dict[str, Any] = Field(default_factory=dict)
stage_results: dict[str, object] = Field(default_factory=dict)
started_at: str | None = None
completed_at: str | None = None
error: str | None = None
variables: dict[str, Any] = Field(default_factory=dict)
variables: dict[str, object] = Field(default_factory=dict)
class WorkflowSummary(BaseModel):
@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel):
name: str
stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict)
triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, object] = Field(default_factory=dict)
class ExecuteWorkflowRequest(BaseModel):
"""Request body for executing a workflow."""
variables: dict[str, Any] = Field(default_factory=dict)
variables: dict[str, object] = Field(default_factory=dict)
class ApproveRequest(BaseModel):

View File

@ -37,7 +37,7 @@ async def create_agent(request: CreateAgentRequest, req: Request):
config_dict = request.config
try:
config = SkillConfig.from_dict(config_dict)
except Exception:
except (ValueError, KeyError, TypeError):
config = AgentConfig.from_dict(config_dict)
agent = await pool.create_agent(config)
else:

View File

@ -53,6 +53,8 @@ from agentkit.server.auth.session_service import (
REVOKE_REASON_PASSWORD_CHANGED,
REVOKE_REASON_USER_TERMINATED,
SessionCreate,
SessionNotFound,
SessionReuseDetected,
SessionService,
get_session_service,
)
@ -253,7 +255,7 @@ def _is_legacy_client(request: Request) -> bool:
if client_v is None or cutoff_v is None:
return False
return client_v < cutoff_v
except Exception: # noqa: BLE001
except (ValueError, TypeError, AttributeError): # noqa: BLE001
logger.debug("Failed to parse X-Client-Version %r", raw)
return False
@ -492,7 +494,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
# 1. Verify signature + type
try:
refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh")
except Exception as exc: # noqa: BLE001
except (jwt.PyJWTError, ValueError, KeyError) as exc: # noqa: BLE001
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
# 2-3. Validate the session (also handles reuse detection)
@ -510,7 +512,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
new_refresh_token=new_pair.refresh_token,
new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()),
)
except Exception as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound
except (SessionReuseDetected, SessionNotFound, ValueError, KeyError, RuntimeError) as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound
logger.info("Refresh rejected: %s", exc)
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc

View File

@ -483,7 +483,7 @@ async def validate_formula(
parse_formula(body.formula)
except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e:
return {"valid": False, "error": str(e)}
except Exception as e: # pragma: no cover — defensive
except (ValueError, TypeError, KeyError, AttributeError) as e: # pragma: no cover — defensive
return {"valid": False, "error": f"Unexpected error: {e}"}
return {"valid": True}
@ -750,7 +750,7 @@ async def upload_file(
f.write(chunk)
except HTTPException:
raise
except Exception as exc:
except (OSError, RuntimeError) as exc:
file_path.unlink(missing_ok=True)
logger.error(f"Failed to save uploaded bitable file: {exc}")
raise HTTPException(status_code=500, detail="Failed to save file") from exc

View File

@ -553,7 +553,7 @@ async def _invalidate_adapter_cache(channel_id: str) -> None:
if old is not None:
try:
await old.close()
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更
except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001 — 关闭异常不应阻塞配置变更
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
@ -562,7 +562,7 @@ async def close_all_adapters() -> None:
for channel_id, adapter in list(_adapter_cache.items()):
try:
await adapter.close()
except Exception: # noqa: BLE001
except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
_adapter_cache.clear()
@ -614,6 +614,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
model=routing.model or "default",
)
final_content = getattr(result, "content", "") or ""
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
final_content = await _direct_chat(llm_gateway, routing)
@ -628,6 +630,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
content=final_content,
)
await adapter.send_message(outgoing)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
logger.exception("处理入站消息失败: %s", exc)
@ -703,7 +707,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
except WeComURLVerification as e:
# 企微 URL 验证 — 返回 XML 响应
return Response(content=e.response_xml, media_type="application/xml")
except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
except (ValueError, KeyError, RuntimeError, AttributeError, OSError) as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
return {"code": 0, "msg": "invalid_payload"}

View File

@ -108,7 +108,7 @@ class ChatConnectionManager:
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
stale.append(ws)
for ws in stale:
self.remove(session_id, ws)
@ -295,6 +295,8 @@ async def _execute_board_meeting(
await team.create_board(topic=routing_result.topic, expert_configs=expert_configs)
orchestrator = BoardOrchestrator(team=team)
result = await orchestrator.execute(routing_result.topic)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True)
await websocket.send_json(
@ -302,7 +304,7 @@ async def _execute_board_meeting(
)
try:
await team.dissolve()
except Exception:
except (RuntimeError, asyncio.TimeoutError, ConnectionError):
pass
return True
finally:
@ -348,7 +350,7 @@ async def _execute_board_meeting(
# Dissolve the team to release expert agents
try:
await team.dissolve()
except Exception as e:
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.warning(f"Board team dissolve failed: {e}")
return True
@ -467,7 +469,7 @@ async def _execute_team_collab(
# Always dissolve the team and remove handler to avoid leaks
try:
await team.dissolve()
except Exception as e:
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.warning(f"Team dissolve failed: {e}")
# dissolve() already clears handlers via handoff_transport.close()
@ -585,7 +587,7 @@ def _build_phase_engine(
if phase_policy is None:
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
phase_policy = default_policy()
except Exception as e:
except (ValueError, TypeError, KeyError) as e:
logger.error(
"PLAN_EXEC phase policy construction failed for session %s: %s",
session_id,
@ -695,6 +697,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
agent_name=agent.name,
system_prompt=system_prompt,
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@ -773,6 +777,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
return response_dict
return response
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@ -829,7 +835,7 @@ async def _resolve_ws_dept_context(
db_path_resolved = Path(db_path)
try:
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.exception(
"Failed to fetch department ids for WebSocket user %s — fail-closed",
user_id,
@ -946,7 +952,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"data": {"content": content},
}
)
except Exception as e:
except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
logger.warning(f"Failed to enqueue intervention: {e}")
await websocket.send_json(
{
@ -1022,11 +1028,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass
finally:
# Clean up pending futures
@ -1174,6 +1182,8 @@ async def _handle_chat_message(
content=final_content,
agent_name=agent.name,
)
except asyncio.CancelledError:
raise
except Exception as e:
# Check if this is a QuotaExceededError (U4: WebSocket quota).
from agentkit.llm.gateway import QuotaExceededError
@ -1422,6 +1432,8 @@ async def _handle_chat_message(
agent_name=agent.name,
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
# Show meaningful error to user, but avoid leaking full stack traces
@ -1473,6 +1485,8 @@ async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]:
file_path.write_bytes(contents)
except HTTPException:
raise
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error(f"Failed to save uploaded file: {exc}")
raise HTTPException(status_code=500, detail="Failed to save file") from exc

View File

@ -64,7 +64,7 @@ def _collect_skill_configs(request: Request) -> list[dict[str, Any]]:
"version": getattr(skill.config, "version", "1.0.0"),
"config": config_dict,
})
except Exception as e:
except (AttributeError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Failed to collect skill configs: {e}")
return configs
@ -81,7 +81,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
from agentkit.server.routes.workflows import _workflow_store
workflow_store = _workflow_store
except Exception:
except (ImportError, AttributeError):
return []
configs: list[dict[str, Any]] = []
@ -94,7 +94,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
configs.append(wf.to_dict())
else:
configs.append(dict(wf))
except Exception as e:
except (RuntimeError, AttributeError, ValueError) as e:
logger.warning(f"Failed to collect workflow configs: {e}")
return configs

View File

@ -13,6 +13,7 @@ Endpoints:
from __future__ import annotations
import asyncio
import hmac
import logging
import uuid
@ -155,6 +156,8 @@ async def create_document(
raise HTTPException(status_code=400, detail=str(e)) from e
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Document creation failed: {e}")
raise HTTPException(status_code=500, detail="Document creation failed") from e
@ -187,7 +190,7 @@ async def upload_template(
file_path.write_bytes(contents)
except HTTPException:
raise
except Exception as exc:
except (OSError, RuntimeError) as exc:
logger.error(f"Failed to save template: {exc}")
raise HTTPException(status_code=500, detail="Failed to save template") from exc
finally:

View File

@ -1,5 +1,6 @@
"""Evolution API routes"""
import asyncio
import logging
from fastapi import APIRouter, HTTPException, Request
@ -42,6 +43,8 @@ async def list_evolution_events(
agent_name=agent_name,
change_type=event_type,
)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to list evolution events: {e}")
raise HTTPException(status_code=500, detail="Failed to list evolution events")
@ -63,6 +66,8 @@ async def get_skill_versions(skill_name: str, req: Request = None):
store = _get_evolution_store(req)
try:
versions = await store.list_skill_versions(skill_name)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to get skill versions for '{skill_name}': {e}")
raise HTTPException(status_code=500, detail="Failed to get skill versions")
@ -103,6 +108,8 @@ async def trigger_evolution(request: TriggerEvolutionRequest, req: Request = Non
)
try:
event_id = await store.record(event)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to record trigger event: {e}")
raise HTTPException(status_code=500, detail="Failed to trigger evolution")
@ -153,6 +160,8 @@ async def list_ab_tests(
for e in entries
]
return {"items": results[:limit], "total": len(results)}
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to list A/B tests from persistent store: {e}")
raise HTTPException(status_code=500, detail="Failed to list A/B tests")

View File

@ -194,7 +194,7 @@ async def list_experiences(
}
)
return {"experiences": experiences, "total": len(experiences)}
except Exception as e:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.error(f"Failed to list experiences from store: {e}")
# Fallback to in-memory store
@ -324,7 +324,7 @@ async def get_metrics(
# Generate daily trends from the metrics
trends = _generate_trends(metrics_list, period)
except Exception as e:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.error(f"Failed to get metrics from store: {e}")
else:
# Generate from in-memory experiences
@ -501,7 +501,7 @@ async def get_usage(
"errors": 0,
"avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1),
})
except Exception as e:
except (ConnectionError, OSError, ValueError, KeyError, RuntimeError, AttributeError) as e:
logger.error(f"Failed to get usage from LLMGateway: {e}")
# Fill in missing dates with zero
@ -587,7 +587,7 @@ async def check_pitfalls(
})
return {"warnings": warnings_data}
except Exception as e:
except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Failed to check pitfalls: {e}")
return {"warnings": []}
@ -642,7 +642,7 @@ async def list_path_optimizations(
else None,
}
)
except Exception as e:
except (ValueError, KeyError, RuntimeError, AttributeError) as e:
logger.error(f"Failed to get path optimizations: {e}")
# Also include in-memory optimizations
@ -767,7 +767,7 @@ async def evolution_dashboard_ws(websocket: WebSocket):
)
except WebSocketDisconnect:
logger.debug("Evolution dashboard WebSocket disconnected")
except Exception as e:
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Evolution dashboard WebSocket error: {e}")
finally:
if websocket in _ws_connections:
@ -781,7 +781,7 @@ async def _broadcast_event(event_type: str, data: dict):
for ws in _ws_connections:
try:
await ws.send_json(message)
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
disconnected.append(ws)
for ws in disconnected:
if ws in _ws_connections:

View File

@ -1,5 +1,7 @@
"""Health check route"""
import asyncio
from fastapi import APIRouter, Request
router = APIRouter(tags=["health"])
@ -23,14 +25,14 @@ async def health_check(request: Request):
redis_client = await task_store._get_redis()
await redis_client.ping()
redis_status = "available"
except Exception as ping_exc:
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as ping_exc:
redis_status = f"error: {str(ping_exc)[:100]}"
overall_status = "degraded"
else:
redis_status = "not_configured"
else:
redis_status = "not_configured"
except Exception as exc:
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError, AttributeError) as exc:
redis_status = f"error: {str(exc)[:100]}"
overall_status = "degraded"
checks["redis"] = redis_status
@ -42,7 +44,7 @@ async def health_check(request: Request):
try:
agents = agent_pool.list_agents()
pool_size = len(agents)
except Exception:
except (RuntimeError, AttributeError):
pass
checks["agent_pool"] = {"status": "available", "size": pool_size}
@ -57,7 +59,7 @@ async def health_check(request: Request):
else:
llm_status = "no_providers"
overall_status = "degraded"
except Exception:
except (RuntimeError, AttributeError, ValueError):
llm_status = "error"
overall_status = "degraded"
checks["llm_gateway"] = llm_status
@ -68,7 +70,7 @@ async def health_check(request: Request):
if skill_registry:
try:
skill_count = len(skill_registry.list_skills())
except Exception:
except (RuntimeError, AttributeError):
pass
checks["skill_registry"] = {
"status": "available" if skill_registry else "not_configured",

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import hmac
import logging
import os
@ -252,7 +253,7 @@ async def list_sources(
visible_ids = await filter_kb_sources_by_department(
db_path, dept_ctx.department_ids, all_ids
)
except Exception: # noqa: BLE001 — never block listing on DB errors
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
logger.exception("Department KB filtering failed — returning empty list")
return {"sources": []}
visible_set = set(visible_ids)
@ -398,7 +399,7 @@ async def list_documents(
visible_ids = await filter_kb_sources_by_department(
db_path, dept_ctx.department_ids, all_source_ids
)
except Exception: # noqa: BLE001 — never block listing on DB errors
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
logger.exception("Department KB filtering failed — returning empty list")
return {"documents": []}
visible_set = set(visible_ids)
@ -484,7 +485,7 @@ async def upload_document(
try:
text = processor.parse(tmp_path, file_type)
chunks = processor.segment(text)
except Exception as e:
except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
logger.warning("Document parsing failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
@ -567,7 +568,7 @@ async def preview_document(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
logger.warning("Document preview failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
@ -628,7 +629,7 @@ async def search_knowledge(
for r in results
]
}
except Exception as e:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Semantic search failed: {e}")
# Fallback: return empty results with a hint

View File

@ -115,7 +115,7 @@ async def publish_skill(
try:
skill = skill_registry.get(skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
try:

View File

@ -1,5 +1,6 @@
"""Memory API routes"""
import asyncio
import logging
from fastapi import APIRouter, HTTPException, Request
@ -40,6 +41,8 @@ async def search_episodic_memory(
if agent_name:
filters["agent_name"] = agent_name
items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to search episodic memory: {e}")
raise HTTPException(status_code=500, detail="Failed to search episodic memory")
@ -76,6 +79,8 @@ async def search_semantic_memory(
if knowledge_base_ids:
filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")]
items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to search semantic memory: {e}")
raise HTTPException(status_code=500, detail="Failed to search semantic memory")
@ -104,6 +109,8 @@ async def delete_episodic_memory(key: str, req: Request = None):
try:
deleted = await retriever._episodic.delete(key)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Failed to delete episodic memory '{key}': {e}")
raise HTTPException(status_code=500, detail="Failed to delete episodic memory")

View File

@ -1,5 +1,6 @@
"""Metrics route — /api/v1/metrics"""
import asyncio
import logging
from fastapi import APIRouter, Request
@ -29,7 +30,7 @@ async def get_metrics(request: Request):
task_metrics["completed_tasks"] = counts.get("completed", 0)
task_metrics["failed_tasks"] = counts.get("failed", 0)
task_metrics["pending_tasks"] = counts.get("pending", 0)
except Exception as e:
except (RuntimeError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
logger.warning(f"Failed to collect task metrics: {e}")
# Agent pool metrics
@ -41,7 +42,7 @@ async def get_metrics(request: Request):
try:
agents = agent_pool.list_agents()
agent_metrics["total_agents"] = len(agents)
except Exception as e:
except (RuntimeError, AttributeError) as e:
logger.warning(f"Failed to collect agent metrics: {e}")
# Skill registry metrics
@ -53,7 +54,7 @@ async def get_metrics(request: Request):
try:
skills = skill_registry.list_skills()
skill_metrics["total_skills"] = len(skills)
except Exception as e:
except (RuntimeError, AttributeError) as e:
logger.warning(f"Failed to collect skill metrics: {e}")
return {

View File

@ -94,7 +94,7 @@ async def _emit_event_safe(
data=data or {},
)
await event_queue.emit(event)
except Exception as e:
except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
@ -211,7 +211,7 @@ class PortalConnectionManager:
import asyncio
asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded"))
except Exception:
except (ConnectionError, RuntimeError):
pass
conns.append(ws)
@ -235,7 +235,7 @@ class PortalConnectionManager:
for ws in conns:
try:
await ws.send_json(message)
except Exception as e:
except (ConnectionError, RuntimeError, asyncio.TimeoutError) as e:
logger.debug(
"Portal WS send failed for user %s (marking stale): %s", user_id, e
)
@ -285,7 +285,7 @@ async def _build_history_messages(
"""
try:
history = await _conversation_store.get_history(conv_id, limit=limit)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
return []
# The last message in history is the current user message (just added),
@ -553,6 +553,8 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
):
if event.event_type == "final_answer":
collected_output.append(event.data.get("output", ""))
except asyncio.CancelledError:
raise
except Exception as e:
response_text = f"执行出错: {e}"
else:
@ -682,6 +684,8 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
}
),
}
except asyncio.CancelledError:
raise
except Exception as e:
yield {
"event": "error",
@ -862,7 +866,7 @@ async def _execute_react_background(
await _task_store_update_status(
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
async for event in react_engine.execute_stream(
@ -909,7 +913,7 @@ async def _execute_react_background(
progress=1.0,
progress_message="Completed",
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
# Emit task.completed so subscribers know the task is done
@ -932,7 +936,15 @@ async def _execute_react_background(
partial = _ensure_non_empty("".join(collected_output))
try:
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
except (Exception, asyncio.CancelledError):
except (
asyncio.CancelledError,
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.warning("Failed to persist partial output on cancel")
if task_store is not None:
try:
@ -945,7 +957,15 @@ async def _execute_react_background(
completed_at=datetime.now(timezone.utc),
)
)
except (Exception, asyncio.CancelledError):
except (
asyncio.CancelledError,
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.warning("Failed to update TaskStore on cancel", exc_info=True)
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
# Shield it so a re-entrant CancelledError doesn't kill the emit
@ -963,7 +983,7 @@ async def _execute_react_background(
},
)
)
except (Exception, asyncio.CancelledError):
except (asyncio.CancelledError, asyncio.QueueFull, RuntimeError, ConnectionError):
logger.warning("Failed to emit TASK_FAILED on cancel")
raise # Propagate cancellation
@ -973,7 +993,7 @@ async def _execute_react_background(
partial = _ensure_non_empty("".join(collected_output))
try:
await conversation_store.add_message(conv_id, "assistant", partial)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to persist partial output in background task")
if task_store is not None:
@ -985,7 +1005,7 @@ async def _execute_react_background(
error_message=str(e),
completed_at=datetime.now(timezone.utc),
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore FAILED", exc_info=True)
# Emit task.failed so subscribers know the task failed
@ -1120,7 +1140,7 @@ async def portal_websocket(websocket: WebSocket):
try:
record = await _task_store_get(resume_task_store, resume_task_id)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("TaskStore.get failed during resume", exc_info=True)
record = None
if record is not None:
@ -1333,7 +1353,7 @@ async def portal_websocket(websocket: WebSocket):
},
)
await _broadcast_dashboard_event("metrics_updated", {"period": "7d"})
except Exception as e:
except (asyncio.QueueFull, RuntimeError, ConnectionError, ValueError, KeyError) as e:
logger.warning(f"Failed to record experience: {e}")
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
@ -1414,7 +1434,7 @@ async def portal_websocket(websocket: WebSocket):
TaskStatus.PENDING,
metadata={"conversation_id": conv.id},
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to register task in TaskStore", exc_info=True)
# Execute based on routing result's execution_mode
@ -1455,7 +1475,7 @@ async def portal_websocket(websocket: WebSocket):
progress=1.0,
progress_message="Completed",
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
# Emit turn.final_answer and task.completed to EQ
@ -1526,7 +1546,7 @@ async def portal_websocket(websocket: WebSocket):
chat_messages.insert(
-1, {"role": hist_msg.role, "content": hist_msg.content}
)
except Exception:
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
pass
response = await llm_gateway.chat(
messages=chat_messages,
@ -1627,7 +1647,7 @@ async def portal_websocket(websocket: WebSocket):
logger.warning("EventQueue not configured; awaiting background task directly")
try:
await bg_task
except Exception:
except (RuntimeError, ConnectionError, asyncio.TimeoutError):
pass # errors handled inside _execute_react_background
active_bg_task = None
continue
@ -1734,6 +1754,8 @@ async def portal_websocket(websocket: WebSocket):
# kill the task, lose the full output, and mark it FAILED —
# defeating layers 2 and 3. The task is only cancelled on explicit
# user cancel (msg_type == 'cancel') or application shutdown.
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Portal WebSocket error: {e}")
# P1 #6 fix: Do NOT cancel the background task on connection-level
@ -1758,7 +1780,7 @@ async def portal_websocket(websocket: WebSocket):
)
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass
finally:
# Remove from user-scoped push tracking on any disconnect/error/return.

View File

@ -119,7 +119,7 @@ def _skill_to_detail(skill: Any) -> dict[str, Any]:
if hasattr(skill, "config"):
try:
config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {}
except Exception:
except (AttributeError, ValueError, TypeError):
config = {}
return {
@ -174,7 +174,7 @@ async def get_skill_detail(skill_name: str, req: Request):
skill_registry = req.app.state.skill_registry
try:
skill = skill_registry.get(skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
return _skill_to_detail(skill)
@ -186,7 +186,7 @@ async def check_skill_health(skill_name: str, req: Request):
skill_registry = req.app.state.skill_registry
try:
skill_registry.get(skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
# Basic health check - skill exists and is registered
@ -243,7 +243,7 @@ async def reload_skill(skill_name: str, req: Request):
# Verify the skill is currently registered (404 if not).
try:
skill_registry.get(skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
# Resolve the skills directory (mirrors routes.skills._get_skills_dir).

View File

@ -136,7 +136,7 @@ async def register_skill(request: RegisterSkillRequest, req: Request):
try:
config = SkillConfig.from_dict(request.config)
except Exception as e:
except (ValueError, TypeError, KeyError) as e:
raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}")
skill = Skill(config=config)
@ -279,7 +279,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
resp = await client.get(source)
resp.raise_for_status()
yaml_content = resp.text
except Exception as e:
except (httpx.HTTPError, OSError) as e:
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
elif source and source.startswith("file://"):
# Read from local file path
@ -295,7 +295,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
try:
with open(local_path, encoding="utf-8") as f:
yaml_content = f.read()
except Exception as e:
except OSError as e:
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
else:
# Search GitHub for skills (YAML config files)
@ -313,7 +313,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
},
)
gh_data = gh_resp.json()
except Exception as e:
except (httpx.HTTPError, OSError, ValueError) as e:
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
items = gh_data.get("items", [])
@ -334,7 +334,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
},
)
items = gh_resp2.json().get("items", [])
except Exception:
except (httpx.HTTPError, OSError, ValueError, KeyError):
items = []
if not items:
@ -362,7 +362,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
resp = await client.get(raw_url)
resp.raise_for_status()
yaml_content = resp.text
except Exception as e:
except (httpx.HTTPError, OSError) as e:
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
# Validate YAML content before writing to disk
@ -391,14 +391,14 @@ async def install_skill(request: InstallSkillRequest, req: Request):
)
loader.load_from_file(file_path)
registration_ok = True
except Exception as e:
except (ValueError, TypeError, KeyError, OSError, RuntimeError) as e:
logger.warning(f"Failed to register installed skill: {e}")
if not registration_ok:
# Remove the invalid YAML file and report error
try:
os.remove(file_path)
except Exception:
except OSError:
pass
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
@ -419,7 +419,7 @@ async def uninstall_skill(name: str, req: Request):
try:
skill_registry.get(validated_name)
except Exception:
except (KeyError, ValueError, RuntimeError):
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
# Remove from registry
@ -487,7 +487,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ
try:
result = await pipeline.execute(input_data=request.input_data)
except Exception as e:
except (ValueError, TypeError, KeyError, RuntimeError, OSError, ConnectionError) as e:
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Pipeline execution failed")

View File

@ -24,7 +24,7 @@ def _read_meminfo() -> dict[str, int]:
key = parts[0].strip()
value = parts[1].strip().split()[0]
values[key] = int(value) * 1024 # kB -> bytes
except Exception as exc:
except (OSError, ValueError, FileNotFoundError) as exc:
logger.debug(f"Failed to read /proc/meminfo: {exc}")
return values
@ -52,7 +52,7 @@ async def get_system_resources() -> dict[str, Any]:
if hasattr(os, "getloadavg"):
try:
loadavg = list(os.getloadavg())
except Exception as exc:
except (OSError, AttributeError) as exc:
logger.debug(f"Failed to get loadavg: {exc}")
meminfo = _read_meminfo()
@ -68,7 +68,7 @@ async def get_system_resources() -> dict[str, Any]:
disk_total = du.total
disk_used = du.used
disk_free = du.free
except Exception as exc:
except (OSError, FileNotFoundError) as exc:
logger.debug(f"Failed to get disk usage: {exc}")
return {

View File

@ -83,7 +83,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
elif request.skill_name:
try:
skill = skill_registry.get(request.skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(
status_code=404,
detail=f"Skill '{request.skill_name}' not found",
@ -145,7 +145,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
quality_result = await quality_gate.validate(
task_result.output_data or {}, skill, skill_context=skill_context
)
except Exception:
except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError):
pass # Quality gate failure shouldn't block the response
# 7. Standardize output if skill available
@ -167,7 +167,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
"task_id": task.task_id,
"status": task_result.status,
}
except Exception:
except (ValueError, KeyError, AttributeError, RuntimeError):
pass # Fall through to raw output
# 8. Return raw result if no skill or standardization failed
@ -307,7 +307,7 @@ async def resume_task(task_id: str, req: Request, plan_id: str | None = None):
finally:
try:
await team.dissolve()
except Exception:
except (RuntimeError, asyncio.TimeoutError, ConnectionError):
pass
return {
@ -343,7 +343,7 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
elif request.skill_name:
try:
skill_registry.get(request.skill_name)
except Exception:
except (KeyError, ValueError, AttributeError):
raise HTTPException(
status_code=404,
detail=f"Skill '{request.skill_name}' not found",

View File

@ -359,14 +359,14 @@ async def execute_command(
await check_pty.close()
if cwd_result.exit_code == 0:
state.cwd = cwd_result.output.strip()
except Exception:
except (OSError, RuntimeError, asyncio.TimeoutError):
pass
except asyncio.TimeoutError:
output = f"命令执行超时({request.timeout}s"
exit_code = -1
duration_ms = int((time.monotonic() - start_time) * 1000)
except Exception as e:
except (OSError, RuntimeError, ValueError) as e:
output = str(e)
exit_code = -1
duration_ms = int((time.monotonic() - start_time) * 1000)
@ -515,7 +515,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
})
await websocket.close(code=4003, reason="Permission denied")
return
except Exception:
except (ValueError, KeyError, RuntimeError, OSError):
pass # Fall through to API key / dev mode
# 2. API key via ?api_key=
@ -721,7 +721,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
})
except asyncio.TimeoutError:
continue
except Exception:
except (OSError, RuntimeError):
break
# Get final result
@ -741,7 +741,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
await check_pty.close()
if cwd_result.exit_code == 0:
state.cwd = cwd_result.output.strip()
except Exception:
except (OSError, RuntimeError, asyncio.TimeoutError):
pass
# Record history
@ -790,6 +790,8 @@ async def terminal_websocket(websocket: WebSocket) -> None:
"cwd": state.cwd,
"duration_ms": duration_ms,
})
except asyncio.CancelledError:
raise
except Exception as e:
duration_ms = int((time.monotonic() - start_time) * 1000)
await websocket.send_json({
@ -800,7 +802,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running:
try:
await active_pty.close()
except Exception:
except (OSError, RuntimeError):
pass
active_pty = None
@ -836,11 +838,13 @@ async def terminal_websocket(websocket: WebSocket) -> None:
except WebSocketDisconnect:
logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}")
try:
await websocket.send_json({"type": "error", "message": str(e)[:200]})
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass
finally:
# Clean up pending confirmations
@ -851,5 +855,5 @@ async def terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running:
try:
await active_pty.close()
except Exception:
except (OSError, RuntimeError):
pass

View File

@ -549,7 +549,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
})
await websocket.close(code=4003, reason="Permission denied")
return
except Exception:
except (ValueError, KeyError, RuntimeError, OSError):
pass
if not auth_ok:
@ -585,7 +585,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
})
await websocket.close(code=4003, reason="Not authorized")
return
except Exception as e:
except (aiosqlite.Error, OSError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Failed to check server terminal authorization: {e}")
# If DB check fails, deny access
await websocket.send_json({
@ -822,7 +822,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
})
except asyncio.TimeoutError:
continue
except Exception:
except (OSError, RuntimeError):
break
result = await run_task
@ -846,7 +846,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
else:
base = state.cwd or os.path.expanduser("~")
state.cwd = os.path.normpath(os.path.join(base, cd_arg))
except Exception:
except (ValueError, OSError, RuntimeError):
pass # cd parsing failed — keep old cwd
# Record history
@ -893,6 +893,8 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
"cwd": state.cwd,
"duration_ms": duration_ms,
})
except asyncio.CancelledError:
raise
except Exception as e:
duration_ms = int((time.monotonic() - start_time) * 1000)
await websocket.send_json({
@ -903,7 +905,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running:
try:
await active_pty.close()
except Exception:
except (OSError, RuntimeError):
pass
active_pty = None
@ -942,17 +944,19 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
except WebSocketDisconnect:
logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}")
try:
await websocket.send_json({"type": "error", "message": str(e)[:200]})
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass
finally:
if active_pty and active_pty.is_running:
try:
await active_pty.close()
except Exception:
except (OSError, RuntimeError):
pass
# Clean up session from in-memory stores to prevent leaks
_server_sessions.pop(state.session_id, None)

View File

@ -433,6 +433,8 @@ async def _execute_workflow(
)
result = await agent.handle_task(task)
stage_result = {"output": result, "skill": stage.action}
except asyncio.CancelledError:
raise
except Exception as e:
stage_result = {"error": str(e), "skill": stage.action}
else:
@ -474,6 +476,8 @@ async def _execute_workflow(
)
result = await agent.handle_task(task)
return {"output": result, "skill": action}
except asyncio.CancelledError:
raise
except Exception as e:
return {"error": str(e), "skill": action}
return {"dry_run": True, "action": action}
@ -515,6 +519,8 @@ async def _execute_workflow(
execution_id=execution.execution_id,
)
except asyncio.CancelledError:
raise
except Exception as e:
execution.stage_results[stage_name] = {
"status": "failed",
@ -633,7 +639,7 @@ async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None
for ws in targets:
try:
await ws.send_json(message)
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
disconnected.append(ws)
if disconnected:
async with _ws_subscribers_lock:
@ -952,7 +958,7 @@ async def workflow_websocket(websocket: WebSocket):
# Keep connection alive - messages are primarily server-push
except WebSocketDisconnect:
logger.debug("Workflow WebSocket disconnected")
except Exception as e:
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Workflow WebSocket error: {e}")
finally:
if subscribed_execution_id:

View File

@ -47,7 +47,7 @@ class ConnectionManager:
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
stale.append(ws)
for ws in stale:
self.remove(task_id, ws)
@ -153,6 +153,8 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
except WebSocketDisconnect:
logger.debug(f"WebSocket disconnected for task {task_id}")
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(f"WebSocket error for task {task_id}: {e}")
try:
@ -162,7 +164,7 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
"data": {"message": str(e)},
}
)
except Exception:
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass
finally:
manager.remove(task_id, websocket)
@ -243,6 +245,8 @@ async def _run_react_and_stream(
},
)
except asyncio.CancelledError:
raise
except Exception as e:
await websocket.send_json(
{