refactor: follow-up tech debt cleanup (except Exception + Any 治理) (#9)
Deploy to Production / deploy (push) Waiting to run Details
Test / backend-test (push) Waiting to run Details
Test / frontend-unit (push) Waiting to run Details
Test / api-e2e (push) Waiting to run Details
Test / frontend-e2e (push) Waiting to run Details

This commit is contained in:
Fischer 2026-07-01 03:03:02 +08:00
parent cc531d0663
commit 838a05772e
79 changed files with 1388 additions and 900 deletions

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import ast import ast
from collections import deque from collections import deque
from typing import Any
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
from agentkit.bitable.formula.parser import ( from agentkit.bitable.formula.parser import (
@ -104,9 +103,9 @@ class FormulaEngine:
def evaluate( def evaluate(
self, self,
field_id: str, field_id: str,
row_values: dict[str, Any], row_values: dict[str, object],
column_values: dict[str, list[Any]] | None = None, column_values: dict[str, list[object]] | None = None,
) -> Any: ) -> object:
"""Evaluate a formula field for a specific record. """Evaluate a formula field for a specific record.
Args: Args:
@ -130,7 +129,7 @@ class FormulaEngine:
# Build the field_values dict for the evaluator # Build the field_values dict for the evaluator
# Aggregate refs get column values (lists), row refs get row values (scalars) # Aggregate refs get column values (lists), row refs get row values (scalars)
eval_values: dict[str, Any] = {} eval_values: dict[str, object] = {}
# Map real field IDs to safe names # Map real field IDs to safe names
for safe_name, real_id in entry.field_mapping.items(): for safe_name, real_id in entry.field_mapping.items():
@ -143,16 +142,16 @@ class FormulaEngine:
def evaluate_all_for_record( def evaluate_all_for_record(
self, self,
row_values: dict[str, Any], row_values: dict[str, object],
column_values: dict[str, list[Any]] | None = None, column_values: dict[str, list[object]] | None = None,
) -> dict[str, Any]: ) -> dict[str, object]:
"""Evaluate all registered formulas for a record. """Evaluate all registered formulas for a record.
Returns a dict of field_id computed value. Returns a dict of field_id computed value.
Formulas are evaluated in topological order so that formula-to-formula Formulas are evaluated in topological order so that formula-to-formula
dependencies are resolved correctly. dependencies are resolved correctly.
""" """
results: dict[str, Any] = {} results: dict[str, object] = {}
column_values = column_values or {} column_values = column_values or {}
for field_id in self.topological_order(): for field_id in self.topological_order():

View File

@ -16,13 +16,11 @@ Usage::
from __future__ import annotations from __future__ import annotations
from typing import Any
def transform_records( def transform_records(
records: list[dict[str, Any]], records: list[dict[str, object]],
field_mapping: dict[str, str], field_mapping: dict[str, str],
) -> list[dict[str, Any]]: ) -> list[dict[str, object]]:
"""Map source record keys to bitable field IDs via field_mapping. """Map source record keys to bitable field IDs via field_mapping.
Keys not in ``field_mapping`` are dropped. Values are passed through Keys not in ``field_mapping`` are dropped. Values are passed through
@ -40,9 +38,9 @@ def transform_records(
if not field_mapping: if not field_mapping:
return [] return []
transformed: list[dict[str, Any]] = [] transformed: list[dict[str, object]] = []
for rec in records: for rec in records:
out: dict[str, Any] = {} out: dict[str, object] = {}
for src_key, field_id in field_mapping.items(): for src_key, field_id in field_mapping.items():
if src_key in rec: if src_key in rec:
out[field_id] = rec[src_key] out[field_id] = rec[src_key]

View File

@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable):
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = {
READ_BATCH = 1000 READ_BATCH = 1000
def infer_field_type(sqla_type: Any) -> str: def infer_field_type(sqla_type: object) -> str:
"""Map a SQLAlchemy column type instance or class to a bitable field type. """Map a SQLAlchemy column type instance or class to a bitable field type.
Handles both type instances (``Integer()``) and type classes (``Integer``). Handles both type instances (``Integer()``) and type classes (``Integer``).
@ -78,7 +77,7 @@ def import_table(
table_name: str, table_name: str,
*, *,
max_rows: int = 50_000, max_rows: int = 50_000,
) -> dict[str, Any]: ) -> dict[str, object]:
"""Reflect a single table from an external DB. """Reflect a single table from an external DB.
Returns ``{"table_name": str, "fields": [...], "records": [...], Returns ``{"table_name": str, "fields": [...], "records": [...],
@ -97,7 +96,7 @@ def import_table(
engine.dispose() engine.dispose()
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]: def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]:
"""Reflect one table and read its rows.""" """Reflect one table and read its rows."""
insp = inspect(engine) insp = inspect(engine)
@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
table = Table(table_name, metadata, autoload_with=engine) table = Table(table_name, metadata, autoload_with=engine)
# Build field definitions # Build field definitions
fields: list[dict[str, Any]] = [] fields: list[dict[str, object]] = []
pk_columns = list(table.primary_key.columns) pk_columns = list(table.primary_key.columns)
pk_name = pk_columns[0].name if pk_columns else None pk_name = pk_columns[0].name if pk_columns else None
@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
pk_name = "id" pk_name = "id"
# Read rows # Read rows
records: list[dict[str, Any]] = [] records: list[dict[str, object]] = []
with engine.connect() as conn: with engine.connect() as conn:
result = conn.execute(select(table)) result = conn.execute(select(table))
for i, row in enumerate(result): for i, row in enumerate(result):
if i >= max_rows: if i >= max_rows:
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows) logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
break break
rec: dict[str, Any] = {} rec: dict[str, object] = {}
for col in table.columns: for col in table.columns:
val = getattr(row, col.name, None) val = getattr(row, col.name, None)
if val is not None: if val is not None:
@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
} }
def _serialize(val: Any) -> Any: def _serialize(val: object) -> object:
"""Serialize a DB value to JSON-safe form.""" """Serialize a DB value to JSON-safe form."""
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal

View File

@ -18,7 +18,6 @@ import logging
import socket import socket
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -36,7 +35,7 @@ class ParsedSheet:
name: str name: str
columns: list[str] = field(default_factory=list) columns: list[str] = field(default_factory=list)
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date" field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
records: list[dict[str, Any]] = field(default_factory=list) records: list[dict[str, object]] = field(default_factory=list)
def parse_excel(file_path: str | Path) -> list[ParsedSheet]: def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None:
col_count = len(clean_headers) col_count = len(clean_headers)
field_types = _infer_column_types(data_rows, col_count) field_types = _infer_column_types(data_rows, col_count)
records: list[dict[str, Any]] = [] records: list[dict[str, object]] = []
for row in data_rows: for row in data_rows:
rec: dict[str, Any] = {} rec: dict[str, object] = {}
for i, col_name in enumerate(clean_headers): for i, col_name in enumerate(clean_headers):
val = row[i] if i < len(row) else None val = row[i] if i < len(row) else None
if val is not None: if val is not None:
@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]:
return types return types
def _coerce_value(val: Any, field_type: str) -> Any: def _coerce_value(val: object, field_type: str) -> object:
"""Coerce a cell value to the inferred field type. Truncate long strings.""" """Coerce a cell value to the inferred field type. Truncate long strings."""
if field_type == "date": if field_type == "date":
from datetime import datetime from datetime import datetime

View File

@ -9,10 +9,14 @@ from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field as PydanticField from pydantic import BaseModel, ConfigDict, Field as PydanticField
# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of
# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive
# named aliases (RecursionError). `object` is the most permissive type and
# Pydantic v2 serializes dict/list/primitive values fine at runtime.
def _utcnow() -> datetime: def _utcnow() -> datetime:
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
@ -97,7 +101,7 @@ class Table(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Status select field options — labels and colors match Feishu Bitable defaults. # Status select field options — labels and colors match Feishu Bitable defaults.
_STATUS_OPTIONS: list[dict[str, Any]] = [ _STATUS_OPTIONS: list[dict[str, object]] = [
{"label": "未开始", "value": "not_started", "color": "default"}, {"label": "未开始", "value": "not_started", "color": "default"},
{"label": "进行中", "value": "in_progress", "color": "processing"}, {"label": "进行中", "value": "in_progress", "color": "processing"},
{"label": "已完成", "value": "done", "color": "success"}, {"label": "已完成", "value": "done", "color": "success"},
@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [
#: Templates for the 5 default fields created on every new table (R2). #: Templates for the 5 default fields created on every new table (R2).
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer #: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
#: on record creation; user-owned fields are user-editable. #: on record creation; user-owned fields are user-editable.
DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [ DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [
{ {
"name": "标题", "name": "标题",
"field_type": FieldType.text, "field_type": FieldType.text,
@ -155,7 +159,7 @@ class Field(BaseModel):
table_id: str table_id: str
name: str name: str
field_type: FieldType field_type: FieldType
config: dict[str, Any] = PydanticField(default_factory=dict) config: dict[str, object] = PydanticField(default_factory=dict)
owner: FieldOwner = FieldOwner.user owner: FieldOwner = FieldOwner.user
created_at: datetime = PydanticField(default_factory=_utcnow) created_at: datetime = PydanticField(default_factory=_utcnow)
@ -167,7 +171,7 @@ class Record(BaseModel):
id: str id: str
table_id: str table_id: str
values: dict[str, Any] = PydanticField(default_factory=dict) values: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow) created_at: datetime = PydanticField(default_factory=_utcnow)
updated_at: datetime = PydanticField(default_factory=_utcnow) updated_at: datetime = PydanticField(default_factory=_utcnow)
@ -181,7 +185,7 @@ class View(BaseModel):
table_id: str table_id: str
name: str name: str
view_type: ViewType = ViewType.grid view_type: ViewType = ViewType.grid
config: dict[str, Any] = PydanticField(default_factory=dict) config: dict[str, object] = PydanticField(default_factory=dict)
created_at: datetime = PydanticField(default_factory=_utcnow) created_at: datetime = PydanticField(default_factory=_utcnow)

View File

@ -18,11 +18,10 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Any
from agentkit.bitable.db import BitableDB from agentkit.bitable.db import BitableDB
from agentkit.bitable.formula.engine import FormulaEngine from agentkit.bitable.formula.engine import FormulaEngine
from agentkit.bitable.models import FieldType, RecalcStatus from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
from agentkit.bitable.repository import BitableRepository from agentkit.bitable.repository import BitableRepository
from agentkit.bitable.service import BitableService from agentkit.bitable.service import BitableService
@ -124,7 +123,7 @@ class RecalcWorker:
logger.exception("RecalcWorker error in main loop") logger.exception("RecalcWorker error in main loop")
await asyncio.sleep(self._poll_interval) await asyncio.sleep(self._poll_interval)
async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]: async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
"""Sort claimed tasks so dependencies are processed first (P1 #7). """Sort claimed tasks so dependencies are processed first (P1 #7).
Groups tasks by table_id, builds (or reuses) the engine to get the Groups tasks by table_id, builds (or reuses) the engine to get the
@ -146,7 +145,7 @@ class RecalcWorker:
order = engine.topological_order() order = engine.topological_order()
topo_index[tid] = {fid: i for i, fid in enumerate(order)} topo_index[tid] = {fid: i for i, fid in enumerate(order)}
def _key(t: Any) -> tuple[str, str, int]: def _key(t: RecalcTask) -> tuple[str, str, int]:
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30) idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
return (t.table_id, t.record_id, idx) return (t.table_id, t.record_id, idx)
@ -175,7 +174,7 @@ class RecalcWorker:
except Exception: except Exception:
logger.exception("RecalcWorker reaper error") logger.exception("RecalcWorker reaper error")
async def process_task(self, task: Any) -> None: async def process_task(self, task: RecalcTask) -> None:
"""Process a single recalc task: evaluate formula → write result. """Process a single recalc task: evaluate formula → write result.
The task is expected to already be in ``calculating`` status when The task is expected to already be in ``calculating`` status when
@ -216,7 +215,7 @@ class RecalcWorker:
return return
deps = engine.get_dependencies(task.field_id) deps = engine.get_dependencies(task.field_id)
column_values: dict[str, list[Any]] = {} column_values: dict[str, list[object]] = {}
for dep_field_id in deps: for dep_field_id in deps:
column_values[dep_field_id] = await self._repo.get_column_values( column_values[dep_field_id] = await self._repo.get_column_values(
task.table_id, dep_field_id task.table_id, dep_field_id

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging import logging
import re import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import delete, func, insert, select, text, update from sqlalchemy import delete, func, insert, select, text, update
from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import insert as pg_insert
@ -102,7 +101,7 @@ class BitableRepository:
result = await session.execute(stmt) result = await session.execute(stmt)
return [BitableFile.model_validate(e) for e in result.scalars().all()] return [BitableFile.model_validate(e) for e in result.scalars().all()]
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None: async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
"""Update a file's attributes.""" """Update a file's attributes."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -181,7 +180,7 @@ class BitableRepository:
result = await session.execute(stmt) result = await session.execute(stmt)
return [Table.model_validate(e) for e in result.scalars().all()] return [Table.model_validate(e) for e in result.scalars().all()]
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None: async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
"""Update a table's attributes.""" """Update a table's attributes."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -236,7 +235,7 @@ class BitableRepository:
table_id: str, table_id: str,
name: str, name: str,
field_type: FieldType, field_type: FieldType,
config: dict[str, Any] | None = None, config: dict[str, object] | None = None,
owner: FieldOwner = FieldOwner.user, owner: FieldOwner = FieldOwner.user,
) -> Field: ) -> Field:
"""Create a new field in a table.""" """Create a new field in a table."""
@ -277,7 +276,7 @@ class BitableRepository:
result = await session.execute(stmt) result = await session.execute(stmt)
return [Field.model_validate(e) for e in result.scalars().all()] return [Field.model_validate(e) for e in result.scalars().all()]
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None: async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
"""Update a field's attributes.""" """Update a field's attributes."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -300,7 +299,7 @@ class BitableRepository:
# ── Records ───────────────────────────────────────────── # ── Records ─────────────────────────────────────────────
async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record: async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record:
"""Create a new record.""" """Create a new record."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -318,7 +317,7 @@ class BitableRepository:
return Record.model_validate(entity) return Record.model_validate(entity)
async def create_records_batch( async def create_records_batch(
self, table_id: str, records_values: list[dict[str, Any]] self, table_id: str, records_values: list[dict[str, object]]
) -> list[Record]: ) -> list[Record]:
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT). """Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
@ -376,7 +375,7 @@ class BitableRepository:
return [Record.model_validate(e) for e in entities], next_cursor return [Record.model_validate(e) for e in entities], next_cursor
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None: async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None:
"""Update a record's values (full replace).""" """Update a record's values (full replace)."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -413,7 +412,7 @@ class BitableRepository:
table_id: str, table_id: str,
name: str, name: str,
view_type: ViewType = ViewType.grid, view_type: ViewType = ViewType.grid,
config: dict[str, Any] | None = None, config: dict[str, object] | None = None,
) -> View: ) -> View:
"""Create a new view.""" """Create a new view."""
async with self._session_factory() as session: async with self._session_factory() as session:
@ -451,7 +450,7 @@ class BitableRepository:
result = await session.execute(stmt) result = await session.execute(stmt)
return [View.model_validate(e) for e in result.scalars().all()] return [View.model_validate(e) for e in result.scalars().all()]
async def update_view(self, view_id: str, **kwargs: Any) -> View | None: async def update_view(self, view_id: str, **kwargs: object) -> View | None:
"""Update a view's attributes.""" """Update a view's attributes."""
async with self._session_factory() as session: async with self._session_factory() as session:
stmt = ( stmt = (
@ -543,7 +542,7 @@ class BitableRepository:
) -> None: ) -> None:
"""Update a recalc task's status.""" """Update a recalc task's status."""
async with self._session_factory() as session: async with self._session_factory() as session:
kwargs: dict[str, Any] = {"status": status.value} kwargs: dict[str, object] = {"status": status.value}
if error_message is not None: if error_message is not None:
kwargs["error_message"] = error_message kwargs["error_message"] = error_message
if status in (RecalcStatus.done, RecalcStatus.error): if status in (RecalcStatus.done, RecalcStatus.error):
@ -630,7 +629,7 @@ class BitableRepository:
return result_map return result_map
async def upsert_record_agent_fields( async def upsert_record_agent_fields(
self, record_id: str, agent_field_values: dict[str, Any] self, record_id: str, agent_field_values: dict[str, object]
) -> None: ) -> None:
"""Update agent-owned fields using jsonb_set (KTD8). """Update agent-owned fields using jsonb_set (KTD8).
@ -646,7 +645,7 @@ class BitableRepository:
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect # Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
# misparses the `::` as part of the param name. # misparses the `::` as part of the param name.
inner = "values" inner = "values"
params: dict[str, Any] = {"record_id": record_id} params: dict[str, object] = {"record_id": record_id}
for i, (field_id, value) in enumerate(agent_field_values.items()): for i, (field_id, value) in enumerate(agent_field_values.items()):
param_key = f"v{i}" param_key = f"v{i}"
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)" inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
@ -660,8 +659,8 @@ class BitableRepository:
async def list_records_filtered( async def list_records_filtered(
self, self,
table_id: str, table_id: str,
filters: list[dict[str, Any]] | None = None, filters: list[dict[str, object]] | None = None,
sorts: list[dict[str, Any]] | None = None, sorts: list[dict[str, object]] | None = None,
cursor: str | None = None, cursor: str | None = None,
limit: int = 50, limit: int = 50,
) -> tuple[list[Record], str | None]: ) -> tuple[list[Record], str | None]:
@ -682,7 +681,7 @@ class BitableRepository:
# Build raw SQL with JSONB filter/sort translation. # Build raw SQL with JSONB filter/sort translation.
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer). # ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
where_clauses = ["table_id = :table_id"] where_clauses = ["table_id = :table_id"]
params: dict[str, Any] = {"table_id": table_id} params: dict[str, object] = {"table_id": table_id}
if filters: if filters:
for i, f in enumerate(filters): for i, f in enumerate(filters):
@ -783,7 +782,7 @@ class BitableRepository:
last_mapping = last_row._mapping last_mapping = last_row._mapping
# Build composite cursor from sort values + id. # Build composite cursor from sort values + id.
# Sort values are extracted as text to match `values->>'fid'` expressions. # Sort values are extracted as text to match `values->>'fid'` expressions.
sv: list[Any] = [] sv: list[object] = []
last_values = last_mapping.get("values") last_values = last_mapping.get("values")
if isinstance(last_values, str): if isinstance(last_values, str):
# asyncpg may return JSONB as str in raw text() queries. # asyncpg may return JSONB as str in raw text() queries.
@ -852,7 +851,7 @@ class BitableRepository:
# ── Recalc support (U3) ──────────────────────────────── # ── Recalc support (U3) ────────────────────────────────
async def get_column_values(self, table_id: str, field_id: str) -> list[Any]: async def get_column_values(self, table_id: str, field_id: str) -> list[object]:
"""Get all values for a field across all records in a table (for aggregates). """Get all values for a field across all records in a table (for aggregates).
Returns a list of values (preserving order by record id). Missing values Returns a list of values (preserving order by record id). Missing values
@ -866,7 +865,7 @@ class BitableRepository:
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id}) result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
return [row[0] for row in result.fetchall()] return [row[0] for row in result.fetchall()]
async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None: async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None:
"""Set a single formula field value in a record's JSONB (jsonb_set).""" """Set a single formula field value in a record's JSONB (jsonb_set)."""
import json import json

View File

@ -34,12 +34,19 @@ import os
import sqlite3 import sqlite3
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Callable, TypeAlias
import httpx import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
# 服务器返回的 skills/workflows 列表元素是 dictmodel_dump/to_dict 输出),
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
SkillConfigDict: TypeAlias = dict[str, object]
WorkflowConfigDict: TypeAlias = dict[str, object]
SyncedConfigPayload: TypeAlias = dict[str, object]
# ── Defaults ────────────────────────────────────────────────────────── # ── Defaults ──────────────────────────────────────────────────────────
@ -100,8 +107,8 @@ class ConfigSync:
# In-memory cache (mirrors the SQLite cache for fast access) # In-memory cache (mirrors the SQLite cache for fast access)
self._version: str | None = None self._version: str | None = None
self._skills: list[dict[str, Any]] = [] self._skills: list[SkillConfigDict] = []
self._workflows: list[dict[str, Any]] = [] self._workflows: list[WorkflowConfigDict] = []
self._last_synced_at: str | None = None self._last_synced_at: str | None = None
# ── Lifecycle ───────────────────────────────────────────────── # ── Lifecycle ─────────────────────────────────────────────────
@ -232,15 +239,15 @@ class ConfigSync:
"""Return the current cached config version hash.""" """Return the current cached config version hash."""
return self._version return self._version
def get_skills(self) -> list[dict[str, Any]]: def get_skills(self) -> list[SkillConfigDict]:
"""Return the cached skill configs.""" """Return the cached skill configs."""
return list(self._skills) return list(self._skills)
def get_workflows(self) -> list[dict[str, Any]]: def get_workflows(self) -> list[WorkflowConfigDict]:
"""Return the cached workflow configs.""" """Return the cached workflow configs."""
return list(self._workflows) return list(self._workflows)
def get_all(self) -> dict[str, Any]: def get_all(self) -> SyncedConfigPayload:
"""Return all cached configs as a single dict.""" """Return all cached configs as a single dict."""
return { return {
"version": self._version, "version": self._version,
@ -249,14 +256,14 @@ class ConfigSync:
"synced_at": self._last_synced_at, "synced_at": self._last_synced_at,
} }
def get_skill(self, name: str) -> dict[str, Any] | None: def get_skill(self, name: str) -> SkillConfigDict | None:
"""Return a single skill config by name, or ``None``.""" """Return a single skill config by name, or ``None``."""
for skill in self._skills: for skill in self._skills:
if skill.get("name") == name: if skill.get("name") == name:
return skill return skill
return None return None
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None: def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
"""Return a single workflow config by ID, or ``None``.""" """Return a single workflow config by ID, or ``None``."""
for wf in self._workflows: for wf in self._workflows:
if wf.get("workflow_id") == workflow_id: if wf.get("workflow_id") == workflow_id:
@ -281,7 +288,7 @@ class ConfigSync:
conn.executescript(_CACHE_SCHEMA) conn.executescript(_CACHE_SCHEMA)
conn.commit() conn.commit()
def _save_to_cache(self, data: dict[str, Any]) -> None: def _save_to_cache(self, data: SyncedConfigPayload) -> None:
"""Save the synced configs to the local SQLite cache.""" """Save the synced configs to the local SQLite cache."""
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
with sqlite3.connect(str(self.cache_db_path)) as conn: with sqlite3.connect(str(self.cache_db_path)) as conn:

View File

@ -14,7 +14,7 @@ import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from typing import TYPE_CHECKING, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
@ -25,6 +25,52 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的第三方对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _RedisLike(Protocol):
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
async def get(self, key: str) -> bytes | str | None: ...
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
async def smembers(self, key: str) -> set[bytes | str]: ...
async def sadd(self, name: str, *values: str) -> int: ...
async def srem(self, name: str, *values: str) -> int: ...
async def scard(self, name: str) -> int: ...
async def delete(self, *names: str) -> int: ...
def pipeline(self) -> "_RedisPipelineLike": ...
class _RedisPipelineLike(Protocol):
"""Redis pipeline 最小契约。"""
def get(self, key: str) -> "_RedisPipelineLike": ...
def set(
self, key: str, value: bytes | str, ex: int | None = None
) -> "_RedisPipelineLike": ...
def delete(self, *names: str) -> "_RedisPipelineLike": ...
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
async def execute(self) -> list[object]: ...
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Data Classes # Data Classes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -328,7 +374,7 @@ class RedisLLMCache:
self._semantic_ttl = semantic_ttl self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold self._similarity_threshold = similarity_threshold
self._max_entries_to_scan = max_entries_to_scan self._max_entries_to_scan = max_entries_to_scan
self._redis: Any = None self._redis: _RedisLike | None = None
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
self._degraded = False # True if Redis is unreachable self._degraded = False # True if Redis is unreachable
@ -691,7 +737,7 @@ class LitellmCacheManager:
def __init__(self, config: LitellmCacheConfig): def __init__(self, config: LitellmCacheConfig):
self._config = config self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例 self._cache_instance: object | None = None # litellm.caching.Cache 实例
self._hits = 0 self._hits = 0
self._misses = 0 self._misses = 0
@ -709,7 +755,7 @@ class LitellmCacheManager:
litellm.cache = None litellm.cache = None
self._cache_instance = None self._cache_instance = None
def _create_cache_instance(self) -> Any: def _create_cache_instance(self) -> object:
"""根据 backend 配置创建 LiteLLM Cache 实例。 """根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache

View File

@ -2,14 +2,13 @@
import hashlib import hashlib
import json import json
from typing import Any
def generate_cache_key( def generate_cache_key(
model: str, model: str,
messages: list[dict[str, str]], messages: list[dict[str, str]],
temperature: float, temperature: float,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto", tool_choice: str = "auto",
max_tokens: int = 2000, max_tokens: int = 2000,
user_id: str | None = None, user_id: str | None = None,

View File

@ -3,12 +3,12 @@
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from agentkit.channels.secrets import SecretsStore from agentkit.channels.secrets import SecretEntry, SecretsStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class ProviderConfig:
api_key: str api_key: str
base_url: str base_url: str
models: dict[str, dict[str, Any]] = field(default_factory=dict) models: dict[str, dict[str, object]] = field(default_factory=dict)
type: str = "openai" # "openai" | "anthropic" | "gemini" type: str = "openai" # "openai" | "anthropic" | "gemini"
max_tokens: int = 4096 # Anthropic: default max_tokens max_tokens: int = 4096 # Anthropic: default max_tokens
timeout: float = 120.0 # Anthropic: request timeout timeout: float = 120.0 # Anthropic: request timeout
@ -168,18 +168,18 @@ class ProviderConfig:
return f"llm:provider:{self.type}:api_key" return f"llm:provider:{self.type}:api_key"
@staticmethod @staticmethod
def _encode_secret_entry(entry: Any, key: str) -> str: def _encode_secret_entry(entry: object, key: str) -> str:
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。""" """把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
# entry 是 SecretEntry pydantic 模型,有 model_dump() # entry 是 SecretEntry pydantic 模型,有 model_dump()
if hasattr(entry, "model_dump"): if hasattr(entry, "model_dump"):
data = entry.model_dump() data = entry.model_dump() # type: ignore[attr-defined]
else: else:
data = dict(entry) data = dict(entry) # type: ignore[call-overload]
data["key"] = key data["key"] = key
return json.dumps(data) return json.dumps(data)
@staticmethod @staticmethod
def _decode_secret_entry(encoded: str) -> Any: def _decode_secret_entry(encoded: str) -> "SecretEntry":
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。""" """从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
from agentkit.channels.secrets import SecretEntry from agentkit.channels.secrets import SecretEntry

View File

@ -5,7 +5,7 @@ import logging
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Protocol
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.config import LLMConfig from agentkit.llm.config import LLMConfig
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
from agentkit.telemetry.metrics import llm_token_histogram from agentkit.telemetry.metrics import llm_token_histogram
if TYPE_CHECKING:
from agentkit.llm.cache import LitellmCacheManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TYPE_CHECKING Protocols — 避免 Any描述运行时 lazy import 的对象
# ---------------------------------------------------------------------------
if TYPE_CHECKING:
class _QuotaServiceLike(Protocol):
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
async def is_model_allowed(
self, db: Path, department_id: str, model: str
) -> tuple[bool, str]: ...
async def check_quota(
self,
db: Path,
department_id: str,
quota_type: str,
period: str,
current: float,
) -> tuple[bool, str]: ...
async def get_quota(
self, db: Path, department_id: str, quota_type: str, period: str
) -> dict[str, object] | None: ...
class QuotaExceededError(Exception): class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded. """Raised when a department's LLM quota is exceeded.
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
department_id: str, department_id: str,
quota_type: str, quota_type: str,
period: str, period: str,
limit: Any, limit: object,
current: Any, current: object,
) -> None: ) -> None:
self.department_id = department_id self.department_id = department_id
self.quota_type = quota_type self.quota_type = quota_type
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
class LLMGateway: class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache""" """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None): def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
self._providers: dict[str, LLMProvider] = {} self._providers: dict[str, LLMProvider] = {}
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig() self._config = config or LLMConfig()
# Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用) # Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache_manager: Any = None # LitellmCacheManager | None self._cache_manager: "LitellmCacheManager | None" = None
if self._config.cache and self._config.cache.enabled: if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
@ -601,7 +632,7 @@ class LLMGateway:
async def _check_quota_value( async def _check_quota_value(
self, self,
quota_service: Any, quota_service: _QuotaServiceLike,
db: Path, db: Path,
dept_id: str, dept_id: str,
period: str, period: str,

View File

@ -27,12 +27,11 @@ from __future__ import annotations
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]: def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。 """把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
流程 流程
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
store = SecretsStore() # master key 从 env 加载 store = SecretsStore() # master key 从 env 加载
async def _run() -> dict[str, dict[str, Any]]: async def _run() -> dict[str, dict[str, object]]:
report: dict[str, dict[str, Any]] = {} report: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items(): for name, pconf in llm_config.providers.items():
if pconf.api_key_source == "secrets_store" and not pconf.api_key: if pconf.api_key_source == "secrets_store" and not pconf.api_key:
report[name] = {"status": "skipped", "source": pconf.api_key_source} report[name] = {"status": "skipped", "source": pconf.api_key_source}
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
report = asyncio.run(_run()) report = asyncio.run(_run())
# 写回 YAML更新 llm.providers 段,保留其它段 # 写回 YAML更新 llm.providers 段,保留其它段
providers_out: dict[str, dict[str, Any]] = {} providers_out: dict[str, dict[str, object]] = {}
for name, pconf in llm_config.providers.items(): for name, pconf in llm_config.providers.items():
entry: dict[str, Any] = { entry: dict[str, object] = {
"type": pconf.type, "type": pconf.type,
"base_url": pconf.base_url, "base_url": pconf.base_url,
"models": pconf.models, "models": pconf.models,

View File

@ -2,7 +2,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
@dataclass @dataclass
@ -23,7 +22,7 @@ class ToolCall:
id: str id: str
name: str name: str
arguments: dict[str, Any] arguments: dict[str, object]
@dataclass @dataclass
@ -32,7 +31,7 @@ class LLMRequest:
messages: list[dict[str, str]] messages: list[dict[str, str]]
model: str model: str
tools: list[dict[str, Any]] | None = None tools: list[dict[str, object]] | None = None
tool_choice: str = "auto" tool_choice: str = "auto"
temperature: float = 0.7 temperature: float = 0.7
max_tokens: int = 2000 max_tokens: int = 2000
@ -42,13 +41,13 @@ class LLMRequest:
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, object]] | None = None,
tool_choice: str = "auto", tool_choice: str = "auto",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2000, max_tokens: int = 2000,
timeout: float | None = None, timeout: float | None = None,
cache: dict[str, Any] | None = None, cache: dict[str, object] | None = None,
**kwargs: Any, **kwargs: object,
): ):
self.messages = messages self.messages = messages
self.model = model self.model = model
@ -59,7 +58,7 @@ class LLMRequest:
self.timeout = timeout self.timeout = timeout
self._extra = kwargs self._extra = kwargs
# U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion # U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion
self._cache: dict[str, Any] | None = cache self._cache: dict[str, object] | None = cache
@dataclass @dataclass

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
import time import time
from typing import Any
import httpx import httpx
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
"content-type": "application/json", "content-type": "application/json",
} }
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]: def _convert_messages(
self, messages: list[dict[str, str]]
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Anthropic 格式 """将 OpenAI 风格消息转换为 Anthropic 格式
Returns: Returns:
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2) - list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
- None: system 消息 - None: system 消息
""" """
system_prompt: str | list[dict[str, Any]] | None = None system_prompt: str | list[dict[str, object]] | None = None
anthropic_messages: list[dict[str, Any]] = [] anthropic_messages: list[dict[str, object]] = []
for msg in messages: for msg in messages:
role = msg.get("role", "") role = msg.get("role", "")
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
# 检查是否有 tool_calls (OpenAI 格式) # 检查是否有 tool_calls (OpenAI 格式)
tool_calls = msg.get("tool_calls") tool_calls = msg.get("tool_calls")
if tool_calls: if tool_calls:
blocks: list[dict[str, Any]] = [] blocks: list[dict[str, object]] = []
# 如果有文本内容,先添加文本块 # 如果有文本内容,先添加文本块
if content: if content:
blocks.append({"type": "text", "text": content}) blocks.append({"type": "text", "text": content})
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
arguments = json.loads(arguments) arguments = json.loads(arguments)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": arguments} arguments = {"raw": arguments}
blocks.append({ blocks.append(
"type": "tool_use", {
"id": tc.get("id", ""), "type": "tool_use",
"name": func.get("name", ""), "id": tc.get("id", ""),
"input": arguments, "name": func.get("name", ""),
}) "input": arguments,
}
)
anthropic_messages.append({"role": "assistant", "content": blocks}) anthropic_messages.append({"role": "assistant", "content": blocks})
else: else:
anthropic_messages.append({ anthropic_messages.append(
"role": "assistant", {
"content": [{"type": "text", "text": content}], "role": "assistant",
}) "content": [{"type": "text", "text": content}],
}
)
continue continue
if role == "user": if role == "user":
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果) # 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."} # OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
if msg.get("tool_call_id"): if msg.get("tool_call_id"):
tool_result_blocks: list[dict[str, Any]] = [] tool_result_blocks: list[dict[str, object]] = []
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
# tool_result 的 content 可以是字符串或内容块列表 # tool_result 的 content 可以是字符串或内容块列表
if isinstance(tool_content, str): if isinstance(tool_content, str):
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
else: else:
tool_result_blocks.append({"type": "text", "text": str(tool_content)}) tool_result_blocks.append({"type": "text", "text": str(tool_content)})
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{ "role": "user",
"type": "tool_result", "content": [
"tool_use_id": msg.get("tool_call_id", ""), {
"content": tool_result_blocks, "type": "tool_result",
}], "tool_use_id": msg.get("tool_call_id", ""),
}) "content": tool_result_blocks,
}
],
}
)
else: else:
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{"type": "text", "text": content}], "role": "user",
}) "content": [{"type": "text", "text": content}],
}
)
continue continue
if role == "tool": if role == "tool":
# OpenAI 格式中独立的 tool 消息 # OpenAI 格式中独立的 tool 消息
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
if isinstance(tool_content, str): if isinstance(tool_content, str):
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}] result_content: list[dict[str, object]] | str = [
{"type": "text", "text": tool_content}
]
elif isinstance(tool_content, list): elif isinstance(tool_content, list):
result_content = tool_content result_content = tool_content
else: else:
result_content = [{"type": "text", "text": str(tool_content)}] result_content = [{"type": "text", "text": str(tool_content)}]
anthropic_messages.append({ anthropic_messages.append(
"role": "user", {
"content": [{ "role": "user",
"type": "tool_result", "content": [
"tool_use_id": msg.get("tool_call_id", ""), {
"content": result_content, "type": "tool_result",
}], "tool_use_id": msg.get("tool_call_id", ""),
}) "content": result_content,
}
],
}
)
return system_prompt, anthropic_messages return system_prompt, anthropic_messages
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Anthropic tool 格式""" """将 OpenAI function 格式转换为 Anthropic tool 格式"""
anthropic_tools = [] anthropic_tools = []
for tool in tools: for tool in tools:
if tool.get("type") == "function": if tool.get("type") == "function":
func = tool.get("function", {}) func = tool.get("function", {})
anthropic_tools.append({ anthropic_tools.append(
"name": func.get("name", ""), {
"description": func.get("description", ""), "name": func.get("name", ""),
"input_schema": func.get("parameters", {"type": "object", "properties": {}}), "description": func.get("description", ""),
}) "input_schema": func.get(
"parameters", {"type": "object", "properties": {}}
),
}
)
return anthropic_tools return anthropic_tools
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式""" """将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
if tool_choice == "auto": if tool_choice == "auto":
return {"type": "auto"} return {"type": "auto"}
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
return {"type": "tool", "name": tool_choice} return {"type": "tool", "name": tool_choice}
return None return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Anthropic 响应转换为 LLMResponse""" """将 Anthropic 响应转换为 LLMResponse"""
content_blocks = data.get("content", []) content_blocks = data.get("content", [])
text_parts: list[str] = [] text_parts: list[str] = []
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
if block_type == "text": if block_type == "text":
text_parts.append(block.get("text", "")) text_parts.append(block.get("text", ""))
elif block_type == "tool_use": elif block_type == "tool_use":
tool_calls.append(ToolCall( tool_calls.append(
id=block.get("id", ""), ToolCall(
name=block.get("name", ""), id=block.get("id", ""),
arguments=block.get("input", {}), name=block.get("name", ""),
)) arguments=block.get("input", {}),
)
)
usage_data = data.get("usage", {}) usage_data = data.get("usage", {})
usage = TokenUsage( usage = TokenUsage(
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages) system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"model": request.model, "model": request.model,
"max_tokens": request.max_tokens or self._max_tokens, "max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages, "messages": anthropic_messages,
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
system_prompt, anthropic_messages = self._convert_messages(request.messages) system_prompt, anthropic_messages = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"model": request.model, "model": request.model,
"max_tokens": request.max_tokens or self._max_tokens, "max_tokens": request.max_tokens or self._max_tokens,
"messages": anthropic_messages, "messages": anthropic_messages,
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest): async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks.""" """Iterate over an already-open SSE stream and yield StreamChunks."""
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str} # Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
accumulated_tool_calls: dict[str, dict[str, Any]] = {} accumulated_tool_calls: dict[str, dict[str, object]] = {}
current_tool_id: str | None = None current_tool_id: str | None = None
current_tool_name: str | None = None current_tool_name: str | None = None
current_tool_input_json: str = "" current_tool_input_json: str = ""
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
# Finalize current tool call if any # Finalize current tool call if any
if current_tool_id is not None: if current_tool_id is not None:
try: try:
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {} arguments = (
json.loads(current_tool_input_json) if current_tool_input_json else {}
)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": current_tool_input_json} arguments = {"raw": current_tool_input_json}
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
error_msg = error_info.get("message", "Stream error") error_msg = error_info.get("message", "Stream error")
raise LLMProviderError("anthropic", error_msg) raise LLMProviderError("anthropic", error_msg)
def get_model_info(self) -> dict[str, Any]: def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息""" """返回 Provider 和模型信息"""
return { return {
"provider": "anthropic", "provider": "anthropic",

View File

@ -8,7 +8,6 @@ API火山引擎 OpenAI 兼容接口
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
api_key: str, api_key: str,
base_url: str = DOUBAO_DEFAULT_BASE_URL, base_url: str = DOUBAO_DEFAULT_BASE_URL,
default_model: str = "doubao-pro-32k", default_model: str = "doubao-pro-32k",
**kwargs: Any, **kwargs: object,
): ):
super().__init__( super().__init__(
api_key=api_key, api_key=api_key,

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
import time import time
from typing import Any
import httpx import httpx
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
def _convert_messages( def _convert_messages(
self, messages: list[dict[str, str]] self, messages: list[dict[str, str]]
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]: ) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
"""将 OpenAI 风格消息转换为 Gemini 格式 """将 OpenAI 风格消息转换为 Gemini 格式
Returns: Returns:
(system_instruction, contents) (system_instruction, contents)
""" """
system_instruction: dict[str, Any] | None = None system_instruction: dict[str, object] | None = None
contents: list[dict[str, Any]] = [] contents: list[dict[str, object]] = []
for msg in messages: for msg in messages:
role = msg.get("role", "") role = msg.get("role", "")
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
tool_name = parsed.get("name", "") tool_name = parsed.get("name", "")
except (json.JSONDecodeError, AttributeError): except (json.JSONDecodeError, AttributeError):
pass pass
contents.append({ contents.append(
"role": "user", {
"parts": [{ "role": "user",
"functionResponse": { "parts": [
"name": tool_name, {
"response": { "functionResponse": {
"content": content, "name": tool_name,
}, "response": {
}, "content": content,
}], },
}) },
}
],
}
)
else: else:
contents.append({ contents.append(
"role": "user", {
"parts": [{"text": content}], "role": "user",
}) "parts": [{"text": content}],
}
)
continue continue
if role == "assistant": if role == "assistant":
tool_calls = msg.get("tool_calls") tool_calls = msg.get("tool_calls")
if tool_calls: if tool_calls:
parts: list[dict[str, Any]] = [] parts: list[dict[str, object]] = []
if content: if content:
parts.append({"text": content}) parts.append({"text": content})
for tc in tool_calls: for tc in tool_calls:
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
arguments = json.loads(arguments) arguments = json.loads(arguments)
except json.JSONDecodeError: except json.JSONDecodeError:
arguments = {"raw": arguments} arguments = {"raw": arguments}
parts.append({ parts.append(
"functionCall": { {
"name": func.get("name", ""), "functionCall": {
"args": arguments, "name": func.get("name", ""),
}, "args": arguments,
}) },
}
)
contents.append({"role": "model", "parts": parts}) contents.append({"role": "model", "parts": parts})
else: else:
contents.append({ contents.append(
"role": "model", {
"parts": [{"text": content}], "role": "model",
}) "parts": [{"text": content}],
}
)
continue continue
if role == "tool": if role == "tool":
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."} # OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
tool_name = msg.get("name", "") tool_name = msg.get("name", "")
tool_content = msg.get("content", "") tool_content = msg.get("content", "")
contents.append({ contents.append(
"role": "user", {
"parts": [{ "role": "user",
"functionResponse": { "parts": [
"name": tool_name, {
"response": { "functionResponse": {
"content": tool_content, "name": tool_name,
}, "response": {
}, "content": tool_content,
}], },
}) },
}
],
}
)
return system_instruction, contents return system_instruction, contents
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
"""将 OpenAI function 格式转换为 Gemini functionDeclarations""" """将 OpenAI function 格式转换为 Gemini functionDeclarations"""
declarations = [] declarations = []
for tool in tools: for tool in tools:
if tool.get("type") == "function": if tool.get("type") == "function":
func = tool.get("function", {}) func = tool.get("function", {})
declarations.append({ declarations.append(
"name": func.get("name", ""), {
"description": func.get("description", ""), "name": func.get("name", ""),
"parameters": func.get("parameters", {"type": "object", "properties": {}}), "description": func.get("description", ""),
}) "parameters": func.get("parameters", {"type": "object", "properties": {}}),
}
)
if not declarations: if not declarations:
return [] return []
return [{"functionDeclarations": declarations}] return [{"functionDeclarations": declarations}]
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None: def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig""" """将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
if tool_choice == "auto": if tool_choice == "auto":
return {"functionCallingConfig": {"mode": "AUTO"}} return {"functionCallingConfig": {"mode": "AUTO"}}
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
return {"functionCallingConfig": {"mode": "NONE"}} return {"functionCallingConfig": {"mode": "NONE"}}
return None return None
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse: def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
"""将 Gemini 响应转换为 LLMResponse""" """将 Gemini 响应转换为 LLMResponse"""
candidates = data.get("candidates", []) candidates = data.get("candidates", [])
text_parts: list[str] = [] text_parts: list[str] = []
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
text_parts.append(part["text"]) text_parts.append(part["text"])
elif "functionCall" in part: elif "functionCall" in part:
fc = part["functionCall"] fc = part["functionCall"]
tool_calls.append(ToolCall( tool_calls.append(
id=f"call_{tool_call_index}", ToolCall(
name=fc.get("name", ""), id=f"call_{tool_call_index}",
arguments=fc.get("args", {}), name=fc.get("name", ""),
)) arguments=fc.get("args", {}),
)
)
tool_call_index += 1 tool_call_index += 1
usage_metadata = data.get("usageMetadata", {}) usage_metadata = data.get("usageMetadata", {})
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages) system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"contents": contents, "contents": contents,
"generationConfig": { "generationConfig": {
"temperature": request.temperature, "temperature": request.temperature,
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
system_instruction, contents = self._convert_messages(request.messages) system_instruction, contents = self._convert_messages(request.messages)
payload: dict[str, Any] = { payload: dict[str, object] = {
"contents": contents, "contents": contents,
"generationConfig": { "generationConfig": {
"temperature": request.temperature, "temperature": request.temperature,
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
async def _iterate_stream(self, response, request: LLMRequest): async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks.""" """Iterate over an already-open SSE stream and yield StreamChunks."""
accumulated_tool_calls: list[dict[str, Any]] = [] accumulated_tool_calls: list[dict[str, object]] = []
model = request.model or self._model model = request.model or self._model
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
) )
elif "functionCall" in part: elif "functionCall" in part:
fc = part["functionCall"] fc = part["functionCall"]
accumulated_tool_calls.append({ accumulated_tool_calls.append(
"id": f"call_{len(accumulated_tool_calls)}", {
"name": fc.get("name", ""), "id": f"call_{len(accumulated_tool_calls)}",
"arguments": fc.get("args", {}), "name": fc.get("name", ""),
}) "arguments": fc.get("args", {}),
}
)
# Check for finish reason # Check for finish reason
finish_reason = candidates[0].get("finishReason", "") finish_reason = candidates[0].get("finishReason", "")
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
) )
accumulated_tool_calls = [] accumulated_tool_calls = []
def get_model_info(self) -> dict[str, Any]: def get_model_info(self) -> dict[str, object]:
"""返回 Provider 和模型信息""" """返回 Provider 和模型信息"""
return { return {
"provider": "gemini", "provider": "gemini",

View File

@ -26,8 +26,7 @@ import inspect
import json import json
import logging import logging
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Iterable
from typing import Any
from agentkit.core.exceptions import LLMProviderError from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import ( from agentkit.llm.protocol import (
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
api_key: str, api_key: str,
base_url: str | None = None, base_url: str | None = None,
provider_type: str = "openai", provider_type: str = "openai",
**default_kwargs: Any, **default_kwargs: object,
) -> None: ) -> None:
self._model_prefix = model_prefix self._model_prefix = model_prefix
self._api_key = api_key self._api_key = api_key
self._base_url = base_url or None # 空字符串视作未设置 self._base_url = base_url or None # 空字符串视作未设置
self._provider_type = provider_type self._provider_type = provider_type
self._default_kwargs: dict[str, Any] = dict(default_kwargs) self._default_kwargs: dict[str, object] = dict(default_kwargs)
async def chat(self, request: LLMRequest) -> LLMResponse: async def chat(self, request: LLMRequest) -> LLMResponse:
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。""" """非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
kwargs = self._build_kwargs(request, stream=True) kwargs = self._build_kwargs(request, stream=True)
accumulated_tool_calls: dict[int, dict[str, Any]] = {} accumulated_tool_calls: dict[int, dict[str, object]] = {}
final_usage: TokenUsage | None = None final_usage: TokenUsage | None = None
final_model: str = request.model final_model: str = request.model
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
# 内部辅助 # 内部辅助
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]: def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
"""从 LLMRequest 构造 litellm.acompletion kwargs。""" """从 LLMRequest 构造 litellm.acompletion kwargs。"""
kwargs: dict[str, Any] = { kwargs: dict[str, object] = {
"model": f"{self._model_prefix}{request.model}", "model": f"{self._model_prefix}{request.model}",
"messages": request.messages, "messages": request.messages,
"temperature": request.temperature, "temperature": request.temperature,
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
def _parse_response( def _parse_response(
self, self,
response: Any, response: object,
request_model: str, request_model: str,
latency_ms: float, latency_ms: float,
) -> LLMResponse: ) -> LLMResponse:
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
def _parse_stream_chunk( def _parse_stream_chunk(
self, self,
chunk: Any, chunk: object,
request_model: str, request_model: str,
accumulated_tool_calls: dict[int, dict[str, Any]], accumulated_tool_calls: dict[int, dict[str, object]],
) -> StreamChunk: ) -> StreamChunk:
"""解析单个流式 chunk非 final。累计 tool_calls 到传入字典。""" """解析单个流式 chunk非 final。累计 tool_calls 到传入字典。"""
choices = getattr(chunk, "choices", None) or [] choices = getattr(chunk, "choices", None) or []
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
def _finalize_tool_calls( def _finalize_tool_calls(
self, self,
accumulated: dict[int, dict[str, Any]], accumulated: dict[int, dict[str, object]],
) -> list[ToolCall]: ) -> list[ToolCall]:
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。""" """把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
tool_calls: list[ToolCall] = [] tool_calls: list[ToolCall] = []
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
"""解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。""" """解析非流式响应的 tool_callsOpenAI 格式 list[ChoiceMessageToolCall])。"""
result: list[ToolCall] = [] result: list[ToolCall] = []
for tc in raw_tool_calls: for tc in raw_tool_calls:
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
return result return result
def _parse_usage(usage_obj: Any) -> TokenUsage: def _parse_usage(usage_obj: object) -> TokenUsage:
"""解析 usage 对象OpenAI CompletionUsage 或 dict""" """解析 usage 对象OpenAI CompletionUsage 或 dict"""
prompt = getattr(usage_obj, "prompt_tokens", None) prompt = getattr(usage_obj, "prompt_tokens", None)
completion = getattr(usage_obj, "completion_tokens", None) completion = getattr(usage_obj, "completion_tokens", None)
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
def _accumulate_stream_tool_calls( def _accumulate_stream_tool_calls(
raw_tool_calls: Any, raw_tool_calls: Iterable[object],
accumulated: dict[int, dict[str, Any]], accumulated: dict[int, dict[str, object]],
) -> None: ) -> None:
"""累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。 """累计流式 chunk 里的 tool_calls 片段OpenAI delta.tool_calls 格式)。
@ -364,7 +363,7 @@ def create_litellm_provider(
provider_type: str, provider_type: str,
api_key: str, api_key: str,
base_url: str | None = None, base_url: str | None = None,
**kwargs: Any, **kwargs: object,
) -> LitellmProvider: ) -> LitellmProvider:
"""根据 provider_type 创建 LitellmProvider 实例。 """根据 provider_type 创建 LitellmProvider 实例。

View File

@ -18,7 +18,7 @@ import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from agentkit.llm.protocol import TokenUsage from agentkit.llm.protocol import TokenUsage
@ -294,8 +294,8 @@ class RedisUsageStore:
def __init__(self, redis_url: str = "redis://localhost:6379"): def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url self._redis_url = redis_url
self._redis: Any = None self._redis: object | None = None
self._sync_redis: Any = None self._sync_redis: object | None = None
self._fallback: InMemoryUsageStore | None = None self._fallback: InMemoryUsageStore | None = None
self._degraded = False self._degraded = False
self._health_check_task: asyncio.Task[None] | None = None self._health_check_task: asyncio.Task[None] | None = None
@ -687,7 +687,7 @@ class RedisUsageStore:
@staticmethod @staticmethod
def _read_list( def _read_list(
r: Any, r: object,
list_key: str, list_key: str,
start: datetime, start: datetime,
end: datetime, end: datetime,

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging import logging
import time import time
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
secret_key: str | None = None, secret_key: str | None = None,
base_url: str = WENXIN_DEFAULT_BASE_URL, base_url: str = WENXIN_DEFAULT_BASE_URL,
default_model: str = "ernie-4.5-turbo-128k", default_model: str = "ernie-4.5-turbo-128k",
**kwargs: Any, **kwargs: object,
): ):
# If AK/SK provided, use token-based auth # If AK/SK provided, use token-based auth
self._access_key = access_key self._access_key = access_key

View File

@ -8,7 +8,6 @@ API腾讯云 OpenAI 兼容接口
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.llm.protocol import LLMRequest, LLMResponse from agentkit.llm.protocol import LLMRequest, LLMResponse
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
base_url: str = YUANBAO_DEFAULT_BASE_URL, base_url: str = YUANBAO_DEFAULT_BASE_URL,
default_model: str = "hunyuan-turbos-latest", default_model: str = "hunyuan-turbos-latest",
enable_enhancement: bool = False, enable_enhancement: bool = False,
**kwargs: Any, **kwargs: object,
): ):
self._enable_enhancement = enable_enhancement self._enable_enhancement = enable_enhancement
super().__init__( super().__init__(

View File

@ -3,7 +3,6 @@
import json import json
import logging import logging
from collections.abc import AsyncIterator, Awaitable, Callable from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any
import httpx import httpx
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
def _build_payload(self, request: LLMRequest) -> dict[str, Any]: def _build_payload(self, request: LLMRequest) -> dict[str, object]:
"""Convert LLMRequest to server API payload.""" """Convert LLMRequest to server API payload."""
return { return {
"messages": request.messages, "messages": request.messages,
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
return str(body["error"]) return str(body["error"])
return str(body) return str(body)
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse: def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
"""Parse server response JSON into an LLMResponse.""" """Parse server response JSON into an LLMResponse."""
usage_data = data.get("usage") or {} usage_data = data.get("usage") or {}
usage = TokenUsage( usage = TokenUsage(
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
latency_ms=data.get("latency_ms", 0.0), latency_ms=data.get("latency_ms", 0.0),
) )
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk: def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
"""Parse a single SSE data payload into a StreamChunk.""" """Parse a single SSE data payload into a StreamChunk."""
usage: TokenUsage | None = None usage: TokenUsage | None = None
usage_data = data.get("usage") usage_data = data.get("usage")
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
if response.status_code == 502: if response.status_code == 502:
await response.aread() await response.aread()
detail = self._extract_error_detail(response) detail = self._extract_error_detail(response)
raise LLMProviderError( raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
"remote", f"Server LLM gateway error: {detail}"
)
if response.status_code != 200: if response.status_code != 200:
await response.aread() await response.aread()
raise LLMProviderError( raise LLMProviderError(

View File

@ -5,7 +5,7 @@ import logging
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable from typing import Callable
from agentkit.core.exceptions import LLMProviderError from agentkit.core.exceptions import LLMProviderError
@ -20,9 +20,7 @@ class RetryConfig:
base_delay: float = 1.0 base_delay: float = 1.0
max_delay: float = 30.0 max_delay: float = 30.0
exponential_base: float = 2.0 exponential_base: float = 2.0
retryable_status_codes: set[int] = field( retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
default_factory=lambda: {429, 500, 502, 503, 529}
)
class CircuitState(Enum): class CircuitState(Enum):
@ -69,7 +67,7 @@ class RetryPolicy:
def __init__(self, config: RetryConfig | None = None): def __init__(self, config: RetryConfig | None = None):
self._config = config or RetryConfig() self._config = config or RetryConfig()
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn with retry on retryable errors.""" """Execute fn with retry on retryable errors."""
last_error: Exception | None = None last_error: Exception | None = None
@ -84,7 +82,7 @@ class RetryPolicy:
raise raise
delay = min( delay = min(
self._config.base_delay * (self._config.exponential_base ** attempt), self._config.base_delay * (self._config.exponential_base**attempt),
self._config.max_delay, self._config.max_delay,
) )
logger.warning( logger.warning(
@ -142,7 +140,7 @@ class CircuitBreaker:
f"after {self._failure_count} failures" f"after {self._failure_count} failures"
) )
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any: async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn through the circuit breaker.""" """Execute fn through the circuit breaker."""
current_state = self.state current_state = self.state
@ -158,6 +156,6 @@ class CircuitBreaker:
result = await fn(*args, **kwargs) result = await fn(*args, **kwargs)
self._on_success() self._on_success()
return result return result
except Exception as e: except Exception:
self._on_failure() self._on_failure()
raise raise

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
import httpx import httpx
@ -95,8 +94,7 @@ class KBAdapter(ABC):
async def delete_by_id(self, id: str) -> bool: async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除(子类可覆盖)""" """按文档 ID 删除(子类可覆盖)"""
logger.warning( logger.warning(
f"{self.__class__.__name__} does not support delete_by_id; " f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
f"id '{id}' skipped"
) )
return False return False
@ -127,8 +125,7 @@ class KBAdapter(ABC):
async def get_document(self, doc_id: str) -> Document | None: async def get_document(self, doc_id: str) -> Document | None:
"""按 ID 获取单个文档(子类可覆盖)""" """按 ID 获取单个文档(子类可覆盖)"""
logger.warning( logger.warning(
f"{self.__class__.__name__} does not support get_document; " f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
f"doc_id '{doc_id}' not found"
) )
return None return None
@ -156,5 +153,5 @@ class KBAdapter(ABC):
async def __aenter__(self) -> KBAdapter: async def __aenter__(self) -> KBAdapter:
return self return self
async def __aexit__(self, *args: Any) -> None: async def __aexit__(self, *args: object) -> None:
await self.close() await self.close()

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
import httpx import httpx
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
) )
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url): if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._username = username self._username = username
self._api_token = api_token self._api_token = api_token
self._space_keys = space_keys or [] self._space_keys = space_keys or []
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
"""创建 Confluence API HTTP 客户端""" """创建 Confluence API HTTP 客户端"""
import base64 import base64
credentials = base64.b64encode( credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
f"{self._username}:{self._api_token}".encode()
).decode()
return httpx.AsyncClient( return httpx.AsyncClient(
base_url=self._base_url, base_url=self._base_url,
headers={ headers={
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
space_filter = " OR ".join( space_filter = " OR ".join(
f'space = "{_escape_cql(key)}"' for key in self._space_keys f'space = "{_escape_cql(key)}"' for key in self._space_keys
) )
cql = f'{cql} AND ({space_filter})' cql = f"{cql} AND ({space_filter})"
resp = await client.get( resp = await client.get(
"/rest/api/content/search", "/rest/api/content/search",
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "") body = page.get("body", {}).get("storage", {}).get("value", "")
# Strip HTML tags for plain text content # Strip HTML tags for plain text content
import re import re
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "") content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
results.append( results.append(
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"Confluence search HTTP error: {e.response.status_code}" f"Confluence search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
body = page.get("body", {}).get("storage", {}).get("value", "") body = page.get("body", {}).get("storage", {}).get("value", "")
import re import re
content = re.sub(r"<[^>]+>", "", body) if body else "" content = re.sub(r"<[^>]+>", "", body) if body else ""
return Document( return Document(
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
source_type="confluence", source_type="confluence",
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.error(f"Confluence list_sources error: {e}") logger.error(f"Confluence list_sources error: {e}")
return [ return [

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import logging import logging
import time import time
from typing import Any from typing import TypeAlias
import httpx import httpx
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 飞书搜索请求 payloadsearch_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
FeishuSearchPayload: TypeAlias = dict[str, object]
class FeishuKBAdapter(KBAdapter): class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器 """飞书知识库适配器
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
self._app_secret = app_secret self._app_secret = app_secret
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not is_safe_url(self._base_url): if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
)
self._space_ids = space_ids or [] self._space_ids = space_ids or []
self._access_token: str | None = None self._access_token: str | None = None
self._token_expiry: float = 0.0 self._token_expiry: float = 0.0
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
self._client = None self._client = None
return self._access_token return self._access_token
else: else:
logger.error( logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
f"Feishu auth failed: code={data.get('code')}, "
f"msg={data.get('msg')}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Feishu auth error: {e}") logger.error(f"Feishu auth error: {e}")
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client() client = self._get_client()
try: try:
payload: dict[str, Any] = { payload: FeishuSearchPayload = {
"search_key": query, "search_key": query,
"page_size": top_k, "page_size": top_k,
} }
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
if data.get("code") != 0: if data.get("code") != 0:
logger.error( logger.error(
f"Feishu search failed: code={data.get('code')}, " f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
f"msg={data.get('msg')}"
) )
return [] return []
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"Feishu search HTTP error: {e.response.status_code}" f"Feishu search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
client = self._get_client() client = self._get_client()
try: try:
resp = await client.get( resp = await client.get(
f"/wiki/v2/spaces/get_node", "/wiki/v2/spaces/get_node",
params={"token": doc_id}, params={"token": doc_id},
) )
resp.raise_for_status() resp.raise_for_status()
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
source_type="feishu", source_type="feishu",
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.error(f"Feishu list_sources error: {e}") logger.error(f"Feishu list_sources error: {e}")
return [ return [

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any
import httpx import httpx
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
) )
self._endpoint_url = endpoint_url.rstrip("/") self._endpoint_url = endpoint_url.rstrip("/")
if not is_safe_url(self._endpoint_url): if not is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.") raise ValueError(
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
)
self._auth_config = auth_config or {} self._auth_config = auth_config or {}
self._extra_headers = headers or {} self._extra_headers = headers or {}
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
headers["Authorization"] = f"Bearer {token}" headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic": elif auth_type == "basic":
import base64 import base64
username = self._auth_config.get("username", "") username = self._auth_config.get("username", "")
password = self._auth_config.get("password", "") password = self._auth_config.get("password", "")
if username and password: if username and password:
credentials = base64.b64encode( credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}" headers["Authorization"] = f"Basic {credentials}"
elif auth_type == "api_key": elif auth_type == "api_key":
key_name = self._auth_config.get("header_name", "X-API-Key") key_name = self._auth_config.get("header_name", "X-API-Key")
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"GenericHTTP search HTTP error: {e.response.status_code}" f"GenericHTTP search HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"GenericHTTP ingest HTTP error: {e.response.status_code}" f"GenericHTTP ingest HTTP error: {e.response.status_code}{e.response.text[:200]}"
f"{e.response.text[:200]}"
) )
return [] return []
except Exception as e: except Exception as e:
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
document_count=item.get("document_count", 0), document_count=item.get("document_count", 0),
) )
) )
return sources if sources else [ return (
SourceInfo( sources
source_id=self._source_id, if sources
source_name=self._source_name, else [
source_type=self._source_type, SourceInfo(
) source_id=self._source_id,
] source_name=self._source_name,
source_type=self._source_type,
)
]
)
except Exception as e: except Exception as e:
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}") logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")

View File

@ -3,19 +3,29 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TypeAlias
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
MetadataValue: TypeAlias = str | int | float | bool | None
MetadataDict: TypeAlias = dict[str, MetadataValue]
MemoryValue: TypeAlias = (
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
)
@dataclass @dataclass
class MemoryItem: class MemoryItem:
"""记忆条目""" """记忆条目"""
key: str key: str
value: Any value: object
metadata: dict[str, Any] = field(default_factory=dict) metadata: MetadataDict = field(default_factory=dict)
score: float = 1.0 score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict: def to_dict(self) -> dict[str, object]:
return { return {
"key": self.key, "key": self.key,
"value": self.value, "value": self.value,
@ -35,7 +45,7 @@ class Memory(ABC):
""" """
@abstractmethod @abstractmethod
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储记忆""" """存储记忆"""
... ...
@ -45,7 +55,9 @@ class Memory(ABC):
... ...
@abstractmethod @abstractmethod
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索""" """语义检索"""
... ...
@ -54,7 +66,7 @@ class Memory(ABC):
"""删除记忆""" """删除记忆"""
... ...
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None: async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
"""批量存储""" """批量存储"""
for key, value, metadata in items: for key, value, metadata in items:
await self.store(key, value, metadata) await self.store(key, value, metadata)

View File

@ -11,10 +11,18 @@ import logging
import re import re
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 分块元数据source_doc/position/char_count/chunking_strategy/heading/heading_level
# — 全部为原始标量str/int
ChunkMetadata: TypeAlias = MetadataDict
# _split_by_headings 返回的节段结构。
SectionInfo: TypeAlias = dict[str, str | int]
@dataclass @dataclass
class Chunk: class Chunk:
@ -22,7 +30,7 @@ class Chunk:
chunk_id: str chunk_id: str
content: str content: str
metadata: dict[str, Any] = field(default_factory=dict) metadata: ChunkMetadata = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if "source_doc" not in self.metadata: if "source_doc" not in self.metadata:
@ -30,7 +38,7 @@ class Chunk:
if "position" not in self.metadata: if "position" not in self.metadata:
self.metadata["position"] = 0 self.metadata["position"] = 0
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, object]:
return { return {
"chunk_id": self.chunk_id, "chunk_id": self.chunk_id,
"content": self.content, "content": self.content,
@ -57,7 +65,9 @@ class TextChunker:
separator: 优先分割符 separator: 优先分割符
""" """
if chunk_overlap >= chunk_size: if chunk_overlap >= chunk_size:
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})") raise ValueError(
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
)
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._separator = separator self._separator = separator
@ -66,7 +76,7 @@ class TextChunker:
self, self,
text: str, text: str,
source_doc_id: str = "", source_doc_id: str = "",
metadata: dict[str, Any] | None = None, metadata: ChunkMetadata | None = None,
) -> list[Chunk]: ) -> list[Chunk]:
"""将文本分块 """将文本分块
@ -96,11 +106,13 @@ class TextChunker:
chunk_meta = dict(base_meta) chunk_meta = dict(base_meta)
chunk_meta["position"] = i chunk_meta["position"] = i
chunk_meta["char_count"] = len(chunk_text) chunk_meta["char_count"] = len(chunk_text)
chunks.append(Chunk( chunks.append(
chunk_id=str(uuid.uuid4()), Chunk(
content=chunk_text, chunk_id=str(uuid.uuid4()),
metadata=chunk_meta, content=chunk_text,
)) metadata=chunk_meta,
)
)
return chunks return chunks
@ -142,7 +154,9 @@ class TextChunker:
overlap_text[overlap_start:], segments overlap_text[overlap_start:], segments
) )
current = overlap_segments current = overlap_segments
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1) current_len = sum(len(s) for s in current) + len(self._separator) * max(
0, len(current) - 1
)
current.append(segment) current.append(segment)
current_len += seg_len + len(self._separator) current_len += seg_len + len(self._separator)
@ -214,7 +228,7 @@ class StructuralChunker:
self, self,
text: str, text: str,
source_doc_id: str = "", source_doc_id: str = "",
metadata: dict[str, Any] | None = None, metadata: ChunkMetadata | None = None,
) -> list[Chunk]: ) -> list[Chunk]:
"""将文本按结构分块 """将文本按结构分块
@ -266,23 +280,25 @@ class StructuralChunker:
chunk_meta["heading"] = heading chunk_meta["heading"] = heading
chunk_meta["heading_level"] = level chunk_meta["heading_level"] = level
chunk_meta["char_count"] = len(content) chunk_meta["char_count"] = len(content)
chunks.append(Chunk( chunks.append(
chunk_id=str(uuid.uuid4()), Chunk(
content=content, chunk_id=str(uuid.uuid4()),
metadata=chunk_meta, content=content,
)) metadata=chunk_meta,
)
)
position += 1 position += 1
return chunks return chunks
def _split_by_headings(self, text: str) -> list[dict[str, Any]]: def _split_by_headings(self, text: str) -> list[SectionInfo]:
"""按标题分割 Markdown 文本 """按标题分割 Markdown 文本
Returns: Returns:
列表每项包含 heading, content, level 列表每项包含 heading, content, level
""" """
lines = text.split("\n") lines = text.split("\n")
sections: list[dict[str, Any]] = [] sections: list[SectionInfo] = []
current_heading = "" current_heading = ""
current_level = 0 current_level = 0
current_lines: list[str] = [] current_lines: list[str] = []
@ -296,11 +312,13 @@ class StructuralChunker:
if current_lines: if current_lines:
content = "\n".join(current_lines).strip() content = "\n".join(current_lines).strip()
if content: if content:
sections.append({ sections.append(
"heading": current_heading, {
"content": content, "heading": current_heading,
"level": current_level, "content": content,
}) "level": current_level,
}
)
# 开始新节 # 开始新节
current_heading = match.group(2).strip() current_heading = match.group(2).strip()
@ -313,18 +331,22 @@ class StructuralChunker:
if current_lines: if current_lines:
content = "\n".join(current_lines).strip() content = "\n".join(current_lines).strip()
if content: if content:
sections.append({ sections.append(
"heading": current_heading, {
"content": content, "heading": current_heading,
"level": current_level, "content": content,
}) "level": current_level,
}
)
# 如果没有标题结构,整体作为一个块 # 如果没有标题结构,整体作为一个块
if not sections: if not sections:
sections.append({ sections.append(
"heading": "", {
"content": text.strip(), "heading": "",
"level": 0, "content": text.strip(),
}) "level": 0,
}
)
return sections return sections

View File

@ -9,10 +9,14 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import TYPE_CHECKING
from agentkit.memory.base import MetadataDict
from agentkit.memory.embedder import EmbeddingCache from agentkit.memory.embedder import EmbeddingCache
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +28,7 @@ class ContextualChunk:
context_prefix: str context_prefix: str
enhanced_content: str enhanced_content: str
chunk_index: int chunk_index: int
metadata: dict[str, Any] metadata: MetadataDict
@property @property
def content(self) -> str: def content(self) -> str:
@ -65,7 +69,7 @@ class ContextualChunker:
def __init__( def __init__(
self, self,
llm_gateway: Any = None, llm_gateway: LLMGateway | None = None,
cache: EmbeddingCache | None = None, cache: EmbeddingCache | None = None,
batch_size: int = 8, batch_size: int = 8,
max_context_length: int = 200, max_context_length: int = 200,
@ -90,7 +94,7 @@ class ContextualChunker:
self, self,
document: str, document: str,
chunks: list[str], chunks: list[str],
metadata: dict[str, Any] | None = None, metadata: MetadataDict | None = None,
) -> list[ContextualChunk]: ) -> list[ContextualChunk]:
"""为文档块添加上下文前缀 """为文档块添加上下文前缀
@ -134,7 +138,7 @@ class ContextualChunker:
document: str, document: str,
chunks: list[str], chunks: list[str],
start_index: int, start_index: int,
metadata: dict[str, Any] | None, metadata: MetadataDict | None,
) -> list[ContextualChunk]: ) -> list[ContextualChunk]:
"""处理一批文档块""" """处理一批文档块"""
results: list[ContextualChunk] = [] results: list[ContextualChunk] = []

View File

@ -12,7 +12,9 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import TypeAlias
from agentkit.memory.base import MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
MAX_ROWS_PER_SHEET = 10_000 MAX_ROWS_PER_SHEET = 10_000
MAX_CELL_CHARS = 10_000 MAX_CELL_CHARS = 10_000
# 文档元数据source/format/parser/page_count/table_count/sheet_count/row_count/
# heading_count/created_at/title/truncated — 全部为原始标量。
DocumentMetadata: TypeAlias = MetadataDict
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
@dataclass @dataclass
class Document: class Document:
@ -31,7 +38,7 @@ class Document:
doc_id: str doc_id: str
title: str title: str
content: str content: str
metadata: dict[str, Any] = field(default_factory=dict) metadata: DocumentMetadata = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if "source" not in self.metadata: if "source" not in self.metadata:
@ -43,7 +50,7 @@ class Document:
if "created_at" not in self.metadata: if "created_at" not in self.metadata:
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat() self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, object]:
return { return {
"doc_id": self.doc_id, "doc_id": self.doc_id,
"title": self.title, "title": self.title,
@ -136,12 +143,14 @@ class DocumentLoader:
parser = parsers.get(doc_format) parser = parsers.get(doc_format)
if parser is None: if parser is None:
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text") logger.warning(
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
)
parser = self._parse_text parser = self._parse_text
text, extra_meta = parser(content, filename) text, extra_meta = parser(content, filename)
metadata: dict[str, Any] = { metadata: DocumentMetadata = {
"source": filename, "source": filename,
"format": doc_format, "format": doc_format,
"created_at": datetime.now(timezone.utc).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(),
@ -159,7 +168,7 @@ class DocumentLoader:
metadata=metadata, metadata=metadata,
) )
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
"""解析 PDF 文件 """解析 PDF 文件
优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本 优先使用 PyMuPDF (fitz)回退到 pdfplumber最终回退到纯文本
@ -215,7 +224,7 @@ class DocumentLoader:
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction") logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Word 文件 """解析 Word 文件
使用 python-docx回退到纯文本 使用 python-docx回退到纯文本
@ -259,7 +268,7 @@ class DocumentLoader:
logger.warning(f"python-docx parsing failed for {filename}: {e}") logger.warning(f"python-docx parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
"""解析 Excel 文件 """解析 Excel 文件
使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格 使用 openpyxl回退到纯文本每个 sheet 转为 Markdown 表格
@ -313,7 +322,7 @@ class DocumentLoader:
finally: finally:
wb.close() wb.close()
text = "\n".join(sections).strip() text = "\n".join(sections).strip()
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "openpyxl", "parser": "openpyxl",
"sheet_count": sheet_count, "sheet_count": sheet_count,
"row_count": total_rows, "row_count": total_rows,
@ -328,7 +337,7 @@ class DocumentLoader:
logger.warning(f"openpyxl parsing failed for {filename}: {e}") logger.warning(f"openpyxl parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
"""解析 Markdown 文件 """解析 Markdown 文件
使用 mistune如果可用否则直接读取文本 使用 mistune如果可用否则直接读取文本
@ -347,7 +356,7 @@ class DocumentLoader:
title = line_stripped.lstrip("#").strip() title = line_stripped.lstrip("#").strip()
break break
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "markdown", "parser": "markdown",
} }
if title: if title:
@ -362,7 +371,7 @@ class DocumentLoader:
return text, meta return text, meta
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_html(self, content: bytes, filename: str) -> ParseResult:
"""解析 HTML 文件 """解析 HTML 文件
使用 BeautifulSoup 提取文本回退到纯文本 使用 BeautifulSoup 提取文本回退到纯文本
@ -388,7 +397,7 @@ class DocumentLoader:
if soup.title and soup.title.string: if soup.title and soup.title.string:
title = soup.title.string.strip() title = soup.title.string.strip()
meta: dict[str, Any] = { meta: DocumentMetadata = {
"parser": "beautifulsoup", "parser": "beautifulsoup",
} }
if title: if title:
@ -402,7 +411,7 @@ class DocumentLoader:
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}") logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
return self._parse_text(content, filename) return self._parse_text(content, filename)
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]: def _parse_text(self, content: bytes, filename: str) -> ParseResult:
"""解析纯文本文件""" """解析纯文本文件"""
try: try:
text = content.decode("utf-8") text = content.decode("utf-8")

View File

@ -1,12 +1,17 @@
"""Embedder 接口与实现 - 文本向量化""" """Embedder 接口与实现 - 文本向量化"""
from __future__ import annotations
import hashlib import hashlib
import logging import logging
import os import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
self._model = model self._model = model
self._base_url = base_url self._base_url = base_url
self._dimension = 1536 # text-embedding-3-small 默认维度 self._dimension = 1536 # text-embedding-3-small 默认维度
self._client: Any = None self._client: httpx.AsyncClient | None = None
self._cache = cache self._cache = cache
def _get_client(self): def _get_client(self) -> httpx.AsyncClient:
"""Lazily create and reuse a single httpx.AsyncClient.""" """Lazily create and reuse a single httpx.AsyncClient."""
if self._client is None: if self._client is None:
import httpx import httpx
self._client = httpx.AsyncClient(timeout=30.0) self._client = httpx.AsyncClient(timeout=30.0)
return self._client return self._client

View File

@ -1,17 +1,22 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆""" """Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
from __future__ import annotations
import json import json
import logging import logging
import math import math
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING
from sqlalchemy import text from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import Memory, MemoryItem, MetadataDict
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
def __init__( def __init__(
self, self,
session_factory: Any, session_factory: object,
episodic_model: Any, episodic_model: object,
embedder: Embedder | None = None, embedder: Embedder | None = None,
decay_rate: float = 0.01, decay_rate: float = 0.01,
alpha: float = 0.7, alpha: float = 0.7,
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._table_name = table_name self._table_name = table_name
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""存储任务经验""" """存储任务经验"""
async with self._session_factory() as db: async with self._session_factory() as db:
try: try:
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
embedding = None embedding = None
if self._embedder: if self._embedder:
if isinstance(value, dict): if isinstance(value, dict):
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500] text = (
value.get("output_summary", "")
or value.get("input_summary", "")
or json.dumps(value, ensure_ascii=False)[:500]
)
else: else:
text = str(value) text = str(value)
embedding = await self._embedder.embed(text) embedding = await self._embedder.embed(text)
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
logger.error(f"Failed to retrieve episodic memory: {e}") logger.error(f"Failed to retrieve episodic memory: {e}")
return None return None
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: async def _retrieve_pgvector(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""使用 pgvector ``<=>`` 算符检索最相似条目""" """使用 pgvector ``<=>`` 算符检索最相似条目"""
sql = text( sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
f"SELECT * FROM {self._table_name} "
f"ORDER BY embedding <=> :query_vec "
f"LIMIT :lim"
)
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1}) result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
row = result.mappings().first() row = result.mappings().first()
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
created_at=row.get("created_at") or datetime.now(timezone.utc), created_at=row.get("created_at") or datetime.now(timezone.utc),
) )
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None: async def _retrieve_client_side(
self, db: AsyncSession, query_embedding: list[float]
) -> MemoryItem | None:
"""客户端 O(N) cosine similarity 检索(回退路径)""" """客户端 O(N) cosine similarity 检索(回退路径)"""
Model = self._episodic_model Model = self._episodic_model
from sqlalchemy import select from sqlalchemy import select
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
created_at=best_item.created_at or datetime.now(timezone.utc), created_at=best_item.created_at or datetime.now(timezone.utc),
) )
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]: async def search(
self,
query: str,
top_k: int = 5,
filters: MetadataDict | None = None,
search_multiplier: int = 5,
) -> list[MemoryItem]:
"""语义检索相似历史案例 """语义检索相似历史案例
Args: Args:
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
async def _search_pgvector( async def _search_pgvector(
self, self,
db: Any, db: AsyncSession,
query: str, query: str,
top_k: int, top_k: int,
filters: dict[str, Any] | None, filters: MetadataDict | None,
search_multiplier: int, search_multiplier: int,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排""" """使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
fetch_limit = top_k * search_multiplier fetch_limit = top_k * search_multiplier
where_clauses = [] where_clauses = []
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit} params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
filters = filters or {} filters = filters or {}
if filters.get("agent_name"): if filters.get("agent_name"):
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
items = [] items = []
for row in rows: for row in rows:
row_embedding = row.get("embedding") row_embedding = row.get("embedding")
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0 age_hours = (
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
if row.get("created_at")
else 0
)
decay = math.exp(-self._decay_rate * age_hours) decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (row.get("quality_score") or 0.5) * decay time_decay_score = (row.get("quality_score") or 0.5) * decay
@ -266,33 +285,37 @@ class EpisodicMemory(Memory):
else: else:
score = time_decay_score score = time_decay_score
items.append(MemoryItem( items.append(
key=str(row.get("id", "")), MemoryItem(
value={ key=str(row.get("id", "")),
"input_summary": row.get("input_summary", ""), value={
"output_summary": row.get("output_summary", ""), "input_summary": row.get("input_summary", ""),
"outcome": row.get("outcome", "success"), "output_summary": row.get("output_summary", ""),
"quality_score": row.get("quality_score", 0.5), "outcome": row.get("outcome", "success"),
"reflection": row.get("reflection", ""), "quality_score": row.get("quality_score", 0.5),
}, "reflection": row.get("reflection", ""),
metadata={ },
"agent_name": row.get("agent_name", ""), metadata={
"task_type": row.get("task_type", ""), "agent_name": row.get("agent_name", ""),
"created_at": row["created_at"].isoformat() if row.get("created_at") else None, "task_type": row.get("task_type", ""),
}, "created_at": row["created_at"].isoformat()
score=score, if row.get("created_at")
created_at=row.get("created_at") or datetime.now(timezone.utc), else None,
)) },
score=score,
created_at=row.get("created_at") or datetime.now(timezone.utc),
)
)
items.sort(key=lambda x: x.score, reverse=True) items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k] return items[:top_k]
async def _search_client_side( async def _search_client_side(
self, self,
db: Any, db: AsyncSession,
query: str, query: str,
top_k: int, top_k: int,
filters: dict[str, Any] | None, filters: MetadataDict | None,
search_multiplier: int, search_multiplier: int,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""客户端 O(N) cosine similarity 检索(回退路径)""" """客户端 O(N) cosine similarity 检索(回退路径)"""
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
filters = filters or {} filters = filters or {}
from sqlalchemy import select from sqlalchemy import select
stmt = select(Model) stmt = select(Model)
if filters.get("agent_name"): if filters.get("agent_name"):
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
# 计算得分并构建 MemoryItem # 计算得分并构建 MemoryItem
items = [] items = []
for entry in entries: for entry in entries:
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0 age_hours = (
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
if entry.created_at
else 0
)
decay = math.exp(-self._decay_rate * age_hours) decay = math.exp(-self._decay_rate * age_hours)
time_decay_score = (entry.quality_score or 0.5) * decay time_decay_score = (entry.quality_score or 0.5) * decay
@ -333,30 +361,34 @@ class EpisodicMemory(Memory):
else: else:
score = time_decay_score score = time_decay_score
items.append(MemoryItem( items.append(
key=str(entry.id), MemoryItem(
value={ key=str(entry.id),
"input_summary": entry.input_summary, value={
"output_summary": entry.output_summary, "input_summary": entry.input_summary,
"outcome": entry.outcome, "output_summary": entry.output_summary,
"quality_score": entry.quality_score, "outcome": entry.outcome,
"reflection": entry.reflection, "quality_score": entry.quality_score,
}, "reflection": entry.reflection,
metadata={ },
"agent_name": entry.agent_name, metadata={
"task_type": entry.task_type, "agent_name": entry.agent_name,
"created_at": entry.created_at.isoformat() if entry.created_at else None, "task_type": entry.task_type,
}, "created_at": entry.created_at.isoformat() if entry.created_at else None,
score=score, },
created_at=entry.created_at or datetime.now(timezone.utc), score=score,
)) created_at=entry.created_at or datetime.now(timezone.utc),
)
)
items.sort(key=lambda x: x.score, reverse=True) items.sort(key=lambda x: x.score, reverse=True)
if len(items) < top_k: if len(items) < top_k:
logger.warning( logger.warning(
"EpisodicMemory.search returned %d results after scoring (top_k=%d). " "EpisodicMemory.search returned %d results after scoring (top_k=%d). "
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.", "Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
len(items), top_k, search_multiplier, len(items),
top_k,
search_multiplier,
) )
return items[:top_k] return items[:top_k]
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
"""删除指定经验""" """删除指定经验"""
async with self._session_factory() as db: async with self._session_factory() as db:
try: try:
from sqlalchemy import select, delete as sql_delete from sqlalchemy import delete as sql_delete
import uuid import uuid
Model = self._episodic_model Model = self._episodic_model
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key)) stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))

View File

@ -3,13 +3,26 @@
配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接 配置驱动不直接依赖业务系统代码通过 base_url + api_key 连接
""" """
from __future__ import annotations
import logging import logging
from typing import Any from typing import TYPE_CHECKING, TypeAlias
import httpx import httpx
from agentkit.memory.base import MetadataDict
if TYPE_CHECKING:
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 标准化检索结果id/content/score/source/document_id/document_title/
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
RAGSearchResult: TypeAlias = dict[str, object]
# ingest() 写入的文档负载title/content/source_type/metadata。
RAGIngestPayload: TypeAlias = dict[str, object]
class HttpRAGService: class HttpRAGService:
"""HTTP 客户端,调用业务系统的知识库检索 API """HTTP 客户端,调用业务系统的知识库检索 API
@ -39,7 +52,7 @@ class HttpRAGService:
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
timeout: int = 30, timeout: int = 30,
contextual_chunking: bool = False, contextual_chunking: bool = False,
llm_gateway: Any = None, llm_gateway: LLMGateway | None = None,
): ):
""" """
Args: Args:
@ -74,7 +87,7 @@ class HttpRAGService:
query: str, query: str,
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
top_k: int = 5, top_k: int = 5,
) -> list[dict[str, Any]]: ) -> list[RAGSearchResult]:
"""语义检索知识库 """语义检索知识库
Args: Args:
@ -113,19 +126,23 @@ class HttpRAGService:
normalized = [] normalized = []
for r in results: for r in results:
if isinstance(r, dict): if isinstance(r, dict):
normalized.append({ normalized.append(
"id": r.get("chunk_id", r.get("id", "")), {
"content": r.get("content", ""), "id": r.get("chunk_id", r.get("id", "")),
"score": float(r.get("score", 0.0)), "content": r.get("content", ""),
"source": r.get("source", "rag"), "score": float(r.get("score", 0.0)),
"document_id": r.get("document_id", ""), "source": r.get("source", "rag"),
"document_title": r.get("document_title", ""), "document_id": r.get("document_id", ""),
"metadata": r.get("metadata", {}), "document_title": r.get("document_title", ""),
}) "metadata": r.get("metadata", {}),
}
)
return normalized return normalized
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}") logger.error(
f"RAG search HTTP error: {e.response.status_code}{e.response.text[:200]}"
)
return [] return []
except httpx.RequestError as e: except httpx.RequestError as e:
logger.error(f"RAG search request error: {e}") logger.error(f"RAG search request error: {e}")
@ -141,7 +158,7 @@ class HttpRAGService:
top_k: int = 5, top_k: int = 5,
use_rerank: bool = True, use_rerank: bool = True,
use_compression: bool = False, use_compression: bool = False,
) -> list[dict[str, Any]]: ) -> list[RAGSearchResult]:
"""增强语义检索知识库(支持 rerank 和 compression """增强语义检索知识库(支持 rerank 和 compression
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口 对每个知识库分别调用 /bases/{kb_id}/retrieve 接口
@ -169,7 +186,7 @@ class HttpRAGService:
} }
client = self._get_client() client = self._get_client()
all_results: list[dict[str, Any]] = [] all_results: list[RAGSearchResult] = []
for kb_id in kb_ids: for kb_id in kb_ids:
try: try:
@ -189,28 +206,27 @@ class HttpRAGService:
# 标准化 # 标准化
for r in results: for r in results:
if isinstance(r, dict): if isinstance(r, dict):
all_results.append({ all_results.append(
"id": r.get("chunk_id", r.get("id", "")), {
"content": r.get("content", ""), "id": r.get("chunk_id", r.get("id", "")),
"score": float(r.get("score", 0.0)), "content": r.get("content", ""),
"source": r.get("source", "rag"), "score": float(r.get("score", 0.0)),
"document_id": r.get("document_id", ""), "source": r.get("source", "rag"),
"document_title": r.get("document_title", ""), "document_id": r.get("document_id", ""),
"knowledge_base_id": kb_id, "document_title": r.get("document_title", ""),
"metadata": r.get("metadata", {}), "knowledge_base_id": kb_id,
}) "metadata": r.get("metadata", {}),
}
)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if e.response.status_code == 404: if e.response.status_code == 404:
# This KB doesn't support enhanced search — fall back to # This KB doesn't support enhanced search — fall back to
# standard search for THIS KB only, not all KBs. # standard search for THIS KB only, not all KBs.
logger.info( logger.info(
f"Enhanced search not available for KB {kb_id}, " f"Enhanced search not available for KB {kb_id}, using standard search"
f"using standard search"
)
std_result = await self.search(
query, knowledge_base_ids=[kb_id], top_k=top_k
) )
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
all_results.extend(std_result) all_results.extend(std_result)
else: else:
logger.error( logger.error(
@ -232,9 +248,9 @@ class HttpRAGService:
async def ingest( async def ingest(
self, self,
key: str, key: str,
value: Any, value: object,
metadata: dict[str, Any] | None = None, metadata: MetadataDict | None = None,
) -> dict[str, Any] | None: ) -> dict[str, object] | None:
"""写入文档到知识库(可选操作) """写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured, When contextual_chunking is enabled and llm_gateway is configured,
@ -308,5 +324,5 @@ class HttpRAGService:
async def __aenter__(self) -> "HttpRAGService": async def __aenter__(self) -> "HttpRAGService":
return self return self
async def __aexit__(self, *args: Any) -> None: async def __aexit__(self, *args: object) -> None:
await self.close() await self.close()

View File

@ -11,7 +11,11 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Protocol, runtime_checkable from typing import Protocol, TypeAlias, runtime_checkable
from agentkit.memory.base import MetadataDict
KBMetadata: TypeAlias = MetadataDict
@dataclass @dataclass
@ -22,7 +26,7 @@ class Document:
content: str content: str
title: str = "" title: str = ""
source_id: str = "" source_id: str = ""
metadata: dict[str, Any] = field(default_factory=dict) metadata: KBMetadata = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@ -34,7 +38,7 @@ class QueryResult:
source_id: str source_id: str
source_name: str source_name: str
score: float score: float
metadata: dict[str, Any] = field(default_factory=dict) metadata: KBMetadata = field(default_factory=dict)
doc_id: str = "" doc_id: str = ""
title: str = "" title: str = ""

View File

@ -11,25 +11,32 @@ from __future__ import annotations
import json import json
import logging import logging
import re import re
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING, TypeAlias
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.memory.knowledge_base import ( from agentkit.memory.knowledge_base import (
Document, Document,
KnowledgeBase,
QueryResult, QueryResult,
SourceInfo, SourceInfo,
) )
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# InMemoryLocalRAGService 内部存储的文档元信息结构。
# 字段title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
InMemoryDocInfo: TypeAlias = dict[str, object]
# 内部 chunk 存储结构content/embedding/metadata/source_doc_id。
InMemoryChunkInfo: TypeAlias = dict[str, object]
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document: def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
"""将 document_loader.Document 转换为 knowledge_base.Document""" """将 document_loader.Document 转换为 knowledge_base.Document"""
@ -53,7 +60,7 @@ class LocalRAGService:
def __init__( def __init__(
self, self,
session_factory: Any, session_factory: object,
embedder: Embedder, embedder: Embedder,
chunk_size: int = 1000, chunk_size: int = 1000,
chunk_overlap: int = 200, chunk_overlap: int = 200,
@ -75,10 +82,14 @@ class LocalRAGService:
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._table_name = table_name self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") raise ValueError(
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
async def ingest(self, documents: list[Document]) -> list[str]: async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表 """摄取文档列表
@ -136,9 +147,7 @@ class LocalRAGService:
try: try:
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
sql = sql_text( sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
)
await db.execute(sql, {"doc_id": id}) await db.execute(sql, {"doc_id": id})
await db.commit() await db.commit()
return True return True
@ -171,20 +180,15 @@ class LocalRAGService:
sources = [] sources = []
for row in rows: for row in rows:
meta = {} sources.append(
if row.get("doc_metadata"): SourceInfo(
try: source_id=row["source_doc_id"],
meta = json.loads(row["doc_metadata"]) source_name=row.get("source_title", ""),
except (json.JSONDecodeError, TypeError): source_type=row.get("doc_format", "local"),
pass document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
sources.append(SourceInfo( )
source_id=row["source_doc_id"], )
source_name=row.get("source_title", ""),
source_type=row.get("doc_format", "local"),
document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
))
return sources return sources
except Exception as e: except Exception as e:
logger.error(f"Failed to list sources: {e}") logger.error(f"Failed to list sources: {e}")
@ -271,7 +275,7 @@ class LocalRAGService:
async def _query_pgvector( async def _query_pgvector(
self, self,
db: Any, db: AsyncSession,
query_embedding: list[float], query_embedding: list[float],
top_k: int, top_k: int,
) -> list[QueryResult]: ) -> list[QueryResult]:
@ -286,10 +290,13 @@ class LocalRAGService:
f"LIMIT :lim" f"LIMIT :lim"
) )
result = await db.execute(sql, { result = await db.execute(
"query_vec": str(query_embedding), sql,
"lim": top_k, {
}) "query_vec": str(query_embedding),
"lim": top_k,
},
)
rows = result.mappings().all() rows = result.mappings().all()
results = [] results = []
@ -306,21 +313,23 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
results.append(QueryResult( results.append(
content=row["content"], QueryResult(
source_id=row["source_doc_id"], content=row["content"],
source_name=row.get("source_title", ""), source_id=row["source_doc_id"],
score=cosine, source_name=row.get("source_title", ""),
metadata=chunk_meta, score=cosine,
doc_id=row["source_doc_id"], metadata=chunk_meta,
title=row.get("source_title", ""), doc_id=row["source_doc_id"],
)) title=row.get("source_title", ""),
)
)
return results return results
async def _query_client_side( async def _query_client_side(
self, self,
db: Any, db: AsyncSession,
query_embedding: list[float], query_embedding: list[float],
top_k: int, top_k: int,
) -> list[QueryResult]: ) -> list[QueryResult]:
@ -363,15 +372,17 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
candidates.append(QueryResult( candidates.append(
content=row["content"], QueryResult(
source_id=row["source_doc_id"], content=row["content"],
source_name=row.get("source_title", ""), source_id=row["source_doc_id"],
score=cosine, source_name=row.get("source_title", ""),
metadata=chunk_meta, score=cosine,
doc_id=row["source_doc_id"], metadata=chunk_meta,
title=row.get("source_title", ""), doc_id=row["source_doc_id"],
)) title=row.get("source_title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True) candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k] return candidates[:top_k]
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
""" """
self._embedder = embedder self._embedder = embedder
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
# 内存存储 # 内存存储
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata} self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at} self._documents: dict[
str, InMemoryDocInfo
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
async def ingest(self, documents: list[Document]) -> list[str]: async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表 """摄取文档列表
@ -459,15 +474,17 @@ class InMemoryLocalRAGService:
source_doc_id = chunk_data["source_doc_id"] source_doc_id = chunk_data["source_doc_id"]
doc_info = self._documents.get(source_doc_id, {}) doc_info = self._documents.get(source_doc_id, {})
candidates.append(QueryResult( candidates.append(
content=chunk_data["content"], QueryResult(
source_id=source_doc_id, content=chunk_data["content"],
source_name=doc_info.get("title", ""), source_id=source_doc_id,
score=cosine, source_name=doc_info.get("title", ""),
metadata=chunk_data.get("metadata", {}), score=cosine,
doc_id=source_doc_id, metadata=chunk_data.get("metadata", {}),
title=doc_info.get("title", ""), doc_id=source_doc_id,
)) title=doc_info.get("title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True) candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k] return candidates[:top_k]
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
"""列出已摄取的文档""" """列出已摄取的文档"""
sources = [] sources = []
for doc_id, doc_info in self._documents.items(): for doc_id, doc_info in self._documents.items():
sources.append(SourceInfo( sources.append(
source_id=doc_id, SourceInfo(
source_name=doc_info["title"], source_id=doc_id,
source_type=doc_info.get("format", "local"), source_name=doc_info["title"],
document_count=len(doc_info.get("chunk_ids", [])), source_type=doc_info.get("format", "local"),
last_updated=doc_info.get("created_at"), document_count=len(doc_info.get("chunk_ids", [])),
)) last_updated=doc_info.get("created_at"),
)
)
return sources return sources
async def health_check(self) -> bool: async def health_check(self) -> bool:

View File

@ -13,7 +13,6 @@ import asyncio
import hashlib import hashlib
import logging import logging
from dataclasses import replace from dataclasses import replace
from typing import Any
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
@ -186,15 +185,13 @@ class MultiSourceRetriever:
Returns: Returns:
所有源的检索结果列表已应用权重 所有源的检索结果列表已应用权重
""" """
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]: async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
try: try:
results = await kb.query(query, top_k=top_k) results = await kb.query(query, top_k=top_k)
# 应用权重 # 应用权重
weight = (weights or {}).get(name, 1.0) weight = (weights or {}).get(name, 1.0)
return [ return [replace(r, score=r.score * weight, source_name=name) for r in results]
replace(r, score=r.score * weight, source_name=name)
for r in results
]
except Exception as e: except Exception as e:
logger.error(f"Query failed for source '{name}': {e}") logger.error(f"Query failed for source '{name}': {e}")
return [] return []

View File

@ -7,10 +7,10 @@
from __future__ import annotations from __future__ import annotations
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Callable
class MemoryFile: class MemoryFile:
@ -26,8 +26,9 @@ class MemoryFile:
""" """
def __init__(self, path: Path, char_budget: int | None = None, def __init__(
protected_sections: set[str] | None = None): self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
):
self.path = Path(path) self.path = Path(path)
self.char_budget = char_budget self.char_budget = char_budget
self._protected_sections = protected_sections or set() self._protected_sections = protected_sections or set()
@ -138,7 +139,7 @@ class MemoryFile:
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE): for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
name = match.group(1).strip() name = match.group(1).strip()
start = match.start() start = match.start()
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE) next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
if next_match: if next_match:
end = match.end() + next_match.start() end = match.end() + next_match.start()
else: else:
@ -146,7 +147,7 @@ class MemoryFile:
sections.append((name, start, end)) sections.append((name, start, end))
if not sections: if not sections:
return content[:self.char_budget] return content[: self.char_budget]
# 保持原始顺序,标记每个 section 是否受保护 # 保持原始顺序,标记每个 section 是否受保护
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected) ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
@ -222,8 +223,9 @@ class MemoryStore:
""" """
def __init__(self, base_dir: Path | str | None = None, def __init__(
on_change: Callable[[str], None] | None = None): self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
):
if base_dir is None: if base_dir is None:
base_dir = Path.home() / ".agentkit" base_dir = Path.home() / ".agentkit"
self.base_dir = Path(base_dir) self.base_dir = Path(base_dir)
@ -238,7 +240,9 @@ class MemoryStore:
protected_sections={"版本", "更新历史"}, protected_sections={"版本", "更新历史"},
) )
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET) self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET) self._memory = MemoryFile(
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
)
self._daily_dir = self.base_dir / "memories" / "daily" self._daily_dir = self.base_dir / "memories" / "daily"
self._daily_dir.mkdir(parents=True, exist_ok=True) self._daily_dir.mkdir(parents=True, exist_ok=True)
@ -376,4 +380,5 @@ class MemoryStore:
self._on_change(new_prompt) self._on_change(new_prompt)
except Exception: except Exception:
import logging import logging
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True) logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)

View File

@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
""" """
_FILLER_WORDS_CN: list[str] = [ _FILLER_WORDS_CN: list[str] = [
"帮我", "", "一下", "分析", "看看", "告诉我", "想知道", "请问", "帮我",
"",
"一下",
"分析",
"看看",
"告诉我",
"想知道",
"请问",
] ]
_FILLER_WORDS_EN: list[str] = [ _FILLER_WORDS_EN: list[str] = [
"please", "can you", "help me", "could you", "i want to", "i need to", "please",
"can you",
"help me",
"could you",
"i want to",
"i need to",
] ]
def __init__( def __init__(
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
self._synonyms = synonyms or {} self._synonyms = synonyms or {}
self._max_sub_queries = max_sub_queries self._max_sub_queries = max_sub_queries
# Pre-compile filler patterns # Pre-compile filler patterns
self._filler_patterns_cn = [ self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
]
self._filler_patterns_en = [ self._filler_patterns_en = [
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
] ]
@ -166,7 +176,9 @@ def create_query_transformer(
"""工厂函数:根据策略创建查询改写器""" """工厂函数:根据策略创建查询改写器"""
if strategy == "llm": if strategy == "llm":
if llm_gateway is None: if llm_gateway is None:
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp") logger.warning(
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
)
return NoOpQueryTransformer() return NoOpQueryTransformer()
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries) return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
elif strategy == "rule": elif strategy == "rule":

View File

@ -7,11 +7,11 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any from typing import TYPE_CHECKING
from agentkit.memory.base import MemoryItem from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import ( from agentkit.memory.relevance_scorer import (
RelevanceScorer, RelevanceScorer,
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
RetrievalEvaluation, RetrievalEvaluation,
) )
if TYPE_CHECKING:
# 避免与 retriever.py 形成运行时循环导入。
from agentkit.memory.retriever import MemoryRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
def __init__( def __init__(
self, self,
retriever: Any, # MemoryRetriever retriever: MemoryRetriever,
scorer: RelevanceScorer | None = None, scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None, query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3, max_retries: int = 3,
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
query: str, query: str,
top_k: int = 5, top_k: int = 5,
token_budget: int = 3000, token_budget: int = 3000,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
) -> RAGLoopResult: ) -> RAGLoopResult:
"""执行带自纠正的检索 """执行带自纠正的检索
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
while retry_count <= self._max_retries: while retry_count <= self._max_retries:
# RETRIEVE # RETRIEVE
items = await self._retriever.retrieve( items = await self._retriever.retrieve(
current_query, top_k=top_k, token_budget=token_budget, current_query,
filters=filters, _skip_correction=True, top_k=top_k,
token_budget=token_budget,
filters=filters,
_skip_correction=True,
) )
# EVALUATE # EVALUATE
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
# CORRECT — rewrite query and retry # CORRECT — rewrite query and retry
retry_count += 1 retry_count += 1
if retry_count <= self._max_retries: if retry_count <= self._max_retries:
current_query = await self._rewrite_query( current_query = await self._rewrite_query(query, current_query, evaluation)
query, current_query, evaluation
)
continue continue
# DEGRADE — exceeded max retries # DEGRADE — exceeded max retries
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
# Degraded result: filter to relevant items and mark low confidence # Degraded result: filter to relevant items and mark low confidence
relevant_items = [ relevant_items = [
s.item s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
for s in evaluation.scores
if s.verdict != RelevanceVerdict.INCORRECT
] ]
result_items = relevant_items if relevant_items else items result_items = relevant_items if relevant_items else items

View File

@ -6,11 +6,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import math
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem from agentkit.memory.base import MemoryItem
@ -120,9 +118,7 @@ class RelevanceScorer:
reason=reason, reason=reason,
) )
def evaluate( def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
self, query: str, items: list[MemoryItem]
) -> RetrievalEvaluation:
"""评估一次检索的整体质量""" """评估一次检索的整体质量"""
if not items: if not items:
return RetrievalEvaluation( return RetrievalEvaluation(
@ -134,9 +130,7 @@ class RelevanceScorer:
) )
scores = [self.score_item(query, item) for item in items] scores = [self.score_item(query, item) for item in items]
relevant_count = sum( relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
)
avg_score = sum(s.score for s in scores) / len(scores) avg_score = sum(s.score for s in scores) / len(scores)
# Overall verdict based on average score and relevant ratio # Overall verdict based on average score and relevant ratio

View File

@ -7,19 +7,16 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import math
from dataclasses import replace from dataclasses import replace
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.working import WorkingMemory from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.query_transformer import QueryTransformerBase from agentkit.memory.query_transformer import QueryTransformerBase
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
from agentkit.memory.relevance_scorer import RelevanceScorer from agentkit.memory.relevance_scorer import RelevanceScorer
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult from agentkit.memory.knowledge_base import KnowledgeBase
from agentkit.memory.multi_source_retriever import MultiSourceRetriever from agentkit.memory.multi_source_retriever import MultiSourceRetriever
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
Chinese characters typically use 1-2 tokens each. Chinese characters typically use 1-2 tokens each.
English words typically use 1 token each. English words typically use 1 token each.
""" """
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
non_cjk = text non_cjk = text
for c in text: for c in text:
if '\u4e00' <= c <= '\u9fff': if "\u4e00" <= c <= "\u9fff":
non_cjk = non_cjk.replace(c, ' ') non_cjk = non_cjk.replace(c, " ")
word_count = len(non_cjk.split()) word_count = len(non_cjk.split())
return cjk_count * 2 + word_count return cjk_count * 2 + word_count
@ -89,7 +86,7 @@ class MemoryRetriever:
query: str, query: str,
top_k: int = 5, top_k: int = 5,
token_budget: int = 3000, token_budget: int = 3000,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
_skip_correction: bool = False, _skip_correction: bool = False,
sources: list[str] | None = None, sources: list[str] | None = None,
source_weights: dict[str, float] | None = None, source_weights: dict[str, float] | None = None,
@ -121,9 +118,7 @@ class MemoryRetriever:
query, top_k=top_k, token_budget=token_budget, filters=filters query, top_k=top_k, token_budget=token_budget, filters=filters
) )
if result.degraded: if result.degraded:
logger.warning( logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
f"RAG self-correction degraded after {result.total_retries} retries"
)
return result.items return result.items
# Query transformation # Query transformation
if self._query_transformer is not None: if self._query_transformer is not None:
@ -139,9 +134,7 @@ class MemoryRetriever:
# Sub-query search in parallel # Sub-query search in parallel
if sub_queries: if sub_queries:
sub_tasks = [ sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
self._search_layers(sq, top_k, filters) for sq in sub_queries
]
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True) sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
for result in sub_results: for result in sub_results:
if isinstance(result, Exception): if isinstance(result, Exception):
@ -178,7 +171,7 @@ class MemoryRetriever:
self, self,
query: str, query: str,
top_k: int = 5, top_k: int = 5,
filters: dict[str, Any] | None = None, filters: MetadataDict | None = None,
) -> list[MemoryItem]: ) -> list[MemoryItem]:
"""Search all configured memory layers with a single query""" """Search all configured memory layers with a single query"""
tasks = [] tasks = []
@ -237,18 +230,20 @@ class MemoryRetriever:
# QueryResult → MemoryItem # QueryResult → MemoryItem
items = [] items = []
for r in kb_results: for r in kb_results:
items.append(MemoryItem( items.append(
key=r.source_id, MemoryItem(
value=r.content, key=r.source_id,
metadata={ value=r.content,
**r.metadata, metadata={
"source": "rag", **r.metadata,
"source_name": r.source_name, "source": "rag",
"doc_id": r.doc_id, "source_name": r.source_name,
"document_title": r.title, "doc_id": r.doc_id,
}, "document_title": r.title,
score=r.score, },
)) score=r.score,
)
)
# Token 预算管理 # Token 预算管理
selected = [] selected = []
@ -318,7 +313,9 @@ class MemoryRetriever:
if source == "rag": if source == "rag":
kb_type = item.metadata.get("kb_type", "知识库") kb_type = item.metadata.get("kb_type", "知识库")
document_title = item.metadata.get("document_title", "未知文档") document_title = item.metadata.get("document_title", "未知文档")
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]" return (
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
)
elif source == "graph": elif source == "graph":
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]" return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
elif source == "episodic": elif source == "episodic":
@ -330,7 +327,7 @@ class MemoryRetriever:
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]" return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
async def store_episode( async def store_episode(
self, key: str, value: Any, metadata: dict[str, Any] | None = None self, key: str, value: object, metadata: MetadataDict | None = None
) -> None: ) -> None:
"""Store an episode into episodic memory if available. """Store an episode into episodic memory if available.
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
items = await self._retriever.retrieve(query, top_k=5) items = await self._retriever.retrieve(query, top_k=5)
results = [] results = []
for item in items: for item in items:
results.append({ results.append(
"content": item.value, {
"score": item.score, "content": item.value,
"source": item.metadata.get("source", "unknown"), "score": item.score,
"document_title": item.metadata.get("document_title", ""), "source": item.metadata.get("source", "unknown"),
}) "document_title": item.metadata.get("document_title", ""),
}
)
return {"query": query, "results": results, "call_count": self._call_count} return {"query": query, "results": results, "call_count": self._call_count}
except Exception as e: except Exception as e:
return {"error": str(e), "results": []} return {"error": str(e), "results": []}

View File

@ -3,14 +3,45 @@
适配器模式对接外部 RAG 服务和知识图谱 适配器模式对接外部 RAG 服务和知识图谱
""" """
import logging from __future__ import annotations
from typing import Any
from agentkit.memory.base import Memory, MemoryItem import logging
from typing import TYPE_CHECKING, Protocol
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
if TYPE_CHECKING:
from agentkit.memory.http_rag import RAGSearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _RAGServiceLike(Protocol):
"""RAG 检索服务最小接口契约duck-typed"""
async def search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
) -> list[RAGSearchResult]: ...
async def enhanced_search(
self,
query: str,
knowledge_base_ids: list[str] | None = ...,
top_k: int = ...,
use_rerank: bool = ...,
use_compression: bool = ...,
) -> list[RAGSearchResult]: ...
class _GraphServiceLike(Protocol):
"""知识图谱服务最小接口契约duck-typed"""
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
class SemanticMemory(Memory): class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索 """Semantic Memory - 知识库检索
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
def __init__( def __init__(
self, self,
rag_service: Any = None, rag_service: _RAGServiceLike | None = None,
graph_service: Any = None, graph_service: _GraphServiceLike | None = None,
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
search_mode: str = "standard", search_mode: str = "standard",
use_rerank: bool = True, use_rerank: bool = True,
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
self._use_compression = use_compression self._use_compression = use_compression
self._kb_weights = kb_weights self._kb_weights = kb_weights
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法""" """Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, 'ingest'): if self._rag_service and hasattr(self._rag_service, "ingest"):
await self._rag_service.ingest(key, value, metadata) await self._rag_service.ingest(key, value, metadata)
else: else:
logger.warning("SemanticMemory.store: no RAG service configured for writing") logger.warning("SemanticMemory.store: no RAG service configured for writing")
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
"""按 key 精确检索Semantic Memory 通常不按 key 检索)""" """按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""语义检索知识库""" """语义检索知识库"""
items = [] items = []
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
if self._rag_service: if self._rag_service:
try: try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids) kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"): if self._search_mode == "enhanced" and hasattr(
self._rag_service, "enhanced_search"
):
results = await self._rag_service.enhanced_search( results = await self._rag_service.enhanced_search(
query, query,
knowledge_base_ids=kb_ids, knowledge_base_ids=kb_ids,
@ -73,24 +108,28 @@ class SemanticMemory(Memory):
use_compression=self._use_compression, use_compression=self._use_compression,
) )
else: else:
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k) results = await self._rag_service.search(
query, knowledge_base_ids=kb_ids, top_k=top_k
)
for r in results: for r in results:
kb_id = r.get("knowledge_base_id", "") kb_id = r.get("knowledge_base_id", "")
score = r.get("score", 0.0) score = r.get("score", 0.0)
# Apply per-KB weights # Apply per-KB weights
if self._kb_weights and kb_id in self._kb_weights: if self._kb_weights and kb_id in self._kb_weights:
score *= self._kb_weights[kb_id] score *= self._kb_weights[kb_id]
items.append(MemoryItem( items.append(
key=r.get("id", ""), MemoryItem(
value=r.get("content", ""), key=r.get("id", ""),
metadata={ value=r.get("content", ""),
"source": r.get("source", "rag"), metadata={
"score": score, "source": r.get("source", "rag"),
"document_id": r.get("document_id"), "score": score,
"knowledge_base_id": kb_id, "document_id": r.get("document_id"),
}, "knowledge_base_id": kb_id,
score=score, },
)) score=score,
)
)
except Exception as e: except Exception as e:
logger.error(f"RAG search failed: {e}") logger.error(f"RAG search failed: {e}")
@ -99,16 +138,18 @@ class SemanticMemory(Memory):
try: try:
graph_results = await self._graph_service.query(query, depth=2) graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]: for r in graph_results[:top_k]:
items.append(MemoryItem( items.append(
key=r.get("id", ""), MemoryItem(
value=r.get("content", ""), key=r.get("id", ""),
metadata={ value=r.get("content", ""),
"source": "graph", metadata={
"entities": r.get("entities", []), "source": "graph",
"relations": r.get("relations", []), "entities": r.get("entities", []),
}, "relations": r.get("relations", []),
score=r.get("score", 0.0), },
)) score=r.get("score", 0.0),
)
)
except Exception as e: except Exception as e:
logger.error(f"Graph search failed: {e}") logger.error(f"Graph search failed: {e}")

View File

@ -3,11 +3,10 @@
import json import json
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis import redis.asyncio as aioredis
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import Memory, MemoryItem, MetadataDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
def _make_key(self, key: str) -> str: def _make_key(self, key: str) -> str:
return f"{self._key_prefix}:{key}" return f"{self._key_prefix}:{key}"
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
redis_key = self._make_key(key) redis_key = self._make_key(key)
item = MemoryItem( item = MemoryItem(
key=key, key=key,
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
value=item_dict["value"], value=item_dict["value"],
metadata=item_dict.get("metadata", {}), metadata=item_dict.get("metadata", {}),
score=item_dict.get("score", 1.0), score=item_dict.get("score", 1.0),
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc), created_at=datetime.fromisoformat(item_dict["created_at"])
if item_dict.get("created_at")
else datetime.now(timezone.utc),
) )
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: async def search(
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
) -> list[MemoryItem]:
"""Working Memory 不支持语义检索,按 key 前缀匹配""" """Working Memory 不支持语义检索,按 key 前缀匹配"""
pattern = self._make_key(f"{query}*") pattern = self._make_key(f"{query}*")
keys = [] keys = []
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
data = await self._redis.get(key) data = await self._redis.get(key)
if data: if data:
item_dict = json.loads(data) item_dict = json.loads(data)
items.append(MemoryItem( items.append(
key=item_dict["key"], MemoryItem(
value=item_dict["value"], key=item_dict["key"],
metadata=item_dict.get("metadata", {}), value=item_dict["value"],
score=1.0, metadata=item_dict.get("metadata", {}),
created_at=datetime.now(timezone.utc), score=1.0,
)) created_at=datetime.now(timezone.utc),
)
)
return items return items
async def delete(self, key: str) -> bool: async def delete(self, key: str) -> bool:

View File

@ -17,7 +17,7 @@ import logging
import time import time
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days
_KEY_PREFIX = "agentkit:pipeline:checkpoint" _KEY_PREFIX = "agentkit:pipeline:checkpoint"
if TYPE_CHECKING:
from typing import Protocol
class _RedisPipelineLike(Protocol):
def set(self, key: str, value: str, ex: int | None = None) -> object: ...
def zadd(self, name: str, mapping: dict[str, float]) -> object: ...
def get(self, key: str) -> object: ...
def delete(self, *keys: str) -> object: ...
async def execute(self) -> list[object]: ...
class _RedisLike(Protocol):
async def set(self, key: str, value: str, ex: int | None = None) -> object: ...
async def get(self, key: str) -> object: ...
def pipeline(self) -> _RedisPipelineLike: ...
async def zrange(self, name: str, start: int, stop: int) -> list[object]: ...
class _PlanLike(Protocol):
@property
def id(self) -> str: ...
def to_dict(self) -> dict[str, object]: ...
@dataclass @dataclass
class CheckpointData: class CheckpointData:
"""单个阶段的 checkpoint 数据。""" """单个阶段的 checkpoint 数据。"""
@ -33,15 +55,15 @@ class CheckpointData:
phase_id: str phase_id: str
phase_name: str phase_name: str
phase_status: str phase_status: str
phase_result: dict[str, Any] | None = None phase_result: dict[str, object] | None = None
plan_status: str = "" plan_status: str = ""
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, object]:
return asdict(self) return asdict(self)
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> CheckpointData: def from_dict(cls, data: dict[str, object]) -> CheckpointData:
return cls( return cls(
plan_id=data.get("plan_id", ""), plan_id=data.get("plan_id", ""),
phase_id=data.get("phase_id", ""), phase_id=data.get("phase_id", ""),
@ -67,7 +89,7 @@ class PipelineCheckpoint:
def __init__( def __init__(
self, self,
redis_client: Any = None, redis_client: _RedisLike | None = None,
prefix: str = _KEY_PREFIX, prefix: str = _KEY_PREFIX,
ttl_seconds: int = _TTL_SECONDS, ttl_seconds: int = _TTL_SECONDS,
) -> None: ) -> None:
@ -78,7 +100,7 @@ class PipelineCheckpoint:
# P1 #6: 改用 dict keyed by phase_id避免重复 append # P1 #6: 改用 dict keyed by phase_id避免重复 append
self._memory: dict[str, dict[str, CheckpointData]] = {} self._memory: dict[str, dict[str, CheckpointData]] = {}
# 内存降级存储plan_id → (plan_dict, saved_timestamp) # 内存降级存储plan_id → (plan_dict, saved_timestamp)
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {} self._memory_plans: dict[str, tuple[dict[str, object], float]] = {}
def _is_expired(self, saved_at: str) -> bool: def _is_expired(self, saved_at: str) -> bool:
"""检查 checkpoint 是否已过期(内存模式 TTL""" """检查 checkpoint 是否已过期(内存模式 TTL"""
@ -102,7 +124,7 @@ class PipelineCheckpoint:
"""完整 plan JSON 的存储键。""" """完整 plan JSON 的存储键。"""
return f"{self._prefix}:plan:{plan_id}" return f"{self._prefix}:plan:{plan_id}"
async def save_plan(self, plan: Any) -> None: async def save_plan(self, plan: _PlanLike) -> None:
"""保存完整 TeamPlan用于 resume 重建)。 """保存完整 TeamPlan用于 resume 重建)。
Args: Args:
@ -121,7 +143,7 @@ class PipelineCheckpoint:
except Exception as e: except Exception as e:
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}") logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
async def load_plan(self, plan_id: str) -> dict[str, Any] | None: async def load_plan(self, plan_id: str) -> dict[str, object] | None:
"""加载完整 plan JSON。""" """加载完整 plan JSON。"""
# 优先 Redis # 优先 Redis
if self._redis is not None: if self._redis is not None:
@ -142,7 +164,7 @@ class PipelineCheckpoint:
return None return None
return plan_dict return plan_dict
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: async def save(self, plan_id: str, phase: object, plan_status: str) -> None:
"""保存阶段 checkpoint。 """保存阶段 checkpoint。
Args: Args:
@ -212,7 +234,8 @@ class PipelineCheckpoint:
if not phase_ids: if not phase_ids:
# Redis 无数据,检查内存(过滤过期) # Redis 无数据,检查内存(过滤过期)
return [ return [
c for c in self._memory.get(plan_id, {}).values() c
for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at) if not self._is_expired(c.saved_at)
] ]
@ -236,8 +259,7 @@ class PipelineCheckpoint:
# 内存降级(过滤过期 checkpoint # 内存降级(过滤过期 checkpoint
return [ return [
c for c in self._memory.get(plan_id, {}).values() c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at)
if not self._is_expired(c.saved_at)
] ]
async def clear(self, plan_id: str) -> None: async def clear(self, plan_id: str) -> None:

View File

@ -1,8 +1,8 @@
"""Saga compensation pattern for Pipeline execution""" """Saga compensation pattern for Pipeline execution"""
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any, Awaitable, Callable from typing import Awaitable, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -12,7 +12,7 @@ class CompletedStep:
"""Record of a completed step with its compensation""" """Record of a completed step with its compensation"""
step_name: str step_name: str
result: Any result: object
compensate_action: str | None = None compensate_action: str | None = None
@ -28,9 +28,7 @@ class CompensationResult:
class SagaOrchestrator: class SagaOrchestrator:
"""Orchestrates LIFO compensation for failed pipelines""" """Orchestrates LIFO compensation for failed pipelines"""
def __init__( def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None):
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
):
""" """
Args: Args:
execute_skill_func: Async function to execute a skill by name execute_skill_func: Async function to execute a skill by name
@ -42,7 +40,7 @@ class SagaOrchestrator:
def record_completed( def record_completed(
self, self,
step_name: str, step_name: str,
result: Any, result: object,
compensate_action: str | None = None, compensate_action: str | None = None,
): ):
"""Record a completed step for potential compensation""" """Record a completed step for potential compensation"""
@ -59,9 +57,7 @@ class SagaOrchestrator:
results: list[CompensationResult] = [] results: list[CompensationResult] = []
for step in reversed(self._completed_steps): for step in reversed(self._completed_steps):
if step.compensate_action is None: if step.compensate_action is None:
logger.info( logger.info(f"No compensation for step '{step.step_name}', skipping")
f"No compensation for step '{step.step_name}', skipping"
)
results.append( results.append(
CompensationResult( CompensationResult(
step_name=step.step_name, step_name=step.step_name,
@ -82,9 +78,7 @@ class SagaOrchestrator:
) )
) )
except Exception as e: except Exception as e:
logger.error( logger.error(f"Compensation for step '{step.step_name}' failed: {e}")
f"Compensation for step '{step.step_name}' failed: {e}"
)
results.append( results.append(
CompensationResult( CompensationResult(
step_name=step.step_name, step_name=step.step_name,

View File

@ -4,7 +4,6 @@
""" """
import logging import logging
from typing import Any
from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
class DynamicPipeline: class DynamicPipeline:
"""动态 Pipeline 组合器""" """动态 Pipeline 组合器"""
def __init__(self, engine: PipelineEngine, loader: Any = None): def __init__(self, engine: PipelineEngine, loader: object | None = None):
self._engine = engine self._engine = engine
self._loader = loader self._loader = loader
@ -23,7 +22,7 @@ class DynamicPipeline:
self, self,
pipelines: dict[str, Pipeline], pipelines: dict[str, Pipeline],
condition_key: str, condition_key: str,
context: dict[str, Any] | None = None, context: dict[str, object] | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""根据条件选择子 Pipeline 执行""" """根据条件选择子 Pipeline 执行"""
context = context or {} context = context or {}
@ -37,14 +36,16 @@ class DynamicPipeline:
) )
selected = pipelines[condition_value] selected = pipelines[condition_value]
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}") logger.info(
f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}"
)
return await self._engine.execute(selected, context) return await self._engine.execute(selected, context)
async def execute_nested( async def execute_nested(
self, self,
parent: Pipeline, parent: Pipeline,
sub_pipeline_map: dict[str, Pipeline], sub_pipeline_map: dict[str, Pipeline],
context: dict[str, Any] | None = None, context: dict[str, object] | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""执行嵌套 Pipeline""" """执行嵌套 Pipeline"""
# 先执行父 Pipeline # 先执行父 Pipeline
@ -52,7 +53,7 @@ class DynamicPipeline:
# 根据父 Pipeline 结果选择子 Pipeline # 根据父 Pipeline 结果选择子 Pipeline
for stage_name, stage_result in parent_result.stage_results.items(): for stage_name, stage_result in parent_result.stage_results.items():
if hasattr(stage_result, 'output_data') and stage_result.output_data: if hasattr(stage_result, "output_data") and stage_result.output_data:
sub_pipeline_name = stage_result.output_data.get("sub_pipeline") sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map: if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
sub = sub_pipeline_map[sub_pipeline_name] sub = sub_pipeline_map[sub_pipeline_name]
@ -66,7 +67,7 @@ class DynamicPipeline:
pipeline: Pipeline, pipeline: Pipeline,
max_iterations: int = 5, max_iterations: int = 5,
exit_condition: str = "done", exit_condition: str = "done",
context: dict[str, Any] | None = None, context: dict[str, object] | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""循环执行 Pipeline 直到条件满足""" """循环执行 Pipeline 直到条件满足"""
current_context = context or {} current_context = context or {}

View File

@ -3,25 +3,42 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any from typing import Awaitable, Callable, Protocol
from agentkit.core.protocol import HandoffMessage from agentkit.core.protocol import HandoffMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _RedisPubSubLike(Protocol):
"""Structural type for Redis pubsub object."""
async def subscribe(self, channel: str) -> None: ...
async def unsubscribe(self, channel: str) -> None: ...
def listen(self) -> object: ...
class _RedisLike(Protocol):
"""Structural type for async Redis client."""
async def publish(self, channel: str, message: str) -> int: ...
def pubsub(self) -> _RedisPubSubLike: ...
class HandoffManager: class HandoffManager:
"""Handoff 管理器 """Handoff 管理器
通过 Redis Pub/Sub 管理 Agent 间的任务转交 通过 Redis Pub/Sub 管理 Agent 间的任务转交
""" """
def __init__(self, redis: Any = None, dispatcher: Any = None): def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None):
self._redis = redis self._redis = redis
self._dispatcher = dispatcher self._dispatcher = dispatcher
self._handlers: dict[str, list[Any]] = {} self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {}
def register_handler(self, agent_name: str, handler: Any) -> None: def register_handler(
self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]]
) -> None:
"""注册 Handoff 处理器""" """注册 Handoff 处理器"""
if agent_name not in self._handlers: if agent_name not in self._handlers:
self._handlers[agent_name] = [] self._handlers[agent_name] = []

View File

@ -1,10 +1,12 @@
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿""" """Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import TYPE_CHECKING
from agentkit.orchestrator.compensation import SagaOrchestrator from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import ( from agentkit.orchestrator.pipeline_schema import (
@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _DispatcherLike(Protocol):
async def dispatch(self, task: object) -> None: ...
async def get_task_status(self, task_id: str) -> dict[str, object]: ...
class _StateManagerLike(Protocol):
async def create_execution(self, **kwargs: object) -> str: ...
async def update_step(self, **kwargs: object) -> None: ...
async def complete_execution(self, **kwargs: object) -> None: ...
async def fail_execution(self, **kwargs: object) -> None: ...
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineEngine: class PipelineEngine:
"""Pipeline 执行引擎 """Pipeline 执行引擎
@ -38,7 +57,12 @@ class PipelineEngine:
- 状态持久化可选 - 状态持久化可选
""" """
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None): def __init__(
self,
dispatcher: _DispatcherLike | None = None,
state_manager: _StateManagerLike | None = None,
llm_gateway: _LLMGatewayLike | None = None,
):
self._dispatcher = dispatcher self._dispatcher = dispatcher
self._state_manager = state_manager self._state_manager = state_manager
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
@ -46,7 +70,7 @@ class PipelineEngine:
async def execute( async def execute(
self, self,
pipeline: Pipeline, pipeline: Pipeline,
context: dict[str, Any] | None = None, context: dict[str, object] | None = None,
adaptive_config: AdaptiveConfig | None = None, adaptive_config: AdaptiveConfig | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""执行 Pipeline """执行 Pipeline
@ -68,7 +92,7 @@ class PipelineEngine:
async def _adaptive_loop( async def _adaptive_loop(
self, self,
pipeline: Pipeline, pipeline: Pipeline,
context: dict[str, Any] | None, context: dict[str, object] | None,
failed_result: PipelineResult, failed_result: PipelineResult,
adaptive_config: AdaptiveConfig, adaptive_config: AdaptiveConfig,
) -> PipelineResult: ) -> PipelineResult:
@ -92,34 +116,30 @@ class PipelineEngine:
# Replan # Replan
new_pipeline = await replanner.replan(current_pipeline, current_result, report) new_pipeline = await replanner.replan(current_pipeline, current_result, report)
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)") logger.info(
f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)"
)
# Re-execute # Re-execute
current_result = await self._execute_pipeline(new_pipeline, context) current_result = await self._execute_pipeline(new_pipeline, context)
current_pipeline = new_pipeline current_pipeline = new_pipeline
# Record reflection in metadata # Record reflection in metadata
current_result.metadata["reflections"] = [ current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
r.model_dump() for r in reflections
]
if current_result.status == StageStatus.COMPLETED: if current_result.status == StageStatus.COMPLETED:
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)") logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
return current_result return current_result
# Exhausted reflections # Exhausted reflections
logger.warning( logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)")
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)" current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
)
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
return current_result return current_result
async def _execute_pipeline( async def _execute_pipeline(
self, self,
pipeline: Pipeline, pipeline: Pipeline,
context: dict[str, Any] | None = None, context: dict[str, object] | None = None,
) -> PipelineResult: ) -> PipelineResult:
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。""" """执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
result = PipelineResult(pipeline_name=pipeline.name) result = PipelineResult(pipeline_name=pipeline.name)
@ -151,7 +171,9 @@ class PipelineEngine:
# 逐层执行 # 逐层执行
for level, stages in enumerate(level_groups): for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") logger.info(
f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)"
)
# 并行执行同层 stages # 并行执行同层 stages
tasks = [] tasks = []
@ -173,9 +195,11 @@ class PipelineEngine:
# Update step state # Update step state
if self._state_manager is not None and execution_id is not None: if self._state_manager is not None and execution_id is not None:
try: try:
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value step_status = (
step_output = sr.output_data if hasattr(sr, 'output_data') else None "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
step_error = sr.error_message if hasattr(sr, 'error_message') else None )
step_output = sr.output_data if hasattr(sr, "output_data") else None
step_error = sr.error_message if hasattr(sr, "error_message") else None
await self._state_manager.update_step( await self._state_manager.update_step(
execution_id=execution_id, execution_id=execution_id,
step_name=stage.name, step_name=stage.name,
@ -189,19 +213,21 @@ class PipelineEngine:
# 收集输出变量 # 收集输出变量
if sr.output_data and isinstance(sr, dict): if sr.output_data and isinstance(sr, dict):
pass pass
elif hasattr(sr, 'output_data') and sr.output_data: elif hasattr(sr, "output_data") and sr.output_data:
for output_key in stage.outputs: for output_key in stage.outputs:
if output_key in sr.output_data: if output_key in sr.output_data:
result.variables[output_key] = sr.output_data[output_key] result.variables[output_key] = sr.output_data[output_key]
# 检查是否需要中止 # 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: if hasattr(sr, "status") and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure: if not stage.continue_on_failure:
# Execute Saga compensation for completed steps # Execute Saga compensation for completed steps
compensation_results = await saga.compensate() compensation_results = await saga.compensate()
if compensation_results: if compensation_results:
failed_compensations = [ failed_compensations = [
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed" cr
for cr in compensation_results
if not cr.success and cr.error != "no_compensation_needed"
] ]
if failed_compensations: if failed_compensations:
logger.warning( logger.warning(
@ -219,7 +245,12 @@ class PipelineEngine:
step_name=stage.name, step_name=stage.name,
error=result.error_message, error=result.error_message,
) )
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc: except (
asyncio.TimeoutError,
ConnectionError,
RuntimeError,
ValueError,
) as exc:
logger.warning(f"Failed to persist failure state: {exc}") logger.warning(f"Failed to persist failure state: {exc}")
return result return result
@ -252,7 +283,9 @@ class PipelineEngine:
started_at = datetime.now(timezone.utc).isoformat() started_at = datetime.now(timezone.utc).isoformat()
# 条件检查 # 条件检查
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables): if stage.condition and not self._evaluate_condition(
stage.condition, pipeline_result.variables
):
return StageResult( return StageResult(
stage_name=stage.name, stage_name=stage.name,
status=StageStatus.SKIPPED, status=StageStatus.SKIPPED,
@ -312,7 +345,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"): if status["status"] in ("completed", "failed", "cancelled"):
return StageResult( return StageResult(
stage_name=stage.name, stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"), output_data=status.get("output_data"),
error_message=status.get("error_message"), error_message=status.get("error_message"),
started_at=started_at, started_at=started_at,
@ -406,7 +441,7 @@ class PipelineEngine:
return resolved return resolved
@staticmethod @staticmethod
def _get_nested(data: dict, path: str) -> Any: def _get_nested(data: dict, path: str) -> object:
keys = path.split(".") keys = path.split(".")
current = data current = data
for key in keys: for key in keys:
@ -497,9 +532,7 @@ class PipelineEngine:
if verifier_feedback.passed: if verifier_feedback.passed:
# 审查通过,返回成功结果 # 审查通过,返回成功结果
logger.info( logger.info(f"Stage '{stage.name}' passed review in round {round_num}")
f"Stage '{stage.name}' passed review in round {round_num}"
)
worker_result.output_data = worker_result.output_data or {} worker_result.output_data = worker_result.output_data or {}
worker_result.output_data["adversarial_metadata"] = { worker_result.output_data["adversarial_metadata"] = {
"passed_round": round_num, "passed_round": round_num,
@ -553,7 +586,7 @@ class PipelineEngine:
self, self,
agent_name: str, agent_name: str,
action: str, action: str,
input_data: dict[str, Any], input_data: dict[str, object],
stage: PipelineStage, stage: PipelineStage,
started_at: str, started_at: str,
timeout_seconds: int | None = None, timeout_seconds: int | None = None,
@ -568,7 +601,9 @@ class PipelineEngine:
started_at: 开始时间 started_at: 开始时间
timeout_seconds: 独立超时时间不传则使用 stage.timeout_seconds timeout_seconds: 独立超时时间不传则使用 stage.timeout_seconds
""" """
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds effective_timeout = (
timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
)
if self._dispatcher is None: if self._dispatcher is None:
# Dry-run 模式 # Dry-run 模式
return StageResult( return StageResult(
@ -602,7 +637,9 @@ class PipelineEngine:
if status["status"] in ("completed", "failed", "cancelled"): if status["status"] in ("completed", "failed", "cancelled"):
return StageResult( return StageResult(
stage_name=stage.name, stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, status=StageStatus.COMPLETED
if status["status"] == "completed"
else StageStatus.FAILED,
output_data=status.get("output_data"), output_data=status.get("output_data"),
error_message=status.get("error_message"), error_message=status.get("error_message"),
started_at=started_at, started_at=started_at,
@ -639,7 +676,7 @@ class PipelineEngine:
async def _execute_verifier( async def _execute_verifier(
self, self,
verifier_name: str, verifier_name: str,
worker_output: dict[str, Any], worker_output: dict[str, object],
stage: PipelineStage, stage: PipelineStage,
started_at: str, started_at: str,
) -> ReviewFeedback: ) -> ReviewFeedback:
@ -679,10 +716,7 @@ class PipelineEngine:
try: try:
feedback = ReviewFeedback( feedback = ReviewFeedback(
passed=output_data.get("passed", False), passed=output_data.get("passed", False),
issues=[ issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])],
ReviewIssue(**issue)
for issue in output_data.get("issues", [])
],
summary=output_data.get("summary", "No summary provided"), summary=output_data.get("summary", "No summary provided"),
score=output_data.get("score", 0.0), score=output_data.get("score", 0.0),
) )
@ -699,7 +733,7 @@ class PipelineEngine:
self, self,
feedback: ReviewFeedback, feedback: ReviewFeedback,
feedback_mode: str = "structured+natural", feedback_mode: str = "structured+natural",
) -> dict[str, Any]: ) -> dict[str, object]:
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复 """构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
Args: Args:
@ -720,7 +754,7 @@ class PipelineEngine:
for issue in feedback.issues for issue in feedback.issues
] ]
feedback_context: dict[str, Any] = { feedback_context: dict[str, object] = {
"previous_attempt_failed": True, "previous_attempt_failed": True,
} }
@ -756,7 +790,9 @@ class PipelineEngine:
) )
else: else:
# 未知模式fallback 到 structured+natural # 未知模式fallback 到 structured+natural
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural") logger.warning(
f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural"
)
feedback_context["review_feedback"] = { feedback_context["review_feedback"] = {
"summary": feedback.summary, "summary": feedback.summary,
"issues": issues_list, "issues": issues_list,

View File

@ -2,7 +2,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any
import yaml import yaml
@ -23,7 +22,9 @@ class PipelineLoader:
if not yaml_path.exists(): if not yaml_path.exists():
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml" yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
if not yaml_path.exists(): if not yaml_path.exists():
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}") raise FileNotFoundError(
f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}"
)
content = yaml_path.read_text(encoding="utf-8") content = yaml_path.read_text(encoding="utf-8")
return self.load_from_yaml(content, pipeline_name) return self.load_from_yaml(content, pipeline_name)

View File

@ -35,9 +35,7 @@ class PipelineExecutionModel(Base):
) )
completed_at = Column(DateTime) completed_at = Column(DateTime)
__table_args__ = ( __table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),)
Index("ix_pipeline_status_created", "status", "created_at"),
)
class PipelineStepHistoryModel(Base): class PipelineStepHistoryModel(Base):

View File

@ -1,7 +1,7 @@
"""Pipeline 数据模型""" """Pipeline 数据模型"""
from enum import Enum from enum import Enum
from typing import Any, Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -18,8 +18,11 @@ class StageStatus(str, Enum):
class ReviewIssue(BaseModel): class ReviewIssue(BaseModel):
"""单条审查问题""" """单条审查问题"""
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度") severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别") category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(
description="问题类别"
)
description: str = Field(min_length=1, description="问题描述") description: str = Field(min_length=1, description="问题描述")
location: str | None = Field(default=None, description="文件路径/行号") location: str | None = Field(default=None, description="文件路径/行号")
suggestion: str | None = Field(default=None, description="修复建议") suggestion: str | None = Field(default=None, description="修复建议")
@ -27,6 +30,7 @@ class ReviewIssue(BaseModel):
class ReviewFeedback(BaseModel): class ReviewFeedback(BaseModel):
"""Verifier 返回的结构化审查反馈""" """Verifier 返回的结构化审查反馈"""
passed: bool = Field(description="是否通过审查") passed: bool = Field(description="是否通过审查")
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表") issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
summary: str = Field(min_length=1, description="自然语言审查报告") summary: str = Field(min_length=1, description="自然语言审查报告")
@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel):
class AdversarialState(BaseModel): class AdversarialState(BaseModel):
"""对抗轮次状态追踪""" """对抗轮次状态追踪"""
current_round: int = Field(default=0, description="当前对抗轮次") current_round: int = Field(default=0, description="当前对抗轮次")
max_rounds: int = Field(default=3, description="最大对抗轮次") max_rounds: int = Field(default=3, description="最大对抗轮次")
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史") feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
@ -46,7 +51,7 @@ class PipelineStage(BaseModel):
agent: str agent: str
action: str action: str
depends_on: list[str] = [] depends_on: list[str] = []
inputs: dict[str, Any] = {} inputs: dict[str, object] = {}
outputs: list[str] = [] outputs: list[str] = []
timeout_seconds: int = 300 timeout_seconds: int = 300
retry_count: int = 0 retry_count: int = 0
@ -54,12 +59,19 @@ class PipelineStage(BaseModel):
condition: str | None = None condition: str | None = None
retry_policy: StepRetryPolicy | None = None retry_policy: StepRetryPolicy | None = None
compensate: str | None = None compensate: str | None = None
# 对抗闭环相关字段 # 对抗闭环相关字段
verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式") verifier: str | None = Field(
default=None, description="Verifier Agent 名称,配置后启用对抗模式"
)
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次") max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds") verifier_timeout_seconds: int = Field(
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式") default=120,
description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds",
)
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(
default="structured+natural", description="反馈模式"
)
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标") escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
@ -70,13 +82,13 @@ class Pipeline(BaseModel):
version: str version: str
description: str description: str
stages: list[PipelineStage] stages: list[PipelineStage]
variables: dict[str, Any] = {} variables: dict[str, object] = {}
class StageResult(BaseModel): class StageResult(BaseModel):
stage_name: str stage_name: str
status: StageStatus = StageStatus.PENDING status: StageStatus = StageStatus.PENDING
output_data: dict[str, Any] | None = None output_data: dict[str, object] | None = None
error_message: str | None = None error_message: str | None = None
started_at: str | None = None started_at: str | None = None
completed_at: str | None = None completed_at: str | None = None
@ -86,9 +98,9 @@ class PipelineResult(BaseModel):
pipeline_name: str pipeline_name: str
status: StageStatus = StageStatus.PENDING status: StageStatus = StageStatus.PENDING
stage_results: dict[str, StageResult] = {} stage_results: dict[str, StageResult] = {}
variables: dict[str, Any] = {} variables: dict[str, object] = {}
error_message: str | None = None error_message: str | None = None
metadata: dict[str, Any] = {} metadata: dict[str, object] = {}
class AdaptiveConfig(BaseModel): class AdaptiveConfig(BaseModel):

View File

@ -13,7 +13,7 @@ import json
import logging import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Callable, Coroutine from typing import Callable, Coroutine
from agentkit.orchestrator.pipeline_models import ( from agentkit.orchestrator.pipeline_models import (
PipelineExecutionModel, PipelineExecutionModel,
@ -183,7 +183,7 @@ class PipelineStateRedis:
return self._redis return self._redis
async def _safe_redis_call( async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object
) -> object | None: ) -> object | None:
"""Execute a Redis call, falling back to memory on failure. """Execute a Redis call, falling back to memory on failure.

View File

@ -4,22 +4,30 @@
生成修正后的 Pipeline 重新执行 生成修正后的 Pipeline 重新执行
""" """
from __future__ import annotations
import json import json
import logging import logging
from typing import Any from typing import TYPE_CHECKING
from agentkit.orchestrator.pipeline_schema import ( from agentkit.orchestrator.pipeline_schema import (
Pipeline, Pipeline,
PipelineResult, PipelineResult,
PipelineStage, PipelineStage,
ReflectionReport, ReflectionReport,
StageResult,
StageStatus, StageStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Protocol
class _LLMGatewayLike(Protocol):
async def chat(self, **kwargs: object) -> object: ...
class PipelineReflector: class PipelineReflector:
"""分析 Pipeline 执行失败原因,生成结构化反思报告。 """分析 Pipeline 执行失败原因,生成结构化反思报告。
@ -27,7 +35,7 @@ class PipelineReflector:
输出 ReflectionReport 包含 failure_typeroot_cause suggested_fix 输出 ReflectionReport 包含 failure_typeroot_cause suggested_fix
""" """
def __init__(self, llm_gateway: Any = None): def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
async def reflect( async def reflect(
@ -54,19 +62,25 @@ class PipelineReflector:
if self._llm_gateway is not None: if self._llm_gateway is not None:
try: try:
return await self._llm_reflect( return await self._llm_reflect(
pipeline, failed_stage, error_message, pipeline,
completed_outputs, reflection_number, failed_stage,
error_message,
completed_outputs,
reflection_number,
) )
except Exception as e: except Exception as e:
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}") logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
# 规则兜底:基于错误信息分类 # 规则兜底:基于错误信息分类
return self._rule_based_reflect( return self._rule_based_reflect(
failed_stage, error_message, reflection_number, failed_stage,
error_message,
reflection_number,
) )
def _find_failure( def _find_failure(
self, result: PipelineResult, self,
result: PipelineResult,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""找到第一个失败的 stage 及其错误信息。""" """找到第一个失败的 stage 及其错误信息。"""
for name, sr in result.stage_results.items(): for name, sr in result.stage_results.items():
@ -75,8 +89,9 @@ class PipelineReflector:
return "", "no failed stage found" return "", "no failed stage found"
def _collect_completed_outputs( def _collect_completed_outputs(
self, result: PipelineResult, self,
) -> dict[str, Any]: result: PipelineResult,
) -> dict[str, object]:
"""收集已完成步骤的输出。""" """收集已完成步骤的输出。"""
outputs = {} outputs = {}
for name, sr in result.stage_results.items(): for name, sr in result.stage_results.items():
@ -89,13 +104,16 @@ class PipelineReflector:
pipeline: Pipeline, pipeline: Pipeline,
failed_stage: str, failed_stage: str,
error_message: str, error_message: str,
completed_outputs: dict[str, Any], completed_outputs: dict[str, object],
reflection_number: int, reflection_number: int,
) -> ReflectionReport: ) -> ReflectionReport:
"""使用 LLM 分析失败原因。""" """使用 LLM 分析失败原因。"""
prompt = self._build_reflection_prompt( prompt = self._build_reflection_prompt(
pipeline, failed_stage, error_message, pipeline,
completed_outputs, reflection_number, failed_stage,
error_message,
completed_outputs,
reflection_number,
) )
response = await self._llm_gateway.chat( response = await self._llm_gateway.chat(
@ -106,7 +124,9 @@ class PipelineReflector:
# 解析 LLM 返回的 JSON # 解析 LLM 返回的 JSON
content = response.content if hasattr(response, "content") else str(response) content = response.content if hasattr(response, "content") else str(response)
return self._parse_reflection_response( return self._parse_reflection_response(
content, failed_stage, reflection_number, content,
failed_stage,
reflection_number,
) )
def _build_reflection_prompt( def _build_reflection_prompt(
@ -114,15 +134,14 @@ class PipelineReflector:
pipeline: Pipeline, pipeline: Pipeline,
failed_stage: str, failed_stage: str,
error_message: str, error_message: str,
completed_outputs: dict[str, Any], completed_outputs: dict[str, object],
reflection_number: int, reflection_number: int,
) -> str: ) -> str:
"""构建反思提示词。""" """构建反思提示词。"""
stage_descriptions = [] stage_descriptions = []
for s in pipeline.stages: for s in pipeline.stages:
stage_descriptions.append( stage_descriptions.append(
f" - {s.name}: agent={s.agent}, action={s.action}, " f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}"
f"depends_on={s.depends_on}"
) )
completed_summary = json.dumps( completed_summary = json.dumps(
@ -174,7 +193,9 @@ JSON response:"""
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM reflection response: {e}") logger.warning(f"Failed to parse LLM reflection response: {e}")
return self._rule_based_reflect( return self._rule_based_reflect(
failed_stage, content, reflection_number, failed_stage,
content,
reflection_number,
) )
def _rule_based_reflect( def _rule_based_reflect(
@ -218,7 +239,7 @@ class PipelineReplanner:
保留已完成步骤的结果仅重新规划失败及后续步骤 保留已完成步骤的结果仅重新规划失败及后续步骤
""" """
def __init__(self, llm_gateway: Any = None): def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
async def replan( async def replan(
@ -255,8 +276,7 @@ class PipelineReplanner:
) -> Pipeline: ) -> Pipeline:
"""使用 LLM 生成修正后的 Pipeline。""" """使用 LLM 生成修正后的 Pipeline。"""
completed_stages = [ completed_stages = [
name for name, sr in result.stage_results.items() name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
if sr.status == StageStatus.COMPLETED
] ]
prompt = f"""Based on the reflection report, generate a corrected pipeline. prompt = f"""Based on the reflection report, generate a corrected pipeline.
@ -284,7 +304,9 @@ JSON pipeline:"""
return self._parse_pipeline_response(content, pipeline) return self._parse_pipeline_response(content, pipeline)
def _parse_pipeline_response( def _parse_pipeline_response(
self, content: str, original: Pipeline, self,
content: str,
original: Pipeline,
) -> Pipeline: ) -> Pipeline:
"""解析 LLM 返回的 Pipeline JSON。""" """解析 LLM 返回的 Pipeline JSON。"""
try: try:
@ -294,9 +316,7 @@ JSON pipeline:"""
text = "\n".join(lines[1:-1]) text = "\n".join(lines[1:-1])
data = json.loads(text) data = json.loads(text)
stages = [ stages = [PipelineStage(**s) for s in data.get("stages", [])]
PipelineStage(**s) for s in data.get("stages", [])
]
return Pipeline( return Pipeline(
name=data.get("name", original.name), name=data.get("name", original.name),
version=data.get("version", original.version), version=data.get("version", original.version),
@ -316,8 +336,7 @@ JSON pipeline:"""
) -> Pipeline: ) -> Pipeline:
"""基于规则的兜底重规划。""" """基于规则的兜底重规划。"""
completed_stages = { completed_stages = {
name for name, sr in result.stage_results.items() name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
if sr.status == StageStatus.COMPLETED
} }
# 构建修正后的 stages 列表 # 构建修正后的 stages 列表
@ -345,17 +364,21 @@ JSON pipeline:"""
) )
def _adjust_failed_stage( def _adjust_failed_stage(
self, stage: PipelineStage, report: ReflectionReport, self,
stage: PipelineStage,
report: ReflectionReport,
) -> PipelineStage: ) -> PipelineStage:
"""根据反思报告调整失败的步骤。""" """根据反思报告调整失败的步骤。"""
adjustments: dict[str, Any] = {} adjustments: dict[str, object] = {}
if report.failure_type == "timeout": if report.failure_type == "timeout":
adjustments["timeout_seconds"] = min( adjustments["timeout_seconds"] = min(
stage.timeout_seconds * 2, 3600, stage.timeout_seconds * 2,
3600,
) )
if stage.retry_policy is None: if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
elif report.failure_type == "resource_error": elif report.failure_type == "resource_error":
@ -365,6 +388,7 @@ JSON pipeline:"""
# 添加重试策略,可能输入在后续可用 # 添加重试策略,可能输入在后续可用
if stage.retry_policy is None: if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2) adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
return stage.model_copy(update=adjustments) return stage.model_copy(update=adjustments)

View File

@ -4,7 +4,7 @@ import asyncio
import logging import logging
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Callable from typing import Awaitable, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class StepRetryPolicy:
def calculate_delay(self, attempt: int) -> float: def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt number (0-based)""" """Calculate delay for given attempt number (0-based)"""
delay = min( delay = min(
self.base_delay * (self.exponential_base ** attempt), self.base_delay * (self.exponential_base**attempt),
self.max_delay, self.max_delay,
) )
if self.jitter: if self.jitter:
@ -36,10 +36,10 @@ class StepRetryPolicy:
async def execute_with_retry( async def execute_with_retry(
func: Callable[..., Awaitable[Any]], func: Callable[..., Awaitable[object]],
retry_policy: StepRetryPolicy | None = None, retry_policy: StepRetryPolicy | None = None,
step_name: str = "", step_name: str = "",
) -> Any: ) -> object:
"""Execute a function with retry policy""" """Execute a function with retry policy"""
if retry_policy is None: if retry_policy is None:
return await func() return await func()

View File

@ -3,18 +3,17 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage from agentkit.orchestrator.pipeline_schema import PipelineStage
class WorkflowStage(PipelineStage): class WorkflowStage(PipelineStage):
"""A workflow stage extending PipelineStage with type and config.""" """A workflow stage extending PipelineStage with type and config."""
type: str = "skill" # "skill" | "condition" | "approval" | "parallel" type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
config: dict[str, Any] = Field(default_factory=dict) config: dict[str, object] = Field(default_factory=dict)
class WorkflowDefinition(BaseModel): class WorkflowDefinition(BaseModel):
@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel):
name: str name: str
version: int = 1 version: int = 1
stages: list[WorkflowStage] = Field(default_factory=list) stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list) triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict) variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict) output_schema: dict[str, object] = Field(default_factory=dict)
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel):
workflow_id: str = "" workflow_id: str = ""
status: str = "pending" # pending|running|paused|completed|failed|cancelled status: str = "pending" # pending|running|paused|completed|failed|cancelled
current_stage: str | None = None current_stage: str | None = None
stage_results: dict[str, Any] = Field(default_factory=dict) stage_results: dict[str, object] = Field(default_factory=dict)
started_at: str | None = None started_at: str | None = None
completed_at: str | None = None completed_at: str | None = None
error: str | None = None error: str | None = None
variables: dict[str, Any] = Field(default_factory=dict) variables: dict[str, object] = Field(default_factory=dict)
class WorkflowSummary(BaseModel): class WorkflowSummary(BaseModel):
@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel):
name: str name: str
stages: list[WorkflowStage] = Field(default_factory=list) stages: list[WorkflowStage] = Field(default_factory=list)
triggers: list[dict[str, Any]] = Field(default_factory=list) triggers: list[dict[str, object]] = Field(default_factory=list)
variables_schema: dict[str, Any] = Field(default_factory=dict) variables_schema: dict[str, object] = Field(default_factory=dict)
output_schema: dict[str, Any] = Field(default_factory=dict) output_schema: dict[str, object] = Field(default_factory=dict)
class ExecuteWorkflowRequest(BaseModel): class ExecuteWorkflowRequest(BaseModel):
"""Request body for executing a workflow.""" """Request body for executing a workflow."""
variables: dict[str, Any] = Field(default_factory=dict) variables: dict[str, object] = Field(default_factory=dict)
class ApproveRequest(BaseModel): class ApproveRequest(BaseModel):

View File

@ -37,7 +37,7 @@ async def create_agent(request: CreateAgentRequest, req: Request):
config_dict = request.config config_dict = request.config
try: try:
config = SkillConfig.from_dict(config_dict) config = SkillConfig.from_dict(config_dict)
except Exception: except (ValueError, KeyError, TypeError):
config = AgentConfig.from_dict(config_dict) config = AgentConfig.from_dict(config_dict)
agent = await pool.create_agent(config) agent = await pool.create_agent(config)
else: else:

View File

@ -53,6 +53,8 @@ from agentkit.server.auth.session_service import (
REVOKE_REASON_PASSWORD_CHANGED, REVOKE_REASON_PASSWORD_CHANGED,
REVOKE_REASON_USER_TERMINATED, REVOKE_REASON_USER_TERMINATED,
SessionCreate, SessionCreate,
SessionNotFound,
SessionReuseDetected,
SessionService, SessionService,
get_session_service, get_session_service,
) )
@ -253,7 +255,7 @@ def _is_legacy_client(request: Request) -> bool:
if client_v is None or cutoff_v is None: if client_v is None or cutoff_v is None:
return False return False
return client_v < cutoff_v return client_v < cutoff_v
except Exception: # noqa: BLE001 except (ValueError, TypeError, AttributeError): # noqa: BLE001
logger.debug("Failed to parse X-Client-Version %r", raw) logger.debug("Failed to parse X-Client-Version %r", raw)
return False return False
@ -492,7 +494,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
# 1. Verify signature + type # 1. Verify signature + type
try: try:
refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh") refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh")
except Exception as exc: # noqa: BLE001 except (jwt.PyJWTError, ValueError, KeyError) as exc: # noqa: BLE001
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
# 2-3. Validate the session (also handles reuse detection) # 2-3. Validate the session (also handles reuse detection)
@ -510,7 +512,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
new_refresh_token=new_pair.refresh_token, new_refresh_token=new_pair.refresh_token,
new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()), new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()),
) )
except Exception as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound except (SessionReuseDetected, SessionNotFound, ValueError, KeyError, RuntimeError) as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound
logger.info("Refresh rejected: %s", exc) logger.info("Refresh rejected: %s", exc)
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc raise HTTPException(status_code=401, detail="Invalid refresh token") from exc

View File

@ -483,7 +483,7 @@ async def validate_formula(
parse_formula(body.formula) parse_formula(body.formula)
except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e: except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e:
return {"valid": False, "error": str(e)} return {"valid": False, "error": str(e)}
except Exception as e: # pragma: no cover — defensive except (ValueError, TypeError, KeyError, AttributeError) as e: # pragma: no cover — defensive
return {"valid": False, "error": f"Unexpected error: {e}"} return {"valid": False, "error": f"Unexpected error: {e}"}
return {"valid": True} return {"valid": True}
@ -750,7 +750,7 @@ async def upload_file(
f.write(chunk) f.write(chunk)
except HTTPException: except HTTPException:
raise raise
except Exception as exc: except (OSError, RuntimeError) as exc:
file_path.unlink(missing_ok=True) file_path.unlink(missing_ok=True)
logger.error(f"Failed to save uploaded bitable file: {exc}") logger.error(f"Failed to save uploaded bitable file: {exc}")
raise HTTPException(status_code=500, detail="Failed to save file") from exc raise HTTPException(status_code=500, detail="Failed to save file") from exc

View File

@ -553,7 +553,7 @@ async def _invalidate_adapter_cache(channel_id: str) -> None:
if old is not None: if old is not None:
try: try:
await old.close() await old.close()
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更 except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001 — 关闭异常不应阻塞配置变更
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id) logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
@ -562,7 +562,7 @@ async def close_all_adapters() -> None:
for channel_id, adapter in list(_adapter_cache.items()): for channel_id, adapter in list(_adapter_cache.items()):
try: try:
await adapter.close() await adapter.close()
except Exception: # noqa: BLE001 except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id) logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
_adapter_cache.clear() _adapter_cache.clear()
@ -614,6 +614,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
model=routing.model or "default", model=routing.model or "default",
) )
final_content = getattr(result, "content", "") or "" final_content = getattr(result, "content", "") or ""
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常 except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc) logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
final_content = await _direct_chat(llm_gateway, routing) final_content = await _direct_chat(llm_gateway, routing)
@ -628,6 +630,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
content=final_content, content=final_content,
) )
await adapter.send_message(outgoing) await adapter.send_message(outgoing)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力 except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
logger.exception("处理入站消息失败: %s", exc) logger.exception("处理入站消息失败: %s", exc)
@ -703,7 +707,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
except WeComURLVerification as e: except WeComURLVerification as e:
# 企微 URL 验证 — 返回 XML 响应 # 企微 URL 验证 — 返回 XML 响应
return Response(content=e.response_xml, media_type="application/xml") return Response(content=e.response_xml, media_type="application/xml")
except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴 except (ValueError, KeyError, RuntimeError, AttributeError, OSError) as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc) logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
return {"code": 0, "msg": "invalid_payload"} return {"code": 0, "msg": "invalid_payload"}

View File

@ -108,7 +108,7 @@ class ChatConnectionManager:
for ws, _ in conns: for ws, _ in conns:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
stale.append(ws) stale.append(ws)
for ws in stale: for ws in stale:
self.remove(session_id, ws) self.remove(session_id, ws)
@ -295,6 +295,8 @@ async def _execute_board_meeting(
await team.create_board(topic=routing_result.topic, expert_configs=expert_configs) await team.create_board(topic=routing_result.topic, expert_configs=expert_configs)
orchestrator = BoardOrchestrator(team=team) orchestrator = BoardOrchestrator(team=team)
result = await orchestrator.execute(routing_result.topic) result = await orchestrator.execute(routing_result.topic)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True) logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True)
await websocket.send_json( await websocket.send_json(
@ -302,7 +304,7 @@ async def _execute_board_meeting(
) )
try: try:
await team.dissolve() await team.dissolve()
except Exception: except (RuntimeError, asyncio.TimeoutError, ConnectionError):
pass pass
return True return True
finally: finally:
@ -348,7 +350,7 @@ async def _execute_board_meeting(
# Dissolve the team to release expert agents # Dissolve the team to release expert agents
try: try:
await team.dissolve() await team.dissolve()
except Exception as e: except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.warning(f"Board team dissolve failed: {e}") logger.warning(f"Board team dissolve failed: {e}")
return True return True
@ -467,7 +469,7 @@ async def _execute_team_collab(
# Always dissolve the team and remove handler to avoid leaks # Always dissolve the team and remove handler to avoid leaks
try: try:
await team.dissolve() await team.dissolve()
except Exception as e: except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.warning(f"Team dissolve failed: {e}") logger.warning(f"Team dissolve failed: {e}")
# dissolve() already clears handlers via handoff_transport.close() # dissolve() already clears handlers via handoff_transport.close()
@ -585,7 +587,7 @@ def _build_phase_engine(
if phase_policy is None: if phase_policy is None:
# Empty config (no `plan_exec:` section) → use KTD5 defaults. # Empty config (no `plan_exec:` section) → use KTD5 defaults.
phase_policy = default_policy() phase_policy = default_policy()
except Exception as e: except (ValueError, TypeError, KeyError) as e:
logger.error( logger.error(
"PLAN_EXEC phase policy construction failed for session %s: %s", "PLAN_EXEC phase policy construction failed for session %s: %s",
session_id, session_id,
@ -695,6 +697,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
agent_name=agent.name, agent_name=agent.name,
system_prompt=system_prompt, system_prompt=system_prompt,
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}") logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -773,6 +777,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
return response_dict return response_dict
return response return response
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}") logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -829,7 +835,7 @@ async def _resolve_ws_dept_context(
db_path_resolved = Path(db_path) db_path_resolved = Path(db_path)
try: try:
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id) department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.exception( logger.exception(
"Failed to fetch department ids for WebSocket user %s — fail-closed", "Failed to fetch department ids for WebSocket user %s — fail-closed",
user_id, user_id,
@ -946,7 +952,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"data": {"content": content}, "data": {"content": content},
} }
) )
except Exception as e: except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
logger.warning(f"Failed to enqueue intervention: {e}") logger.warning(f"Failed to enqueue intervention: {e}")
await websocket.send_json( await websocket.send_json(
{ {
@ -1022,11 +1028,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}") logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}") logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try: try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}}) await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass pass
finally: finally:
# Clean up pending futures # Clean up pending futures
@ -1174,6 +1182,8 @@ async def _handle_chat_message(
content=final_content, content=final_content,
agent_name=agent.name, agent_name=agent.name,
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
# Check if this is a QuotaExceededError (U4: WebSocket quota). # Check if this is a QuotaExceededError (U4: WebSocket quota).
from agentkit.llm.gateway import QuotaExceededError from agentkit.llm.gateway import QuotaExceededError
@ -1422,6 +1432,8 @@ async def _handle_chat_message(
agent_name=agent.name, agent_name=agent.name,
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}") logger.error(f"Chat execution error for session {session_id}: {e}")
# Show meaningful error to user, but avoid leaking full stack traces # Show meaningful error to user, but avoid leaking full stack traces
@ -1473,6 +1485,8 @@ async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]:
file_path.write_bytes(contents) file_path.write_bytes(contents)
except HTTPException: except HTTPException:
raise raise
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
logger.error(f"Failed to save uploaded file: {exc}") logger.error(f"Failed to save uploaded file: {exc}")
raise HTTPException(status_code=500, detail="Failed to save file") from exc raise HTTPException(status_code=500, detail="Failed to save file") from exc

View File

@ -64,7 +64,7 @@ def _collect_skill_configs(request: Request) -> list[dict[str, Any]]:
"version": getattr(skill.config, "version", "1.0.0"), "version": getattr(skill.config, "version", "1.0.0"),
"config": config_dict, "config": config_dict,
}) })
except Exception as e: except (AttributeError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Failed to collect skill configs: {e}") logger.warning(f"Failed to collect skill configs: {e}")
return configs return configs
@ -81,7 +81,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
from agentkit.server.routes.workflows import _workflow_store from agentkit.server.routes.workflows import _workflow_store
workflow_store = _workflow_store workflow_store = _workflow_store
except Exception: except (ImportError, AttributeError):
return [] return []
configs: list[dict[str, Any]] = [] configs: list[dict[str, Any]] = []
@ -94,7 +94,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
configs.append(wf.to_dict()) configs.append(wf.to_dict())
else: else:
configs.append(dict(wf)) configs.append(dict(wf))
except Exception as e: except (RuntimeError, AttributeError, ValueError) as e:
logger.warning(f"Failed to collect workflow configs: {e}") logger.warning(f"Failed to collect workflow configs: {e}")
return configs return configs

View File

@ -13,6 +13,7 @@ Endpoints:
from __future__ import annotations from __future__ import annotations
import asyncio
import hmac import hmac
import logging import logging
import uuid import uuid
@ -155,6 +156,8 @@ async def create_document(
raise HTTPException(status_code=400, detail=str(e)) from e raise HTTPException(status_code=400, detail=str(e)) from e
except FileNotFoundError as e: except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e raise HTTPException(status_code=404, detail=str(e)) from e
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Document creation failed: {e}") logger.error(f"Document creation failed: {e}")
raise HTTPException(status_code=500, detail="Document creation failed") from e raise HTTPException(status_code=500, detail="Document creation failed") from e
@ -187,7 +190,7 @@ async def upload_template(
file_path.write_bytes(contents) file_path.write_bytes(contents)
except HTTPException: except HTTPException:
raise raise
except Exception as exc: except (OSError, RuntimeError) as exc:
logger.error(f"Failed to save template: {exc}") logger.error(f"Failed to save template: {exc}")
raise HTTPException(status_code=500, detail="Failed to save template") from exc raise HTTPException(status_code=500, detail="Failed to save template") from exc
finally: finally:

View File

@ -1,5 +1,6 @@
"""Evolution API routes""" """Evolution API routes"""
import asyncio
import logging import logging
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
@ -42,6 +43,8 @@ async def list_evolution_events(
agent_name=agent_name, agent_name=agent_name,
change_type=event_type, change_type=event_type,
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to list evolution events: {e}") logger.error(f"Failed to list evolution events: {e}")
raise HTTPException(status_code=500, detail="Failed to list evolution events") raise HTTPException(status_code=500, detail="Failed to list evolution events")
@ -63,6 +66,8 @@ async def get_skill_versions(skill_name: str, req: Request = None):
store = _get_evolution_store(req) store = _get_evolution_store(req)
try: try:
versions = await store.list_skill_versions(skill_name) versions = await store.list_skill_versions(skill_name)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to get skill versions for '{skill_name}': {e}") logger.error(f"Failed to get skill versions for '{skill_name}': {e}")
raise HTTPException(status_code=500, detail="Failed to get skill versions") raise HTTPException(status_code=500, detail="Failed to get skill versions")
@ -103,6 +108,8 @@ async def trigger_evolution(request: TriggerEvolutionRequest, req: Request = Non
) )
try: try:
event_id = await store.record(event) event_id = await store.record(event)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to record trigger event: {e}") logger.error(f"Failed to record trigger event: {e}")
raise HTTPException(status_code=500, detail="Failed to trigger evolution") raise HTTPException(status_code=500, detail="Failed to trigger evolution")
@ -153,6 +160,8 @@ async def list_ab_tests(
for e in entries for e in entries
] ]
return {"items": results[:limit], "total": len(results)} return {"items": results[:limit], "total": len(results)}
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to list A/B tests from persistent store: {e}") logger.error(f"Failed to list A/B tests from persistent store: {e}")
raise HTTPException(status_code=500, detail="Failed to list A/B tests") raise HTTPException(status_code=500, detail="Failed to list A/B tests")

View File

@ -194,7 +194,7 @@ async def list_experiences(
} }
) )
return {"experiences": experiences, "total": len(experiences)} return {"experiences": experiences, "total": len(experiences)}
except Exception as e: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.error(f"Failed to list experiences from store: {e}") logger.error(f"Failed to list experiences from store: {e}")
# Fallback to in-memory store # Fallback to in-memory store
@ -324,7 +324,7 @@ async def get_metrics(
# Generate daily trends from the metrics # Generate daily trends from the metrics
trends = _generate_trends(metrics_list, period) trends = _generate_trends(metrics_list, period)
except Exception as e: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.error(f"Failed to get metrics from store: {e}") logger.error(f"Failed to get metrics from store: {e}")
else: else:
# Generate from in-memory experiences # Generate from in-memory experiences
@ -501,7 +501,7 @@ async def get_usage(
"errors": 0, "errors": 0,
"avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1), "avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1),
}) })
except Exception as e: except (ConnectionError, OSError, ValueError, KeyError, RuntimeError, AttributeError) as e:
logger.error(f"Failed to get usage from LLMGateway: {e}") logger.error(f"Failed to get usage from LLMGateway: {e}")
# Fill in missing dates with zero # Fill in missing dates with zero
@ -587,7 +587,7 @@ async def check_pitfalls(
}) })
return {"warnings": warnings_data} return {"warnings": warnings_data}
except Exception as e: except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Failed to check pitfalls: {e}") logger.error(f"Failed to check pitfalls: {e}")
return {"warnings": []} return {"warnings": []}
@ -642,7 +642,7 @@ async def list_path_optimizations(
else None, else None,
} }
) )
except Exception as e: except (ValueError, KeyError, RuntimeError, AttributeError) as e:
logger.error(f"Failed to get path optimizations: {e}") logger.error(f"Failed to get path optimizations: {e}")
# Also include in-memory optimizations # Also include in-memory optimizations
@ -767,7 +767,7 @@ async def evolution_dashboard_ws(websocket: WebSocket):
) )
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug("Evolution dashboard WebSocket disconnected") logger.debug("Evolution dashboard WebSocket disconnected")
except Exception as e: except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Evolution dashboard WebSocket error: {e}") logger.error(f"Evolution dashboard WebSocket error: {e}")
finally: finally:
if websocket in _ws_connections: if websocket in _ws_connections:
@ -781,7 +781,7 @@ async def _broadcast_event(event_type: str, data: dict):
for ws in _ws_connections: for ws in _ws_connections:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
disconnected.append(ws) disconnected.append(ws)
for ws in disconnected: for ws in disconnected:
if ws in _ws_connections: if ws in _ws_connections:

View File

@ -1,5 +1,7 @@
"""Health check route""" """Health check route"""
import asyncio
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
router = APIRouter(tags=["health"]) router = APIRouter(tags=["health"])
@ -23,14 +25,14 @@ async def health_check(request: Request):
redis_client = await task_store._get_redis() redis_client = await task_store._get_redis()
await redis_client.ping() await redis_client.ping()
redis_status = "available" redis_status = "available"
except Exception as ping_exc: except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as ping_exc:
redis_status = f"error: {str(ping_exc)[:100]}" redis_status = f"error: {str(ping_exc)[:100]}"
overall_status = "degraded" overall_status = "degraded"
else: else:
redis_status = "not_configured" redis_status = "not_configured"
else: else:
redis_status = "not_configured" redis_status = "not_configured"
except Exception as exc: except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError, AttributeError) as exc:
redis_status = f"error: {str(exc)[:100]}" redis_status = f"error: {str(exc)[:100]}"
overall_status = "degraded" overall_status = "degraded"
checks["redis"] = redis_status checks["redis"] = redis_status
@ -42,7 +44,7 @@ async def health_check(request: Request):
try: try:
agents = agent_pool.list_agents() agents = agent_pool.list_agents()
pool_size = len(agents) pool_size = len(agents)
except Exception: except (RuntimeError, AttributeError):
pass pass
checks["agent_pool"] = {"status": "available", "size": pool_size} checks["agent_pool"] = {"status": "available", "size": pool_size}
@ -57,7 +59,7 @@ async def health_check(request: Request):
else: else:
llm_status = "no_providers" llm_status = "no_providers"
overall_status = "degraded" overall_status = "degraded"
except Exception: except (RuntimeError, AttributeError, ValueError):
llm_status = "error" llm_status = "error"
overall_status = "degraded" overall_status = "degraded"
checks["llm_gateway"] = llm_status checks["llm_gateway"] = llm_status
@ -68,7 +70,7 @@ async def health_check(request: Request):
if skill_registry: if skill_registry:
try: try:
skill_count = len(skill_registry.list_skills()) skill_count = len(skill_registry.list_skills())
except Exception: except (RuntimeError, AttributeError):
pass pass
checks["skill_registry"] = { checks["skill_registry"] = {
"status": "available" if skill_registry else "not_configured", "status": "available" if skill_registry else "not_configured",

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import hmac import hmac
import logging import logging
import os import os
@ -252,7 +253,7 @@ async def list_sources(
visible_ids = await filter_kb_sources_by_department( visible_ids = await filter_kb_sources_by_department(
db_path, dept_ctx.department_ids, all_ids db_path, dept_ctx.department_ids, all_ids
) )
except Exception: # noqa: BLE001 — never block listing on DB errors except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
logger.exception("Department KB filtering failed — returning empty list") logger.exception("Department KB filtering failed — returning empty list")
return {"sources": []} return {"sources": []}
visible_set = set(visible_ids) visible_set = set(visible_ids)
@ -398,7 +399,7 @@ async def list_documents(
visible_ids = await filter_kb_sources_by_department( visible_ids = await filter_kb_sources_by_department(
db_path, dept_ctx.department_ids, all_source_ids db_path, dept_ctx.department_ids, all_source_ids
) )
except Exception: # noqa: BLE001 — never block listing on DB errors except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
logger.exception("Department KB filtering failed — returning empty list") logger.exception("Department KB filtering failed — returning empty list")
return {"documents": []} return {"documents": []}
visible_set = set(visible_ids) visible_set = set(visible_ids)
@ -484,7 +485,7 @@ async def upload_document(
try: try:
text = processor.parse(tmp_path, file_type) text = processor.parse(tmp_path, file_type)
chunks = processor.segment(text) chunks = processor.segment(text)
except Exception as e: except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
logger.warning("Document parsing failed: %s", e) logger.warning("Document parsing failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
@ -567,7 +568,7 @@ async def preview_document(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
except Exception as e: except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
logger.warning("Document preview failed: %s", e) logger.warning("Document preview failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
@ -628,7 +629,7 @@ async def search_knowledge(
for r in results for r in results
] ]
} }
except Exception as e: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Semantic search failed: {e}") logger.warning(f"Semantic search failed: {e}")
# Fallback: return empty results with a hint # Fallback: return empty results with a hint

View File

@ -115,7 +115,7 @@ async def publish_skill(
try: try:
skill = skill_registry.get(skill_name) skill = skill_registry.get(skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
try: try:

View File

@ -1,5 +1,6 @@
"""Memory API routes""" """Memory API routes"""
import asyncio
import logging import logging
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
@ -40,6 +41,8 @@ async def search_episodic_memory(
if agent_name: if agent_name:
filters["agent_name"] = agent_name filters["agent_name"] = agent_name
items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None) items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to search episodic memory: {e}") logger.error(f"Failed to search episodic memory: {e}")
raise HTTPException(status_code=500, detail="Failed to search episodic memory") raise HTTPException(status_code=500, detail="Failed to search episodic memory")
@ -76,6 +79,8 @@ async def search_semantic_memory(
if knowledge_base_ids: if knowledge_base_ids:
filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")] filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")]
items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None) items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to search semantic memory: {e}") logger.error(f"Failed to search semantic memory: {e}")
raise HTTPException(status_code=500, detail="Failed to search semantic memory") raise HTTPException(status_code=500, detail="Failed to search semantic memory")
@ -104,6 +109,8 @@ async def delete_episodic_memory(key: str, req: Request = None):
try: try:
deleted = await retriever._episodic.delete(key) deleted = await retriever._episodic.delete(key)
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to delete episodic memory '{key}': {e}") logger.error(f"Failed to delete episodic memory '{key}': {e}")
raise HTTPException(status_code=500, detail="Failed to delete episodic memory") raise HTTPException(status_code=500, detail="Failed to delete episodic memory")

View File

@ -1,5 +1,6 @@
"""Metrics route — /api/v1/metrics""" """Metrics route — /api/v1/metrics"""
import asyncio
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
@ -29,7 +30,7 @@ async def get_metrics(request: Request):
task_metrics["completed_tasks"] = counts.get("completed", 0) task_metrics["completed_tasks"] = counts.get("completed", 0)
task_metrics["failed_tasks"] = counts.get("failed", 0) task_metrics["failed_tasks"] = counts.get("failed", 0)
task_metrics["pending_tasks"] = counts.get("pending", 0) task_metrics["pending_tasks"] = counts.get("pending", 0)
except Exception as e: except (RuntimeError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
logger.warning(f"Failed to collect task metrics: {e}") logger.warning(f"Failed to collect task metrics: {e}")
# Agent pool metrics # Agent pool metrics
@ -41,7 +42,7 @@ async def get_metrics(request: Request):
try: try:
agents = agent_pool.list_agents() agents = agent_pool.list_agents()
agent_metrics["total_agents"] = len(agents) agent_metrics["total_agents"] = len(agents)
except Exception as e: except (RuntimeError, AttributeError) as e:
logger.warning(f"Failed to collect agent metrics: {e}") logger.warning(f"Failed to collect agent metrics: {e}")
# Skill registry metrics # Skill registry metrics
@ -53,7 +54,7 @@ async def get_metrics(request: Request):
try: try:
skills = skill_registry.list_skills() skills = skill_registry.list_skills()
skill_metrics["total_skills"] = len(skills) skill_metrics["total_skills"] = len(skills)
except Exception as e: except (RuntimeError, AttributeError) as e:
logger.warning(f"Failed to collect skill metrics: {e}") logger.warning(f"Failed to collect skill metrics: {e}")
return { return {

View File

@ -94,7 +94,7 @@ async def _emit_event_safe(
data=data or {}, data=data or {},
) )
await event_queue.emit(event) await event_queue.emit(event)
except Exception as e: except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True) logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
@ -211,7 +211,7 @@ class PortalConnectionManager:
import asyncio import asyncio
asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded")) asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded"))
except Exception: except (ConnectionError, RuntimeError):
pass pass
conns.append(ws) conns.append(ws)
@ -235,7 +235,7 @@ class PortalConnectionManager:
for ws in conns: for ws in conns:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception as e: except (ConnectionError, RuntimeError, asyncio.TimeoutError) as e:
logger.debug( logger.debug(
"Portal WS send failed for user %s (marking stale): %s", user_id, e "Portal WS send failed for user %s (marking stale): %s", user_id, e
) )
@ -285,7 +285,7 @@ async def _build_history_messages(
""" """
try: try:
history = await _conversation_store.get_history(conv_id, limit=limit) history = await _conversation_store.get_history(conv_id, limit=limit)
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
return [] return []
# The last message in history is the current user message (just added), # The last message in history is the current user message (just added),
@ -553,6 +553,8 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
): ):
if event.event_type == "final_answer": if event.event_type == "final_answer":
collected_output.append(event.data.get("output", "")) collected_output.append(event.data.get("output", ""))
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
response_text = f"执行出错: {e}" response_text = f"执行出错: {e}"
else: else:
@ -682,6 +684,8 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
} }
), ),
} }
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
yield { yield {
"event": "error", "event": "error",
@ -862,7 +866,7 @@ async def _execute_react_background(
await _task_store_update_status( await _task_store_update_status(
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc) task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore RUNNING", exc_info=True) logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
async for event in react_engine.execute_stream( async for event in react_engine.execute_stream(
@ -909,7 +913,7 @@ async def _execute_react_background(
progress=1.0, progress=1.0,
progress_message="Completed", progress_message="Completed",
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True) logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
# Emit task.completed so subscribers know the task is done # Emit task.completed so subscribers know the task is done
@ -932,7 +936,15 @@ async def _execute_react_background(
partial = _ensure_non_empty("".join(collected_output)) partial = _ensure_non_empty("".join(collected_output))
try: try:
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial)) await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
except (Exception, asyncio.CancelledError): except (
asyncio.CancelledError,
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.warning("Failed to persist partial output on cancel") logger.warning("Failed to persist partial output on cancel")
if task_store is not None: if task_store is not None:
try: try:
@ -945,7 +957,15 @@ async def _execute_react_background(
completed_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc),
) )
) )
except (Exception, asyncio.CancelledError): except (
asyncio.CancelledError,
ConnectionError,
OSError,
asyncio.TimeoutError,
ValueError,
KeyError,
RuntimeError,
):
logger.warning("Failed to update TaskStore on cancel", exc_info=True) logger.warning("Failed to update TaskStore on cancel", exc_info=True)
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit). # P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
# Shield it so a re-entrant CancelledError doesn't kill the emit # Shield it so a re-entrant CancelledError doesn't kill the emit
@ -963,7 +983,7 @@ async def _execute_react_background(
}, },
) )
) )
except (Exception, asyncio.CancelledError): except (asyncio.CancelledError, asyncio.QueueFull, RuntimeError, ConnectionError):
logger.warning("Failed to emit TASK_FAILED on cancel") logger.warning("Failed to emit TASK_FAILED on cancel")
raise # Propagate cancellation raise # Propagate cancellation
@ -973,7 +993,7 @@ async def _execute_react_background(
partial = _ensure_non_empty("".join(collected_output)) partial = _ensure_non_empty("".join(collected_output))
try: try:
await conversation_store.add_message(conv_id, "assistant", partial) await conversation_store.add_message(conv_id, "assistant", partial)
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to persist partial output in background task") logger.warning("Failed to persist partial output in background task")
if task_store is not None: if task_store is not None:
@ -985,7 +1005,7 @@ async def _execute_react_background(
error_message=str(e), error_message=str(e),
completed_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc),
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore FAILED", exc_info=True) logger.warning("Failed to update TaskStore FAILED", exc_info=True)
# Emit task.failed so subscribers know the task failed # Emit task.failed so subscribers know the task failed
@ -1120,7 +1140,7 @@ async def portal_websocket(websocket: WebSocket):
try: try:
record = await _task_store_get(resume_task_store, resume_task_id) record = await _task_store_get(resume_task_store, resume_task_id)
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("TaskStore.get failed during resume", exc_info=True) logger.warning("TaskStore.get failed during resume", exc_info=True)
record = None record = None
if record is not None: if record is not None:
@ -1333,7 +1353,7 @@ async def portal_websocket(websocket: WebSocket):
}, },
) )
await _broadcast_dashboard_event("metrics_updated", {"period": "7d"}) await _broadcast_dashboard_event("metrics_updated", {"period": "7d"})
except Exception as e: except (asyncio.QueueFull, RuntimeError, ConnectionError, ValueError, KeyError) as e:
logger.warning(f"Failed to record experience: {e}") logger.warning(f"Failed to record experience: {e}")
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT) # Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
@ -1414,7 +1434,7 @@ async def portal_websocket(websocket: WebSocket):
TaskStatus.PENDING, TaskStatus.PENDING,
metadata={"conversation_id": conv.id}, metadata={"conversation_id": conv.id},
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to register task in TaskStore", exc_info=True) logger.warning("Failed to register task in TaskStore", exc_info=True)
# Execute based on routing result's execution_mode # Execute based on routing result's execution_mode
@ -1455,7 +1475,7 @@ async def portal_websocket(websocket: WebSocket):
progress=1.0, progress=1.0,
progress_message="Completed", progress_message="Completed",
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True) logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
# Emit turn.final_answer and task.completed to EQ # Emit turn.final_answer and task.completed to EQ
@ -1526,7 +1546,7 @@ async def portal_websocket(websocket: WebSocket):
chat_messages.insert( chat_messages.insert(
-1, {"role": hist_msg.role, "content": hist_msg.content} -1, {"role": hist_msg.role, "content": hist_msg.content}
) )
except Exception: except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
pass pass
response = await llm_gateway.chat( response = await llm_gateway.chat(
messages=chat_messages, messages=chat_messages,
@ -1627,7 +1647,7 @@ async def portal_websocket(websocket: WebSocket):
logger.warning("EventQueue not configured; awaiting background task directly") logger.warning("EventQueue not configured; awaiting background task directly")
try: try:
await bg_task await bg_task
except Exception: except (RuntimeError, ConnectionError, asyncio.TimeoutError):
pass # errors handled inside _execute_react_background pass # errors handled inside _execute_react_background
active_bg_task = None active_bg_task = None
continue continue
@ -1734,6 +1754,8 @@ async def portal_websocket(websocket: WebSocket):
# kill the task, lose the full output, and mark it FAILED — # kill the task, lose the full output, and mark it FAILED —
# defeating layers 2 and 3. The task is only cancelled on explicit # defeating layers 2 and 3. The task is only cancelled on explicit
# user cancel (msg_type == 'cancel') or application shutdown. # user cancel (msg_type == 'cancel') or application shutdown.
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Portal WebSocket error: {e}") logger.error(f"Portal WebSocket error: {e}")
# P1 #6 fix: Do NOT cancel the background task on connection-level # P1 #6 fix: Do NOT cancel the background task on connection-level
@ -1758,7 +1780,7 @@ async def portal_websocket(websocket: WebSocket):
) )
try: try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}}) await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass pass
finally: finally:
# Remove from user-scoped push tracking on any disconnect/error/return. # Remove from user-scoped push tracking on any disconnect/error/return.

View File

@ -119,7 +119,7 @@ def _skill_to_detail(skill: Any) -> dict[str, Any]:
if hasattr(skill, "config"): if hasattr(skill, "config"):
try: try:
config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {} config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {}
except Exception: except (AttributeError, ValueError, TypeError):
config = {} config = {}
return { return {
@ -174,7 +174,7 @@ async def get_skill_detail(skill_name: str, req: Request):
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
try: try:
skill = skill_registry.get(skill_name) skill = skill_registry.get(skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
return _skill_to_detail(skill) return _skill_to_detail(skill)
@ -186,7 +186,7 @@ async def check_skill_health(skill_name: str, req: Request):
skill_registry = req.app.state.skill_registry skill_registry = req.app.state.skill_registry
try: try:
skill_registry.get(skill_name) skill_registry.get(skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
# Basic health check - skill exists and is registered # Basic health check - skill exists and is registered
@ -243,7 +243,7 @@ async def reload_skill(skill_name: str, req: Request):
# Verify the skill is currently registered (404 if not). # Verify the skill is currently registered (404 if not).
try: try:
skill_registry.get(skill_name) skill_registry.get(skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
# Resolve the skills directory (mirrors routes.skills._get_skills_dir). # Resolve the skills directory (mirrors routes.skills._get_skills_dir).

View File

@ -136,7 +136,7 @@ async def register_skill(request: RegisterSkillRequest, req: Request):
try: try:
config = SkillConfig.from_dict(request.config) config = SkillConfig.from_dict(request.config)
except Exception as e: except (ValueError, TypeError, KeyError) as e:
raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}") raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}")
skill = Skill(config=config) skill = Skill(config=config)
@ -279,7 +279,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
resp = await client.get(source) resp = await client.get(source)
resp.raise_for_status() resp.raise_for_status()
yaml_content = resp.text yaml_content = resp.text
except Exception as e: except (httpx.HTTPError, OSError) as e:
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}") raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
elif source and source.startswith("file://"): elif source and source.startswith("file://"):
# Read from local file path # Read from local file path
@ -295,7 +295,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
try: try:
with open(local_path, encoding="utf-8") as f: with open(local_path, encoding="utf-8") as f:
yaml_content = f.read() yaml_content = f.read()
except Exception as e: except OSError as e:
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}") raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
else: else:
# Search GitHub for skills (YAML config files) # Search GitHub for skills (YAML config files)
@ -313,7 +313,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
}, },
) )
gh_data = gh_resp.json() gh_data = gh_resp.json()
except Exception as e: except (httpx.HTTPError, OSError, ValueError) as e:
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}") raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
items = gh_data.get("items", []) items = gh_data.get("items", [])
@ -334,7 +334,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
}, },
) )
items = gh_resp2.json().get("items", []) items = gh_resp2.json().get("items", [])
except Exception: except (httpx.HTTPError, OSError, ValueError, KeyError):
items = [] items = []
if not items: if not items:
@ -362,7 +362,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
resp = await client.get(raw_url) resp = await client.get(raw_url)
resp.raise_for_status() resp.raise_for_status()
yaml_content = resp.text yaml_content = resp.text
except Exception as e: except (httpx.HTTPError, OSError) as e:
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}") raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
# Validate YAML content before writing to disk # Validate YAML content before writing to disk
@ -391,14 +391,14 @@ async def install_skill(request: InstallSkillRequest, req: Request):
) )
loader.load_from_file(file_path) loader.load_from_file(file_path)
registration_ok = True registration_ok = True
except Exception as e: except (ValueError, TypeError, KeyError, OSError, RuntimeError) as e:
logger.warning(f"Failed to register installed skill: {e}") logger.warning(f"Failed to register installed skill: {e}")
if not registration_ok: if not registration_ok:
# Remove the invalid YAML file and report error # Remove the invalid YAML file and report error
try: try:
os.remove(file_path) os.remove(file_path)
except Exception: except OSError:
pass pass
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed") raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
@ -419,7 +419,7 @@ async def uninstall_skill(name: str, req: Request):
try: try:
skill_registry.get(validated_name) skill_registry.get(validated_name)
except Exception: except (KeyError, ValueError, RuntimeError):
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found") raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
# Remove from registry # Remove from registry
@ -487,7 +487,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ
try: try:
result = await pipeline.execute(input_data=request.input_data) result = await pipeline.execute(input_data=request.input_data)
except Exception as e: except (ValueError, TypeError, KeyError, RuntimeError, OSError, ConnectionError) as e:
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True) logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Pipeline execution failed") raise HTTPException(status_code=500, detail="Pipeline execution failed")

View File

@ -24,7 +24,7 @@ def _read_meminfo() -> dict[str, int]:
key = parts[0].strip() key = parts[0].strip()
value = parts[1].strip().split()[0] value = parts[1].strip().split()[0]
values[key] = int(value) * 1024 # kB -> bytes values[key] = int(value) * 1024 # kB -> bytes
except Exception as exc: except (OSError, ValueError, FileNotFoundError) as exc:
logger.debug(f"Failed to read /proc/meminfo: {exc}") logger.debug(f"Failed to read /proc/meminfo: {exc}")
return values return values
@ -52,7 +52,7 @@ async def get_system_resources() -> dict[str, Any]:
if hasattr(os, "getloadavg"): if hasattr(os, "getloadavg"):
try: try:
loadavg = list(os.getloadavg()) loadavg = list(os.getloadavg())
except Exception as exc: except (OSError, AttributeError) as exc:
logger.debug(f"Failed to get loadavg: {exc}") logger.debug(f"Failed to get loadavg: {exc}")
meminfo = _read_meminfo() meminfo = _read_meminfo()
@ -68,7 +68,7 @@ async def get_system_resources() -> dict[str, Any]:
disk_total = du.total disk_total = du.total
disk_used = du.used disk_used = du.used
disk_free = du.free disk_free = du.free
except Exception as exc: except (OSError, FileNotFoundError) as exc:
logger.debug(f"Failed to get disk usage: {exc}") logger.debug(f"Failed to get disk usage: {exc}")
return { return {

View File

@ -83,7 +83,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
elif request.skill_name: elif request.skill_name:
try: try:
skill = skill_registry.get(request.skill_name) skill = skill_registry.get(request.skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Skill '{request.skill_name}' not found", detail=f"Skill '{request.skill_name}' not found",
@ -145,7 +145,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
quality_result = await quality_gate.validate( quality_result = await quality_gate.validate(
task_result.output_data or {}, skill, skill_context=skill_context task_result.output_data or {}, skill, skill_context=skill_context
) )
except Exception: except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError):
pass # Quality gate failure shouldn't block the response pass # Quality gate failure shouldn't block the response
# 7. Standardize output if skill available # 7. Standardize output if skill available
@ -167,7 +167,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
"task_id": task.task_id, "task_id": task.task_id,
"status": task_result.status, "status": task_result.status,
} }
except Exception: except (ValueError, KeyError, AttributeError, RuntimeError):
pass # Fall through to raw output pass # Fall through to raw output
# 8. Return raw result if no skill or standardization failed # 8. Return raw result if no skill or standardization failed
@ -307,7 +307,7 @@ async def resume_task(task_id: str, req: Request, plan_id: str | None = None):
finally: finally:
try: try:
await team.dissolve() await team.dissolve()
except Exception: except (RuntimeError, asyncio.TimeoutError, ConnectionError):
pass pass
return { return {
@ -343,7 +343,7 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
elif request.skill_name: elif request.skill_name:
try: try:
skill_registry.get(request.skill_name) skill_registry.get(request.skill_name)
except Exception: except (KeyError, ValueError, AttributeError):
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Skill '{request.skill_name}' not found", detail=f"Skill '{request.skill_name}' not found",

View File

@ -359,14 +359,14 @@ async def execute_command(
await check_pty.close() await check_pty.close()
if cwd_result.exit_code == 0: if cwd_result.exit_code == 0:
state.cwd = cwd_result.output.strip() state.cwd = cwd_result.output.strip()
except Exception: except (OSError, RuntimeError, asyncio.TimeoutError):
pass pass
except asyncio.TimeoutError: except asyncio.TimeoutError:
output = f"命令执行超时({request.timeout}s" output = f"命令执行超时({request.timeout}s"
exit_code = -1 exit_code = -1
duration_ms = int((time.monotonic() - start_time) * 1000) duration_ms = int((time.monotonic() - start_time) * 1000)
except Exception as e: except (OSError, RuntimeError, ValueError) as e:
output = str(e) output = str(e)
exit_code = -1 exit_code = -1
duration_ms = int((time.monotonic() - start_time) * 1000) duration_ms = int((time.monotonic() - start_time) * 1000)
@ -515,7 +515,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
}) })
await websocket.close(code=4003, reason="Permission denied") await websocket.close(code=4003, reason="Permission denied")
return return
except Exception: except (ValueError, KeyError, RuntimeError, OSError):
pass # Fall through to API key / dev mode pass # Fall through to API key / dev mode
# 2. API key via ?api_key= # 2. API key via ?api_key=
@ -721,7 +721,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
}) })
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except Exception: except (OSError, RuntimeError):
break break
# Get final result # Get final result
@ -741,7 +741,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
await check_pty.close() await check_pty.close()
if cwd_result.exit_code == 0: if cwd_result.exit_code == 0:
state.cwd = cwd_result.output.strip() state.cwd = cwd_result.output.strip()
except Exception: except (OSError, RuntimeError, asyncio.TimeoutError):
pass pass
# Record history # Record history
@ -790,6 +790,8 @@ async def terminal_websocket(websocket: WebSocket) -> None:
"cwd": state.cwd, "cwd": state.cwd,
"duration_ms": duration_ms, "duration_ms": duration_ms,
}) })
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
duration_ms = int((time.monotonic() - start_time) * 1000) duration_ms = int((time.monotonic() - start_time) * 1000)
await websocket.send_json({ await websocket.send_json({
@ -800,7 +802,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running: if active_pty and active_pty.is_running:
try: try:
await active_pty.close() await active_pty.close()
except Exception: except (OSError, RuntimeError):
pass pass
active_pty = None active_pty = None
@ -836,11 +838,13 @@ async def terminal_websocket(websocket: WebSocket) -> None:
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}") logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}")
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}") logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}")
try: try:
await websocket.send_json({"type": "error", "message": str(e)[:200]}) await websocket.send_json({"type": "error", "message": str(e)[:200]})
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass pass
finally: finally:
# Clean up pending confirmations # Clean up pending confirmations
@ -851,5 +855,5 @@ async def terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running: if active_pty and active_pty.is_running:
try: try:
await active_pty.close() await active_pty.close()
except Exception: except (OSError, RuntimeError):
pass pass

View File

@ -549,7 +549,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
}) })
await websocket.close(code=4003, reason="Permission denied") await websocket.close(code=4003, reason="Permission denied")
return return
except Exception: except (ValueError, KeyError, RuntimeError, OSError):
pass pass
if not auth_ok: if not auth_ok:
@ -585,7 +585,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
}) })
await websocket.close(code=4003, reason="Not authorized") await websocket.close(code=4003, reason="Not authorized")
return return
except Exception as e: except (aiosqlite.Error, OSError, ValueError, KeyError, RuntimeError) as e:
logger.warning(f"Failed to check server terminal authorization: {e}") logger.warning(f"Failed to check server terminal authorization: {e}")
# If DB check fails, deny access # If DB check fails, deny access
await websocket.send_json({ await websocket.send_json({
@ -822,7 +822,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
}) })
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except Exception: except (OSError, RuntimeError):
break break
result = await run_task result = await run_task
@ -846,7 +846,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
else: else:
base = state.cwd or os.path.expanduser("~") base = state.cwd or os.path.expanduser("~")
state.cwd = os.path.normpath(os.path.join(base, cd_arg)) state.cwd = os.path.normpath(os.path.join(base, cd_arg))
except Exception: except (ValueError, OSError, RuntimeError):
pass # cd parsing failed — keep old cwd pass # cd parsing failed — keep old cwd
# Record history # Record history
@ -893,6 +893,8 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
"cwd": state.cwd, "cwd": state.cwd,
"duration_ms": duration_ms, "duration_ms": duration_ms,
}) })
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
duration_ms = int((time.monotonic() - start_time) * 1000) duration_ms = int((time.monotonic() - start_time) * 1000)
await websocket.send_json({ await websocket.send_json({
@ -903,7 +905,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
if active_pty and active_pty.is_running: if active_pty and active_pty.is_running:
try: try:
await active_pty.close() await active_pty.close()
except Exception: except (OSError, RuntimeError):
pass pass
active_pty = None active_pty = None
@ -942,17 +944,19 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}") logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}")
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}") logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}")
try: try:
await websocket.send_json({"type": "error", "message": str(e)[:200]}) await websocket.send_json({"type": "error", "message": str(e)[:200]})
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass pass
finally: finally:
if active_pty and active_pty.is_running: if active_pty and active_pty.is_running:
try: try:
await active_pty.close() await active_pty.close()
except Exception: except (OSError, RuntimeError):
pass pass
# Clean up session from in-memory stores to prevent leaks # Clean up session from in-memory stores to prevent leaks
_server_sessions.pop(state.session_id, None) _server_sessions.pop(state.session_id, None)

View File

@ -433,6 +433,8 @@ async def _execute_workflow(
) )
result = await agent.handle_task(task) result = await agent.handle_task(task)
stage_result = {"output": result, "skill": stage.action} stage_result = {"output": result, "skill": stage.action}
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
stage_result = {"error": str(e), "skill": stage.action} stage_result = {"error": str(e), "skill": stage.action}
else: else:
@ -474,6 +476,8 @@ async def _execute_workflow(
) )
result = await agent.handle_task(task) result = await agent.handle_task(task)
return {"output": result, "skill": action} return {"output": result, "skill": action}
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
return {"error": str(e), "skill": action} return {"error": str(e), "skill": action}
return {"dry_run": True, "action": action} return {"dry_run": True, "action": action}
@ -515,6 +519,8 @@ async def _execute_workflow(
execution_id=execution.execution_id, execution_id=execution.execution_id,
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
execution.stage_results[stage_name] = { execution.stage_results[stage_name] = {
"status": "failed", "status": "failed",
@ -633,7 +639,7 @@ async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None
for ws in targets: for ws in targets:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
disconnected.append(ws) disconnected.append(ws)
if disconnected: if disconnected:
async with _ws_subscribers_lock: async with _ws_subscribers_lock:
@ -952,7 +958,7 @@ async def workflow_websocket(websocket: WebSocket):
# Keep connection alive - messages are primarily server-push # Keep connection alive - messages are primarily server-push
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug("Workflow WebSocket disconnected") logger.debug("Workflow WebSocket disconnected")
except Exception as e: except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
logger.error(f"Workflow WebSocket error: {e}") logger.error(f"Workflow WebSocket error: {e}")
finally: finally:
if subscribed_execution_id: if subscribed_execution_id:

View File

@ -47,7 +47,7 @@ class ConnectionManager:
for ws, _ in conns: for ws, _ in conns:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
stale.append(ws) stale.append(ws)
for ws in stale: for ws in stale:
self.remove(task_id, ws) self.remove(task_id, ws)
@ -153,6 +153,8 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug(f"WebSocket disconnected for task {task_id}") logger.debug(f"WebSocket disconnected for task {task_id}")
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
logger.error(f"WebSocket error for task {task_id}: {e}") logger.error(f"WebSocket error for task {task_id}: {e}")
try: try:
@ -162,7 +164,7 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
"data": {"message": str(e)}, "data": {"message": str(e)},
} }
) )
except Exception: except (ConnectionError, RuntimeError, asyncio.TimeoutError):
pass pass
finally: finally:
manager.remove(task_id, websocket) manager.remove(task_id, websocket)
@ -243,6 +245,8 @@ async def _run_react_and_stream(
}, },
) )
except asyncio.CancelledError:
raise
except Exception as e: except Exception as e:
await websocket.send_json( await websocket.send_json(
{ {