refactor: follow-up tech debt cleanup (except Exception + Any 治理) (#9)
This commit is contained in:
parent
cc531d0663
commit
838a05772e
|
|
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
|
from agentkit.bitable.formula.functions import AGGREGATE_FUNCTIONS, FUNCTION_REGISTRY
|
||||||
from agentkit.bitable.formula.parser import (
|
from agentkit.bitable.formula.parser import (
|
||||||
|
|
@ -104,9 +103,9 @@ class FormulaEngine:
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
field_id: str,
|
field_id: str,
|
||||||
row_values: dict[str, Any],
|
row_values: dict[str, object],
|
||||||
column_values: dict[str, list[Any]] | None = None,
|
column_values: dict[str, list[object]] | None = None,
|
||||||
) -> Any:
|
) -> object:
|
||||||
"""Evaluate a formula field for a specific record.
|
"""Evaluate a formula field for a specific record.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -130,7 +129,7 @@ class FormulaEngine:
|
||||||
|
|
||||||
# Build the field_values dict for the evaluator
|
# Build the field_values dict for the evaluator
|
||||||
# Aggregate refs get column values (lists), row refs get row values (scalars)
|
# Aggregate refs get column values (lists), row refs get row values (scalars)
|
||||||
eval_values: dict[str, Any] = {}
|
eval_values: dict[str, object] = {}
|
||||||
|
|
||||||
# Map real field IDs to safe names
|
# Map real field IDs to safe names
|
||||||
for safe_name, real_id in entry.field_mapping.items():
|
for safe_name, real_id in entry.field_mapping.items():
|
||||||
|
|
@ -143,16 +142,16 @@ class FormulaEngine:
|
||||||
|
|
||||||
def evaluate_all_for_record(
|
def evaluate_all_for_record(
|
||||||
self,
|
self,
|
||||||
row_values: dict[str, Any],
|
row_values: dict[str, object],
|
||||||
column_values: dict[str, list[Any]] | None = None,
|
column_values: dict[str, list[object]] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Evaluate all registered formulas for a record.
|
"""Evaluate all registered formulas for a record.
|
||||||
|
|
||||||
Returns a dict of field_id → computed value.
|
Returns a dict of field_id → computed value.
|
||||||
Formulas are evaluated in topological order so that formula-to-formula
|
Formulas are evaluated in topological order so that formula-to-formula
|
||||||
dependencies are resolved correctly.
|
dependencies are resolved correctly.
|
||||||
"""
|
"""
|
||||||
results: dict[str, Any] = {}
|
results: dict[str, object] = {}
|
||||||
column_values = column_values or {}
|
column_values = column_values or {}
|
||||||
|
|
||||||
for field_id in self.topological_order():
|
for field_id in self.topological_order():
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,11 @@ Usage::
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def transform_records(
|
def transform_records(
|
||||||
records: list[dict[str, Any]],
|
records: list[dict[str, object]],
|
||||||
field_mapping: dict[str, str],
|
field_mapping: dict[str, str],
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
"""Map source record keys to bitable field IDs via field_mapping.
|
"""Map source record keys to bitable field IDs via field_mapping.
|
||||||
|
|
||||||
Keys not in ``field_mapping`` are dropped. Values are passed through
|
Keys not in ``field_mapping`` are dropped. Values are passed through
|
||||||
|
|
@ -40,9 +38,9 @@ def transform_records(
|
||||||
if not field_mapping:
|
if not field_mapping:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
transformed: list[dict[str, Any]] = []
|
transformed: list[dict[str, object]] = []
|
||||||
for rec in records:
|
for rec in records:
|
||||||
out: dict[str, Any] = {}
|
out: dict[str, object] = {}
|
||||||
for src_key, field_id in field_mapping.items():
|
for src_key, field_id in field_mapping.items():
|
||||||
if src_key in rec:
|
if src_key in rec:
|
||||||
out[field_id] = rec[src_key]
|
out[field_id] = rec[src_key]
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ Type mapping (KTD: DB → bitable):
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
BigInteger,
|
BigInteger,
|
||||||
|
|
@ -56,7 +55,7 @@ DB_TYPE_MAP: dict[type, str] = {
|
||||||
READ_BATCH = 1000
|
READ_BATCH = 1000
|
||||||
|
|
||||||
|
|
||||||
def infer_field_type(sqla_type: Any) -> str:
|
def infer_field_type(sqla_type: object) -> str:
|
||||||
"""Map a SQLAlchemy column type instance or class to a bitable field type.
|
"""Map a SQLAlchemy column type instance or class to a bitable field type.
|
||||||
|
|
||||||
Handles both type instances (``Integer()``) and type classes (``Integer``).
|
Handles both type instances (``Integer()``) and type classes (``Integer``).
|
||||||
|
|
@ -78,7 +77,7 @@ def import_table(
|
||||||
table_name: str,
|
table_name: str,
|
||||||
*,
|
*,
|
||||||
max_rows: int = 50_000,
|
max_rows: int = 50_000,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Reflect a single table from an external DB.
|
"""Reflect a single table from an external DB.
|
||||||
|
|
||||||
Returns ``{"table_name": str, "fields": [...], "records": [...],
|
Returns ``{"table_name": str, "fields": [...], "records": [...],
|
||||||
|
|
@ -97,7 +96,7 @@ def import_table(
|
||||||
engine.dispose()
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, Any]:
|
def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[str, object]:
|
||||||
"""Reflect one table and read its rows."""
|
"""Reflect one table and read its rows."""
|
||||||
insp = inspect(engine)
|
insp = inspect(engine)
|
||||||
|
|
||||||
|
|
@ -111,7 +110,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
table = Table(table_name, metadata, autoload_with=engine)
|
table = Table(table_name, metadata, autoload_with=engine)
|
||||||
|
|
||||||
# Build field definitions
|
# Build field definitions
|
||||||
fields: list[dict[str, Any]] = []
|
fields: list[dict[str, object]] = []
|
||||||
pk_columns = list(table.primary_key.columns)
|
pk_columns = list(table.primary_key.columns)
|
||||||
pk_name = pk_columns[0].name if pk_columns else None
|
pk_name = pk_columns[0].name if pk_columns else None
|
||||||
|
|
||||||
|
|
@ -131,14 +130,14 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
pk_name = "id"
|
pk_name = "id"
|
||||||
|
|
||||||
# Read rows
|
# Read rows
|
||||||
records: list[dict[str, Any]] = []
|
records: list[dict[str, object]] = []
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
result = conn.execute(select(table))
|
result = conn.execute(select(table))
|
||||||
for i, row in enumerate(result):
|
for i, row in enumerate(result):
|
||||||
if i >= max_rows:
|
if i >= max_rows:
|
||||||
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
|
logger.warning("Table %r truncated at %d rows during import", table_name, max_rows)
|
||||||
break
|
break
|
||||||
rec: dict[str, Any] = {}
|
rec: dict[str, object] = {}
|
||||||
for col in table.columns:
|
for col in table.columns:
|
||||||
val = getattr(row, col.name, None)
|
val = getattr(row, col.name, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
|
|
@ -155,7 +154,7 @@ def _reflect_and_read(engine: Engine, table_name: str, max_rows: int) -> dict[st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _serialize(val: Any) -> Any:
|
def _serialize(val: object) -> object:
|
||||||
"""Serialize a DB value to JSON-safe form."""
|
"""Serialize a DB value to JSON-safe form."""
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ import logging
|
||||||
import socket
|
import socket
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -36,7 +35,7 @@ class ParsedSheet:
|
||||||
name: str
|
name: str
|
||||||
columns: list[str] = field(default_factory=list)
|
columns: list[str] = field(default_factory=list)
|
||||||
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
|
field_types: list[str] = field(default_factory=list) # "text" | "number" | "date"
|
||||||
records: list[dict[str, Any]] = field(default_factory=list)
|
records: list[dict[str, object]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
|
def parse_excel(file_path: str | Path) -> list[ParsedSheet]:
|
||||||
|
|
@ -182,9 +181,9 @@ def _parse_worksheet(ws) -> ParsedSheet | None:
|
||||||
col_count = len(clean_headers)
|
col_count = len(clean_headers)
|
||||||
field_types = _infer_column_types(data_rows, col_count)
|
field_types = _infer_column_types(data_rows, col_count)
|
||||||
|
|
||||||
records: list[dict[str, Any]] = []
|
records: list[dict[str, object]] = []
|
||||||
for row in data_rows:
|
for row in data_rows:
|
||||||
rec: dict[str, Any] = {}
|
rec: dict[str, object] = {}
|
||||||
for i, col_name in enumerate(clean_headers):
|
for i, col_name in enumerate(clean_headers):
|
||||||
val = row[i] if i < len(row) else None
|
val = row[i] if i < len(row) else None
|
||||||
if val is not None:
|
if val is not None:
|
||||||
|
|
@ -237,7 +236,7 @@ def _infer_column_types(rows: list[tuple], col_count: int) -> list[str]:
|
||||||
return types
|
return types
|
||||||
|
|
||||||
|
|
||||||
def _coerce_value(val: Any, field_type: str) -> Any:
|
def _coerce_value(val: object, field_type: str) -> object:
|
||||||
"""Coerce a cell value to the inferred field type. Truncate long strings."""
|
"""Coerce a cell value to the inferred field type. Truncate long strings."""
|
||||||
if field_type == "date":
|
if field_type == "date":
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field as PydanticField
|
from pydantic import BaseModel, ConfigDict, Field as PydanticField
|
||||||
|
|
||||||
|
# ponytail: bitable JSONB columns hold arbitrary JSON. Using `object` instead of
|
||||||
|
# a recursive TypeAlias because Pydantic v2 cannot build a schema for recursive
|
||||||
|
# named aliases (RecursionError). `object` is the most permissive type and
|
||||||
|
# Pydantic v2 serializes dict/list/primitive values fine at runtime.
|
||||||
|
|
||||||
|
|
||||||
def _utcnow() -> datetime:
|
def _utcnow() -> datetime:
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
@ -97,7 +101,7 @@ class Table(BaseModel):
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Status select field options — labels and colors match Feishu Bitable defaults.
|
# Status select field options — labels and colors match Feishu Bitable defaults.
|
||||||
_STATUS_OPTIONS: list[dict[str, Any]] = [
|
_STATUS_OPTIONS: list[dict[str, object]] = [
|
||||||
{"label": "未开始", "value": "not_started", "color": "default"},
|
{"label": "未开始", "value": "not_started", "color": "default"},
|
||||||
{"label": "进行中", "value": "in_progress", "color": "processing"},
|
{"label": "进行中", "value": "in_progress", "color": "processing"},
|
||||||
{"label": "已完成", "value": "done", "color": "success"},
|
{"label": "已完成", "value": "done", "color": "success"},
|
||||||
|
|
@ -106,7 +110,7 @@ _STATUS_OPTIONS: list[dict[str, Any]] = [
|
||||||
#: Templates for the 5 default fields created on every new table (R2).
|
#: Templates for the 5 default fields created on every new table (R2).
|
||||||
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
|
#: agent-owned fields (创建人/创建时间) are auto-filled by the service layer
|
||||||
#: on record creation; user-owned fields are user-editable.
|
#: on record creation; user-owned fields are user-editable.
|
||||||
DEFAULT_FIELD_TEMPLATES: list[dict[str, Any]] = [
|
DEFAULT_FIELD_TEMPLATES: list[dict[str, object]] = [
|
||||||
{
|
{
|
||||||
"name": "标题",
|
"name": "标题",
|
||||||
"field_type": FieldType.text,
|
"field_type": FieldType.text,
|
||||||
|
|
@ -155,7 +159,7 @@ class Field(BaseModel):
|
||||||
table_id: str
|
table_id: str
|
||||||
name: str
|
name: str
|
||||||
field_type: FieldType
|
field_type: FieldType
|
||||||
config: dict[str, Any] = PydanticField(default_factory=dict)
|
config: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
owner: FieldOwner = FieldOwner.user
|
owner: FieldOwner = FieldOwner.user
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
@ -167,7 +171,7 @@ class Record(BaseModel):
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
table_id: str
|
table_id: str
|
||||||
values: dict[str, Any] = PydanticField(default_factory=dict)
|
values: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
updated_at: datetime = PydanticField(default_factory=_utcnow)
|
updated_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
@ -181,7 +185,7 @@ class View(BaseModel):
|
||||||
table_id: str
|
table_id: str
|
||||||
name: str
|
name: str
|
||||||
view_type: ViewType = ViewType.grid
|
view_type: ViewType = ViewType.grid
|
||||||
config: dict[str, Any] = PydanticField(default_factory=dict)
|
config: dict[str, object] = PydanticField(default_factory=dict)
|
||||||
created_at: datetime = PydanticField(default_factory=_utcnow)
|
created_at: datetime = PydanticField(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.bitable.db import BitableDB
|
from agentkit.bitable.db import BitableDB
|
||||||
from agentkit.bitable.formula.engine import FormulaEngine
|
from agentkit.bitable.formula.engine import FormulaEngine
|
||||||
from agentkit.bitable.models import FieldType, RecalcStatus
|
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
|
||||||
from agentkit.bitable.repository import BitableRepository
|
from agentkit.bitable.repository import BitableRepository
|
||||||
from agentkit.bitable.service import BitableService
|
from agentkit.bitable.service import BitableService
|
||||||
|
|
||||||
|
|
@ -124,7 +123,7 @@ class RecalcWorker:
|
||||||
logger.exception("RecalcWorker error in main loop")
|
logger.exception("RecalcWorker error in main loop")
|
||||||
await asyncio.sleep(self._poll_interval)
|
await asyncio.sleep(self._poll_interval)
|
||||||
|
|
||||||
async def _sort_by_topological_order(self, tasks: list[Any]) -> list[Any]:
|
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
|
||||||
"""Sort claimed tasks so dependencies are processed first (P1 #7).
|
"""Sort claimed tasks so dependencies are processed first (P1 #7).
|
||||||
|
|
||||||
Groups tasks by table_id, builds (or reuses) the engine to get the
|
Groups tasks by table_id, builds (or reuses) the engine to get the
|
||||||
|
|
@ -146,7 +145,7 @@ class RecalcWorker:
|
||||||
order = engine.topological_order()
|
order = engine.topological_order()
|
||||||
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
|
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
|
||||||
|
|
||||||
def _key(t: Any) -> tuple[str, str, int]:
|
def _key(t: RecalcTask) -> tuple[str, str, int]:
|
||||||
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
|
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
|
||||||
return (t.table_id, t.record_id, idx)
|
return (t.table_id, t.record_id, idx)
|
||||||
|
|
||||||
|
|
@ -175,7 +174,7 @@ class RecalcWorker:
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("RecalcWorker reaper error")
|
logger.exception("RecalcWorker reaper error")
|
||||||
|
|
||||||
async def process_task(self, task: Any) -> None:
|
async def process_task(self, task: RecalcTask) -> None:
|
||||||
"""Process a single recalc task: evaluate formula → write result.
|
"""Process a single recalc task: evaluate formula → write result.
|
||||||
|
|
||||||
The task is expected to already be in ``calculating`` status when
|
The task is expected to already be in ``calculating`` status when
|
||||||
|
|
@ -216,7 +215,7 @@ class RecalcWorker:
|
||||||
return
|
return
|
||||||
|
|
||||||
deps = engine.get_dependencies(task.field_id)
|
deps = engine.get_dependencies(task.field_id)
|
||||||
column_values: dict[str, list[Any]] = {}
|
column_values: dict[str, list[object]] = {}
|
||||||
for dep_field_id in deps:
|
for dep_field_id in deps:
|
||||||
column_values[dep_field_id] = await self._repo.get_column_values(
|
column_values[dep_field_id] = await self._repo.get_column_values(
|
||||||
task.table_id, dep_field_id
|
task.table_id, dep_field_id
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import delete, func, insert, select, text, update
|
from sqlalchemy import delete, func, insert, select, text, update
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
@ -102,7 +101,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [BitableFile.model_validate(e) for e in result.scalars().all()]
|
return [BitableFile.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
|
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
|
||||||
"""Update a file's attributes."""
|
"""Update a file's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -181,7 +180,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [Table.model_validate(e) for e in result.scalars().all()]
|
return [Table.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
|
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
|
||||||
"""Update a table's attributes."""
|
"""Update a table's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -236,7 +235,7 @@ class BitableRepository:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
field_type: FieldType,
|
field_type: FieldType,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
owner: FieldOwner = FieldOwner.user,
|
owner: FieldOwner = FieldOwner.user,
|
||||||
) -> Field:
|
) -> Field:
|
||||||
"""Create a new field in a table."""
|
"""Create a new field in a table."""
|
||||||
|
|
@ -277,7 +276,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [Field.model_validate(e) for e in result.scalars().all()]
|
return [Field.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
|
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
|
||||||
"""Update a field's attributes."""
|
"""Update a field's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -300,7 +299,7 @@ class BitableRepository:
|
||||||
|
|
||||||
# ── Records ─────────────────────────────────────────────
|
# ── Records ─────────────────────────────────────────────
|
||||||
|
|
||||||
async def create_record(self, table_id: str, values: dict[str, Any] | None = None) -> Record:
|
async def create_record(self, table_id: str, values: dict[str, object] | None = None) -> Record:
|
||||||
"""Create a new record."""
|
"""Create a new record."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -318,7 +317,7 @@ class BitableRepository:
|
||||||
return Record.model_validate(entity)
|
return Record.model_validate(entity)
|
||||||
|
|
||||||
async def create_records_batch(
|
async def create_records_batch(
|
||||||
self, table_id: str, records_values: list[dict[str, Any]]
|
self, table_id: str, records_values: list[dict[str, object]]
|
||||||
) -> list[Record]:
|
) -> list[Record]:
|
||||||
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
|
"""Batch-insert multiple records (P2 #19: eliminates per-record INSERT).
|
||||||
|
|
||||||
|
|
@ -376,7 +375,7 @@ class BitableRepository:
|
||||||
|
|
||||||
return [Record.model_validate(e) for e in entities], next_cursor
|
return [Record.model_validate(e) for e in entities], next_cursor
|
||||||
|
|
||||||
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
|
async def update_record_values(self, record_id: str, values: dict[str, object]) -> Record | None:
|
||||||
"""Update a record's values (full replace)."""
|
"""Update a record's values (full replace)."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -413,7 +412,7 @@ class BitableRepository:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
view_type: ViewType = ViewType.grid,
|
view_type: ViewType = ViewType.grid,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
) -> View:
|
) -> View:
|
||||||
"""Create a new view."""
|
"""Create a new view."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
|
|
@ -451,7 +450,7 @@ class BitableRepository:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [View.model_validate(e) for e in result.scalars().all()]
|
return [View.model_validate(e) for e in result.scalars().all()]
|
||||||
|
|
||||||
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
|
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
|
||||||
"""Update a view's attributes."""
|
"""Update a view's attributes."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|
@ -543,7 +542,7 @@ class BitableRepository:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a recalc task's status."""
|
"""Update a recalc task's status."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
kwargs: dict[str, Any] = {"status": status.value}
|
kwargs: dict[str, object] = {"status": status.value}
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
kwargs["error_message"] = error_message
|
kwargs["error_message"] = error_message
|
||||||
if status in (RecalcStatus.done, RecalcStatus.error):
|
if status in (RecalcStatus.done, RecalcStatus.error):
|
||||||
|
|
@ -630,7 +629,7 @@ class BitableRepository:
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
async def upsert_record_agent_fields(
|
async def upsert_record_agent_fields(
|
||||||
self, record_id: str, agent_field_values: dict[str, Any]
|
self, record_id: str, agent_field_values: dict[str, object]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update agent-owned fields using jsonb_set (KTD8).
|
"""Update agent-owned fields using jsonb_set (KTD8).
|
||||||
|
|
||||||
|
|
@ -646,7 +645,7 @@ class BitableRepository:
|
||||||
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
|
# Use CAST(:param AS jsonb) instead of :param::jsonb — asyncpg dialect
|
||||||
# misparses the `::` as part of the param name.
|
# misparses the `::` as part of the param name.
|
||||||
inner = "values"
|
inner = "values"
|
||||||
params: dict[str, Any] = {"record_id": record_id}
|
params: dict[str, object] = {"record_id": record_id}
|
||||||
for i, (field_id, value) in enumerate(agent_field_values.items()):
|
for i, (field_id, value) in enumerate(agent_field_values.items()):
|
||||||
param_key = f"v{i}"
|
param_key = f"v{i}"
|
||||||
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
|
inner = f"jsonb_set({inner}, '{{{field_id}}}', CAST(:{param_key} AS jsonb), true)"
|
||||||
|
|
@ -660,8 +659,8 @@ class BitableRepository:
|
||||||
async def list_records_filtered(
|
async def list_records_filtered(
|
||||||
self,
|
self,
|
||||||
table_id: str,
|
table_id: str,
|
||||||
filters: list[dict[str, Any]] | None = None,
|
filters: list[dict[str, object]] | None = None,
|
||||||
sorts: list[dict[str, Any]] | None = None,
|
sorts: list[dict[str, object]] | None = None,
|
||||||
cursor: str | None = None,
|
cursor: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
) -> tuple[list[Record], str | None]:
|
) -> tuple[list[Record], str | None]:
|
||||||
|
|
@ -682,7 +681,7 @@ class BitableRepository:
|
||||||
# Build raw SQL with JSONB filter/sort translation.
|
# Build raw SQL with JSONB filter/sort translation.
|
||||||
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
|
# ponytail: field_ids in filters/sorts are system UUIDs (validated by service layer).
|
||||||
where_clauses = ["table_id = :table_id"]
|
where_clauses = ["table_id = :table_id"]
|
||||||
params: dict[str, Any] = {"table_id": table_id}
|
params: dict[str, object] = {"table_id": table_id}
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
for i, f in enumerate(filters):
|
for i, f in enumerate(filters):
|
||||||
|
|
@ -783,7 +782,7 @@ class BitableRepository:
|
||||||
last_mapping = last_row._mapping
|
last_mapping = last_row._mapping
|
||||||
# Build composite cursor from sort values + id.
|
# Build composite cursor from sort values + id.
|
||||||
# Sort values are extracted as text to match `values->>'fid'` expressions.
|
# Sort values are extracted as text to match `values->>'fid'` expressions.
|
||||||
sv: list[Any] = []
|
sv: list[object] = []
|
||||||
last_values = last_mapping.get("values")
|
last_values = last_mapping.get("values")
|
||||||
if isinstance(last_values, str):
|
if isinstance(last_values, str):
|
||||||
# asyncpg may return JSONB as str in raw text() queries.
|
# asyncpg may return JSONB as str in raw text() queries.
|
||||||
|
|
@ -852,7 +851,7 @@ class BitableRepository:
|
||||||
|
|
||||||
# ── Recalc support (U3) ────────────────────────────────
|
# ── Recalc support (U3) ────────────────────────────────
|
||||||
|
|
||||||
async def get_column_values(self, table_id: str, field_id: str) -> list[Any]:
|
async def get_column_values(self, table_id: str, field_id: str) -> list[object]:
|
||||||
"""Get all values for a field across all records in a table (for aggregates).
|
"""Get all values for a field across all records in a table (for aggregates).
|
||||||
|
|
||||||
Returns a list of values (preserving order by record id). Missing values
|
Returns a list of values (preserving order by record id). Missing values
|
||||||
|
|
@ -866,7 +865,7 @@ class BitableRepository:
|
||||||
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
|
result = await session.execute(sql, {"field_id": field_id, "table_id": table_id})
|
||||||
return [row[0] for row in result.fetchall()]
|
return [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
async def set_formula_value(self, record_id: str, field_id: str, value: Any) -> None:
|
async def set_formula_value(self, record_id: str, field_id: str, value: object) -> None:
|
||||||
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
|
"""Set a single formula field value in a record's JSONB (jsonb_set)."""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,12 +34,19 @@ import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Callable, TypeAlias
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
|
||||||
|
# 服务器返回的 skills/workflows 列表元素是 dict(model_dump/to_dict 输出),
|
||||||
|
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
|
||||||
|
SkillConfigDict: TypeAlias = dict[str, object]
|
||||||
|
WorkflowConfigDict: TypeAlias = dict[str, object]
|
||||||
|
SyncedConfigPayload: TypeAlias = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
# ── Defaults ──────────────────────────────────────────────────────────
|
# ── Defaults ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -100,8 +107,8 @@ class ConfigSync:
|
||||||
|
|
||||||
# In-memory cache (mirrors the SQLite cache for fast access)
|
# In-memory cache (mirrors the SQLite cache for fast access)
|
||||||
self._version: str | None = None
|
self._version: str | None = None
|
||||||
self._skills: list[dict[str, Any]] = []
|
self._skills: list[SkillConfigDict] = []
|
||||||
self._workflows: list[dict[str, Any]] = []
|
self._workflows: list[WorkflowConfigDict] = []
|
||||||
self._last_synced_at: str | None = None
|
self._last_synced_at: str | None = None
|
||||||
|
|
||||||
# ── Lifecycle ─────────────────────────────────────────────────
|
# ── Lifecycle ─────────────────────────────────────────────────
|
||||||
|
|
@ -232,15 +239,15 @@ class ConfigSync:
|
||||||
"""Return the current cached config version hash."""
|
"""Return the current cached config version hash."""
|
||||||
return self._version
|
return self._version
|
||||||
|
|
||||||
def get_skills(self) -> list[dict[str, Any]]:
|
def get_skills(self) -> list[SkillConfigDict]:
|
||||||
"""Return the cached skill configs."""
|
"""Return the cached skill configs."""
|
||||||
return list(self._skills)
|
return list(self._skills)
|
||||||
|
|
||||||
def get_workflows(self) -> list[dict[str, Any]]:
|
def get_workflows(self) -> list[WorkflowConfigDict]:
|
||||||
"""Return the cached workflow configs."""
|
"""Return the cached workflow configs."""
|
||||||
return list(self._workflows)
|
return list(self._workflows)
|
||||||
|
|
||||||
def get_all(self) -> dict[str, Any]:
|
def get_all(self) -> SyncedConfigPayload:
|
||||||
"""Return all cached configs as a single dict."""
|
"""Return all cached configs as a single dict."""
|
||||||
return {
|
return {
|
||||||
"version": self._version,
|
"version": self._version,
|
||||||
|
|
@ -249,14 +256,14 @@ class ConfigSync:
|
||||||
"synced_at": self._last_synced_at,
|
"synced_at": self._last_synced_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_skill(self, name: str) -> dict[str, Any] | None:
|
def get_skill(self, name: str) -> SkillConfigDict | None:
|
||||||
"""Return a single skill config by name, or ``None``."""
|
"""Return a single skill config by name, or ``None``."""
|
||||||
for skill in self._skills:
|
for skill in self._skills:
|
||||||
if skill.get("name") == name:
|
if skill.get("name") == name:
|
||||||
return skill
|
return skill
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_workflow(self, workflow_id: str) -> dict[str, Any] | None:
|
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
|
||||||
"""Return a single workflow config by ID, or ``None``."""
|
"""Return a single workflow config by ID, or ``None``."""
|
||||||
for wf in self._workflows:
|
for wf in self._workflows:
|
||||||
if wf.get("workflow_id") == workflow_id:
|
if wf.get("workflow_id") == workflow_id:
|
||||||
|
|
@ -281,7 +288,7 @@ class ConfigSync:
|
||||||
conn.executescript(_CACHE_SCHEMA)
|
conn.executescript(_CACHE_SCHEMA)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def _save_to_cache(self, data: dict[str, Any]) -> None:
|
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
|
||||||
"""Save the synced configs to the local SQLite cache."""
|
"""Save the synced configs to the local SQLite cache."""
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||||
|
|
||||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
@ -25,6 +25,52 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的第三方对象
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
class _RedisLike(Protocol):
|
||||||
|
"""Redis 客户端最小契约(仅覆盖本模块用到的方法)。"""
|
||||||
|
|
||||||
|
async def get(self, key: str) -> bytes | str | None: ...
|
||||||
|
|
||||||
|
async def mget(self, keys: list[str]) -> list[bytes | str | None]: ...
|
||||||
|
|
||||||
|
async def set(self, key: str, value: bytes | str, ex: int | None = None) -> None: ...
|
||||||
|
|
||||||
|
async def smembers(self, key: str) -> set[bytes | str]: ...
|
||||||
|
|
||||||
|
async def sadd(self, name: str, *values: str) -> int: ...
|
||||||
|
|
||||||
|
async def srem(self, name: str, *values: str) -> int: ...
|
||||||
|
|
||||||
|
async def scard(self, name: str) -> int: ...
|
||||||
|
|
||||||
|
async def delete(self, *names: str) -> int: ...
|
||||||
|
|
||||||
|
def pipeline(self) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
class _RedisPipelineLike(Protocol):
|
||||||
|
"""Redis pipeline 最小契约。"""
|
||||||
|
|
||||||
|
def get(self, key: str) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self, key: str, value: bytes | str, ex: int | None = None
|
||||||
|
) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
def delete(self, *names: str) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
def sadd(self, name: str, *values: str) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
def srem(self, name: str, *values: str) -> "_RedisPipelineLike": ...
|
||||||
|
|
||||||
|
async def execute(self) -> list[object]: ...
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Data Classes
|
# Data Classes
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -328,7 +374,7 @@ class RedisLLMCache:
|
||||||
self._semantic_ttl = semantic_ttl
|
self._semantic_ttl = semantic_ttl
|
||||||
self._similarity_threshold = similarity_threshold
|
self._similarity_threshold = similarity_threshold
|
||||||
self._max_entries_to_scan = max_entries_to_scan
|
self._max_entries_to_scan = max_entries_to_scan
|
||||||
self._redis: Any = None
|
self._redis: _RedisLike | None = None
|
||||||
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
|
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
|
||||||
self._degraded = False # True if Redis is unreachable
|
self._degraded = False # True if Redis is unreachable
|
||||||
|
|
||||||
|
|
@ -691,7 +737,7 @@ class LitellmCacheManager:
|
||||||
|
|
||||||
def __init__(self, config: LitellmCacheConfig):
|
def __init__(self, config: LitellmCacheConfig):
|
||||||
self._config = config
|
self._config = config
|
||||||
self._cache_instance: Any = None # litellm.caching.Cache 实例
|
self._cache_instance: object | None = None # litellm.caching.Cache 实例
|
||||||
self._hits = 0
|
self._hits = 0
|
||||||
self._misses = 0
|
self._misses = 0
|
||||||
|
|
||||||
|
|
@ -709,7 +755,7 @@ class LitellmCacheManager:
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
self._cache_instance = None
|
self._cache_instance = None
|
||||||
|
|
||||||
def _create_cache_instance(self) -> Any:
|
def _create_cache_instance(self) -> object:
|
||||||
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
||||||
|
|
||||||
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,13 @@
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def generate_cache_key(
|
def generate_cache_key(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
temperature: float,
|
temperature: float,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, object]] | None = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
from agentkit.llm.retry import CircuitBreakerConfig, RetryConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from agentkit.channels.secrets import SecretsStore
|
from agentkit.channels.secrets import SecretEntry, SecretsStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -56,7 +56,7 @@ class ProviderConfig:
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
base_url: str
|
base_url: str
|
||||||
models: dict[str, dict[str, Any]] = field(default_factory=dict)
|
models: dict[str, dict[str, object]] = field(default_factory=dict)
|
||||||
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
||||||
max_tokens: int = 4096 # Anthropic: default max_tokens
|
max_tokens: int = 4096 # Anthropic: default max_tokens
|
||||||
timeout: float = 120.0 # Anthropic: request timeout
|
timeout: float = 120.0 # Anthropic: request timeout
|
||||||
|
|
@ -168,18 +168,18 @@ class ProviderConfig:
|
||||||
return f"llm:provider:{self.type}:api_key"
|
return f"llm:provider:{self.type}:api_key"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _encode_secret_entry(entry: Any, key: str) -> str:
|
def _encode_secret_entry(entry: object, key: str) -> str:
|
||||||
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
|
"""把 SecretEntry 编码为 JSON 字符串(含 key 字段)。"""
|
||||||
# entry 是 SecretEntry pydantic 模型,有 model_dump()
|
# entry 是 SecretEntry pydantic 模型,有 model_dump()
|
||||||
if hasattr(entry, "model_dump"):
|
if hasattr(entry, "model_dump"):
|
||||||
data = entry.model_dump()
|
data = entry.model_dump() # type: ignore[attr-defined]
|
||||||
else:
|
else:
|
||||||
data = dict(entry)
|
data = dict(entry) # type: ignore[call-overload]
|
||||||
data["key"] = key
|
data["key"] = key
|
||||||
return json.dumps(data)
|
return json.dumps(data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _decode_secret_entry(encoded: str) -> Any:
|
def _decode_secret_entry(encoded: str) -> "SecretEntry":
|
||||||
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
|
"""从 JSON 字符串解码 SecretEntry。返回带 .key 属性的对象。"""
|
||||||
from agentkit.channels.secrets import SecretEntry
|
from agentkit.channels.secrets import SecretEntry
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||||
from agentkit.llm.config import LLMConfig
|
from agentkit.llm.config import LLMConfig
|
||||||
|
|
@ -14,9 +14,40 @@ from agentkit.llm.providers.tracker import UsageSummary, UsageTracker
|
||||||
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
|
from agentkit.telemetry.tracing import get_tracer, _OTEL_AVAILABLE
|
||||||
from agentkit.telemetry.metrics import llm_token_histogram
|
from agentkit.telemetry.metrics import llm_token_histogram
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.llm.cache import LitellmCacheManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TYPE_CHECKING Protocols — 避免 Any,描述运行时 lazy import 的对象
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
class _QuotaServiceLike(Protocol):
|
||||||
|
"""Quota service 最小契约(仅覆盖 gateway._enforce_quota 用到的方法)。"""
|
||||||
|
|
||||||
|
async def is_model_allowed(
|
||||||
|
self, db: Path, department_id: str, model: str
|
||||||
|
) -> tuple[bool, str]: ...
|
||||||
|
|
||||||
|
async def check_quota(
|
||||||
|
self,
|
||||||
|
db: Path,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
period: str,
|
||||||
|
current: float,
|
||||||
|
) -> tuple[bool, str]: ...
|
||||||
|
|
||||||
|
async def get_quota(
|
||||||
|
self, db: Path, department_id: str, quota_type: str, period: str
|
||||||
|
) -> dict[str, object] | None: ...
|
||||||
|
|
||||||
|
|
||||||
class QuotaExceededError(Exception):
|
class QuotaExceededError(Exception):
|
||||||
"""Raised when a department's LLM quota is exceeded.
|
"""Raised when a department's LLM quota is exceeded.
|
||||||
|
|
||||||
|
|
@ -29,8 +60,8 @@ class QuotaExceededError(Exception):
|
||||||
department_id: str,
|
department_id: str,
|
||||||
quota_type: str,
|
quota_type: str,
|
||||||
period: str,
|
period: str,
|
||||||
limit: Any,
|
limit: object,
|
||||||
current: Any,
|
current: object,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.department_id = department_id
|
self.department_id = department_id
|
||||||
self.quota_type = quota_type
|
self.quota_type = quota_type
|
||||||
|
|
@ -46,13 +77,13 @@ class QuotaExceededError(Exception):
|
||||||
class LLMGateway:
|
class LLMGateway:
|
||||||
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
||||||
|
|
||||||
def __init__(self, config: LLMConfig | None = None, usage_store: Any = None):
|
def __init__(self, config: LLMConfig | None = None, usage_store: object | None = None):
|
||||||
self._providers: dict[str, LLMProvider] = {}
|
self._providers: dict[str, LLMProvider] = {}
|
||||||
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
||||||
self._config = config or LLMConfig()
|
self._config = config or LLMConfig()
|
||||||
|
|
||||||
# Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用)
|
# Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用)
|
||||||
self._cache_manager: Any = None # LitellmCacheManager | None
|
self._cache_manager: "LitellmCacheManager | None" = None
|
||||||
if self._config.cache and self._config.cache.enabled:
|
if self._config.cache and self._config.cache.enabled:
|
||||||
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||||
|
|
||||||
|
|
@ -601,7 +632,7 @@ class LLMGateway:
|
||||||
|
|
||||||
async def _check_quota_value(
|
async def _check_quota_value(
|
||||||
self,
|
self,
|
||||||
quota_service: Any,
|
quota_service: _QuotaServiceLike,
|
||||||
db: Path,
|
db: Path,
|
||||||
dept_id: str,
|
dept_id: str,
|
||||||
period: str,
|
period: str,
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, Any]]:
|
def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str, object]]:
|
||||||
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
|
"""把 agentkit.yaml 中的 plaintext API Key 迁移到 SecretsStore。
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
|
|
@ -63,8 +62,8 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
|
||||||
|
|
||||||
store = SecretsStore() # master key 从 env 加载
|
store = SecretsStore() # master key 从 env 加载
|
||||||
|
|
||||||
async def _run() -> dict[str, dict[str, Any]]:
|
async def _run() -> dict[str, dict[str, object]]:
|
||||||
report: dict[str, dict[str, Any]] = {}
|
report: dict[str, dict[str, object]] = {}
|
||||||
for name, pconf in llm_config.providers.items():
|
for name, pconf in llm_config.providers.items():
|
||||||
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
|
if pconf.api_key_source == "secrets_store" and not pconf.api_key:
|
||||||
report[name] = {"status": "skipped", "source": pconf.api_key_source}
|
report[name] = {"status": "skipped", "source": pconf.api_key_source}
|
||||||
|
|
@ -93,9 +92,9 @@ def migrate_api_keys_to_secrets(config_path: Path | str) -> dict[str, dict[str,
|
||||||
report = asyncio.run(_run())
|
report = asyncio.run(_run())
|
||||||
|
|
||||||
# 写回 YAML:更新 llm.providers 段,保留其它段
|
# 写回 YAML:更新 llm.providers 段,保留其它段
|
||||||
providers_out: dict[str, dict[str, Any]] = {}
|
providers_out: dict[str, dict[str, object]] = {}
|
||||||
for name, pconf in llm_config.providers.items():
|
for name, pconf in llm_config.providers.items():
|
||||||
entry: dict[str, Any] = {
|
entry: dict[str, object] = {
|
||||||
"type": pconf.type,
|
"type": pconf.type,
|
||||||
"base_url": pconf.base_url,
|
"base_url": pconf.base_url,
|
||||||
"models": pconf.models,
|
"models": pconf.models,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -23,7 +22,7 @@ class ToolCall:
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: dict[str, Any]
|
arguments: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -32,7 +31,7 @@ class LLMRequest:
|
||||||
|
|
||||||
messages: list[dict[str, str]]
|
messages: list[dict[str, str]]
|
||||||
model: str
|
model: str
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, object]] | None = None
|
||||||
tool_choice: str = "auto"
|
tool_choice: str = "auto"
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_tokens: int = 2000
|
max_tokens: int = 2000
|
||||||
|
|
@ -42,13 +41,13 @@ class LLMRequest:
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
model: str,
|
model: str,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, object]] | None = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
cache: dict[str, Any] | None = None,
|
cache: dict[str, object] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
@ -59,7 +58,7 @@ class LLMRequest:
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._extra = kwargs
|
self._extra = kwargs
|
||||||
# U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion
|
# U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion
|
||||||
self._cache: dict[str, Any] | None = cache
|
self._cache: dict[str, object] | None = cache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -99,7 +98,9 @@ class AnthropicProvider(LLMProvider):
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]:
|
def _convert_messages(
|
||||||
|
self, messages: list[dict[str, str]]
|
||||||
|
) -> tuple[str | list[dict[str, object]] | None, list[dict[str, object]]]:
|
||||||
"""将 OpenAI 风格消息转换为 Anthropic 格式
|
"""将 OpenAI 风格消息转换为 Anthropic 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -110,8 +111,8 @@ class AnthropicProvider(LLMProvider):
|
||||||
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
|
- list[dict]: Anthropic content blocks(支持 cache_control,U2/G2)
|
||||||
- None: 无 system 消息
|
- None: 无 system 消息
|
||||||
"""
|
"""
|
||||||
system_prompt: str | list[dict[str, Any]] | None = None
|
system_prompt: str | list[dict[str, object]] | None = None
|
||||||
anthropic_messages: list[dict[str, Any]] = []
|
anthropic_messages: list[dict[str, object]] = []
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg.get("role", "")
|
role = msg.get("role", "")
|
||||||
|
|
@ -127,7 +128,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
# 检查是否有 tool_calls (OpenAI 格式)
|
# 检查是否有 tool_calls (OpenAI 格式)
|
||||||
tool_calls = msg.get("tool_calls")
|
tool_calls = msg.get("tool_calls")
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
blocks: list[dict[str, Any]] = []
|
blocks: list[dict[str, object]] = []
|
||||||
# 如果有文本内容,先添加文本块
|
# 如果有文本内容,先添加文本块
|
||||||
if content:
|
if content:
|
||||||
blocks.append({"type": "text", "text": content})
|
blocks.append({"type": "text", "text": content})
|
||||||
|
|
@ -139,25 +140,29 @@ class AnthropicProvider(LLMProvider):
|
||||||
arguments = json.loads(arguments)
|
arguments = json.loads(arguments)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
arguments = {"raw": arguments}
|
arguments = {"raw": arguments}
|
||||||
blocks.append({
|
blocks.append(
|
||||||
"type": "tool_use",
|
{
|
||||||
"id": tc.get("id", ""),
|
"type": "tool_use",
|
||||||
"name": func.get("name", ""),
|
"id": tc.get("id", ""),
|
||||||
"input": arguments,
|
"name": func.get("name", ""),
|
||||||
})
|
"input": arguments,
|
||||||
|
}
|
||||||
|
)
|
||||||
anthropic_messages.append({"role": "assistant", "content": blocks})
|
anthropic_messages.append({"role": "assistant", "content": blocks})
|
||||||
else:
|
else:
|
||||||
anthropic_messages.append({
|
anthropic_messages.append(
|
||||||
"role": "assistant",
|
{
|
||||||
"content": [{"type": "text", "text": content}],
|
"role": "assistant",
|
||||||
})
|
"content": [{"type": "text", "text": content}],
|
||||||
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
|
# 检查是否是 tool_result 消息 (OpenAI 格式中 tool 角色的结果)
|
||||||
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
# OpenAI 格式: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
||||||
if msg.get("tool_call_id"):
|
if msg.get("tool_call_id"):
|
||||||
tool_result_blocks: list[dict[str, Any]] = []
|
tool_result_blocks: list[dict[str, object]] = []
|
||||||
tool_content = msg.get("content", "")
|
tool_content = msg.get("content", "")
|
||||||
# tool_result 的 content 可以是字符串或内容块列表
|
# tool_result 的 content 可以是字符串或内容块列表
|
||||||
if isinstance(tool_content, str):
|
if isinstance(tool_content, str):
|
||||||
|
|
@ -167,56 +172,72 @@ class AnthropicProvider(LLMProvider):
|
||||||
else:
|
else:
|
||||||
tool_result_blocks.append({"type": "text", "text": str(tool_content)})
|
tool_result_blocks.append({"type": "text", "text": str(tool_content)})
|
||||||
|
|
||||||
anthropic_messages.append({
|
anthropic_messages.append(
|
||||||
"role": "user",
|
{
|
||||||
"content": [{
|
"role": "user",
|
||||||
"type": "tool_result",
|
"content": [
|
||||||
"tool_use_id": msg.get("tool_call_id", ""),
|
{
|
||||||
"content": tool_result_blocks,
|
"type": "tool_result",
|
||||||
}],
|
"tool_use_id": msg.get("tool_call_id", ""),
|
||||||
})
|
"content": tool_result_blocks,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
anthropic_messages.append({
|
anthropic_messages.append(
|
||||||
"role": "user",
|
{
|
||||||
"content": [{"type": "text", "text": content}],
|
"role": "user",
|
||||||
})
|
"content": [{"type": "text", "text": content}],
|
||||||
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
# OpenAI 格式中独立的 tool 消息
|
# OpenAI 格式中独立的 tool 消息
|
||||||
tool_content = msg.get("content", "")
|
tool_content = msg.get("content", "")
|
||||||
if isinstance(tool_content, str):
|
if isinstance(tool_content, str):
|
||||||
result_content: list[dict[str, Any]] | str = [{"type": "text", "text": tool_content}]
|
result_content: list[dict[str, object]] | str = [
|
||||||
|
{"type": "text", "text": tool_content}
|
||||||
|
]
|
||||||
elif isinstance(tool_content, list):
|
elif isinstance(tool_content, list):
|
||||||
result_content = tool_content
|
result_content = tool_content
|
||||||
else:
|
else:
|
||||||
result_content = [{"type": "text", "text": str(tool_content)}]
|
result_content = [{"type": "text", "text": str(tool_content)}]
|
||||||
|
|
||||||
anthropic_messages.append({
|
anthropic_messages.append(
|
||||||
"role": "user",
|
{
|
||||||
"content": [{
|
"role": "user",
|
||||||
"type": "tool_result",
|
"content": [
|
||||||
"tool_use_id": msg.get("tool_call_id", ""),
|
{
|
||||||
"content": result_content,
|
"type": "tool_result",
|
||||||
}],
|
"tool_use_id": msg.get("tool_call_id", ""),
|
||||||
})
|
"content": result_content,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return system_prompt, anthropic_messages
|
return system_prompt, anthropic_messages
|
||||||
|
|
||||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||||
"""将 OpenAI function 格式转换为 Anthropic tool 格式"""
|
"""将 OpenAI function 格式转换为 Anthropic tool 格式"""
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.get("type") == "function":
|
if tool.get("type") == "function":
|
||||||
func = tool.get("function", {})
|
func = tool.get("function", {})
|
||||||
anthropic_tools.append({
|
anthropic_tools.append(
|
||||||
"name": func.get("name", ""),
|
{
|
||||||
"description": func.get("description", ""),
|
"name": func.get("name", ""),
|
||||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
"description": func.get("description", ""),
|
||||||
})
|
"input_schema": func.get(
|
||||||
|
"parameters", {"type": "object", "properties": {}}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
return anthropic_tools
|
return anthropic_tools
|
||||||
|
|
||||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
|
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
|
||||||
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
|
"""将 OpenAI tool_choice 格式转换为 Anthropic 格式"""
|
||||||
if tool_choice == "auto":
|
if tool_choice == "auto":
|
||||||
return {"type": "auto"}
|
return {"type": "auto"}
|
||||||
|
|
@ -227,7 +248,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
return {"type": "tool", "name": tool_choice}
|
return {"type": "tool", "name": tool_choice}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
|
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
|
||||||
"""将 Anthropic 响应转换为 LLMResponse"""
|
"""将 Anthropic 响应转换为 LLMResponse"""
|
||||||
content_blocks = data.get("content", [])
|
content_blocks = data.get("content", [])
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
|
|
@ -238,11 +259,13 @@ class AnthropicProvider(LLMProvider):
|
||||||
if block_type == "text":
|
if block_type == "text":
|
||||||
text_parts.append(block.get("text", ""))
|
text_parts.append(block.get("text", ""))
|
||||||
elif block_type == "tool_use":
|
elif block_type == "tool_use":
|
||||||
tool_calls.append(ToolCall(
|
tool_calls.append(
|
||||||
id=block.get("id", ""),
|
ToolCall(
|
||||||
name=block.get("name", ""),
|
id=block.get("id", ""),
|
||||||
arguments=block.get("input", {}),
|
name=block.get("name", ""),
|
||||||
))
|
arguments=block.get("input", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
usage_data = data.get("usage", {})
|
usage_data = data.get("usage", {})
|
||||||
usage = TokenUsage(
|
usage = TokenUsage(
|
||||||
|
|
@ -287,7 +310,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
|
|
||||||
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, object] = {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"max_tokens": request.max_tokens or self._max_tokens,
|
"max_tokens": request.max_tokens or self._max_tokens,
|
||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
|
|
@ -346,7 +369,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
|
|
||||||
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
system_prompt, anthropic_messages = self._convert_messages(request.messages)
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, object] = {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"max_tokens": request.max_tokens or self._max_tokens,
|
"max_tokens": request.max_tokens or self._max_tokens,
|
||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
|
|
@ -375,7 +398,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
async def _iterate_stream(self, response, request: LLMRequest):
|
async def _iterate_stream(self, response, request: LLMRequest):
|
||||||
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
||||||
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
|
# Accumulated tool calls: tool_use_id -> {id, name, input_json_str}
|
||||||
accumulated_tool_calls: dict[str, dict[str, Any]] = {}
|
accumulated_tool_calls: dict[str, dict[str, object]] = {}
|
||||||
current_tool_id: str | None = None
|
current_tool_id: str | None = None
|
||||||
current_tool_name: str | None = None
|
current_tool_name: str | None = None
|
||||||
current_tool_input_json: str = ""
|
current_tool_input_json: str = ""
|
||||||
|
|
@ -433,7 +456,9 @@ class AnthropicProvider(LLMProvider):
|
||||||
# Finalize current tool call if any
|
# Finalize current tool call if any
|
||||||
if current_tool_id is not None:
|
if current_tool_id is not None:
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(current_tool_input_json) if current_tool_input_json else {}
|
arguments = (
|
||||||
|
json.loads(current_tool_input_json) if current_tool_input_json else {}
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
arguments = {"raw": current_tool_input_json}
|
arguments = {"raw": current_tool_input_json}
|
||||||
|
|
||||||
|
|
@ -510,7 +535,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
error_msg = error_info.get("message", "Stream error")
|
error_msg = error_info.get("message", "Stream error")
|
||||||
raise LLMProviderError("anthropic", error_msg)
|
raise LLMProviderError("anthropic", error_msg)
|
||||||
|
|
||||||
def get_model_info(self) -> dict[str, Any]:
|
def get_model_info(self) -> dict[str, object]:
|
||||||
"""返回 Provider 和模型信息"""
|
"""返回 Provider 和模型信息"""
|
||||||
return {
|
return {
|
||||||
"provider": "anthropic",
|
"provider": "anthropic",
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ API:火山引擎 OpenAI 兼容接口
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
|
||||||
|
|
@ -48,7 +47,7 @@ class DoubaoProvider(OpenAICompatibleProvider):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
base_url: str = DOUBAO_DEFAULT_BASE_URL,
|
||||||
default_model: str = "doubao-pro-32k",
|
default_model: str = "doubao-pro-32k",
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -90,14 +89,14 @@ class GeminiProvider(LLMProvider):
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
self, messages: list[dict[str, str]]
|
self, messages: list[dict[str, str]]
|
||||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
) -> tuple[dict[str, object] | None, list[dict[str, object]]]:
|
||||||
"""将 OpenAI 风格消息转换为 Gemini 格式
|
"""将 OpenAI 风格消息转换为 Gemini 格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(system_instruction, contents)
|
(system_instruction, contents)
|
||||||
"""
|
"""
|
||||||
system_instruction: dict[str, Any] | None = None
|
system_instruction: dict[str, object] | None = None
|
||||||
contents: list[dict[str, Any]] = []
|
contents: list[dict[str, object]] = []
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg.get("role", "")
|
role = msg.get("role", "")
|
||||||
|
|
@ -119,28 +118,34 @@ class GeminiProvider(LLMProvider):
|
||||||
tool_name = parsed.get("name", "")
|
tool_name = parsed.get("name", "")
|
||||||
except (json.JSONDecodeError, AttributeError):
|
except (json.JSONDecodeError, AttributeError):
|
||||||
pass
|
pass
|
||||||
contents.append({
|
contents.append(
|
||||||
"role": "user",
|
{
|
||||||
"parts": [{
|
"role": "user",
|
||||||
"functionResponse": {
|
"parts": [
|
||||||
"name": tool_name,
|
{
|
||||||
"response": {
|
"functionResponse": {
|
||||||
"content": content,
|
"name": tool_name,
|
||||||
},
|
"response": {
|
||||||
},
|
"content": content,
|
||||||
}],
|
},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
contents.append({
|
contents.append(
|
||||||
"role": "user",
|
{
|
||||||
"parts": [{"text": content}],
|
"role": "user",
|
||||||
})
|
"parts": [{"text": content}],
|
||||||
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
tool_calls = msg.get("tool_calls")
|
tool_calls = msg.get("tool_calls")
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
parts: list[dict[str, Any]] = []
|
parts: list[dict[str, object]] = []
|
||||||
if content:
|
if content:
|
||||||
parts.append({"text": content})
|
parts.append({"text": content})
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
|
|
@ -151,54 +156,64 @@ class GeminiProvider(LLMProvider):
|
||||||
arguments = json.loads(arguments)
|
arguments = json.loads(arguments)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
arguments = {"raw": arguments}
|
arguments = {"raw": arguments}
|
||||||
parts.append({
|
parts.append(
|
||||||
"functionCall": {
|
{
|
||||||
"name": func.get("name", ""),
|
"functionCall": {
|
||||||
"args": arguments,
|
"name": func.get("name", ""),
|
||||||
},
|
"args": arguments,
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
contents.append({"role": "model", "parts": parts})
|
contents.append({"role": "model", "parts": parts})
|
||||||
else:
|
else:
|
||||||
contents.append({
|
contents.append(
|
||||||
"role": "model",
|
{
|
||||||
"parts": [{"text": content}],
|
"role": "model",
|
||||||
})
|
"parts": [{"text": content}],
|
||||||
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
# OpenAI format: {"role": "tool", "tool_call_id": "...", "content": "..."}
|
||||||
tool_name = msg.get("name", "")
|
tool_name = msg.get("name", "")
|
||||||
tool_content = msg.get("content", "")
|
tool_content = msg.get("content", "")
|
||||||
contents.append({
|
contents.append(
|
||||||
"role": "user",
|
{
|
||||||
"parts": [{
|
"role": "user",
|
||||||
"functionResponse": {
|
"parts": [
|
||||||
"name": tool_name,
|
{
|
||||||
"response": {
|
"functionResponse": {
|
||||||
"content": tool_content,
|
"name": tool_name,
|
||||||
},
|
"response": {
|
||||||
},
|
"content": tool_content,
|
||||||
}],
|
},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return system_instruction, contents
|
return system_instruction, contents
|
||||||
|
|
||||||
def _convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _convert_tools(self, tools: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||||
"""将 OpenAI function 格式转换为 Gemini functionDeclarations"""
|
"""将 OpenAI function 格式转换为 Gemini functionDeclarations"""
|
||||||
declarations = []
|
declarations = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.get("type") == "function":
|
if tool.get("type") == "function":
|
||||||
func = tool.get("function", {})
|
func = tool.get("function", {})
|
||||||
declarations.append({
|
declarations.append(
|
||||||
"name": func.get("name", ""),
|
{
|
||||||
"description": func.get("description", ""),
|
"name": func.get("name", ""),
|
||||||
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
|
"description": func.get("description", ""),
|
||||||
})
|
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
|
||||||
|
}
|
||||||
|
)
|
||||||
if not declarations:
|
if not declarations:
|
||||||
return []
|
return []
|
||||||
return [{"functionDeclarations": declarations}]
|
return [{"functionDeclarations": declarations}]
|
||||||
|
|
||||||
def _convert_tool_choice(self, tool_choice: str) -> dict[str, Any] | None:
|
def _convert_tool_choice(self, tool_choice: str) -> dict[str, object] | None:
|
||||||
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
|
"""将 OpenAI tool_choice 格式转换为 Gemini toolConfig"""
|
||||||
if tool_choice == "auto":
|
if tool_choice == "auto":
|
||||||
return {"functionCallingConfig": {"mode": "AUTO"}}
|
return {"functionCallingConfig": {"mode": "AUTO"}}
|
||||||
|
|
@ -210,7 +225,7 @@ class GeminiProvider(LLMProvider):
|
||||||
return {"functionCallingConfig": {"mode": "NONE"}}
|
return {"functionCallingConfig": {"mode": "NONE"}}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_response(self, data: dict[str, Any], model: str) -> LLMResponse:
|
def _parse_response(self, data: dict[str, object], model: str) -> LLMResponse:
|
||||||
"""将 Gemini 响应转换为 LLMResponse"""
|
"""将 Gemini 响应转换为 LLMResponse"""
|
||||||
candidates = data.get("candidates", [])
|
candidates = data.get("candidates", [])
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
|
|
@ -225,11 +240,13 @@ class GeminiProvider(LLMProvider):
|
||||||
text_parts.append(part["text"])
|
text_parts.append(part["text"])
|
||||||
elif "functionCall" in part:
|
elif "functionCall" in part:
|
||||||
fc = part["functionCall"]
|
fc = part["functionCall"]
|
||||||
tool_calls.append(ToolCall(
|
tool_calls.append(
|
||||||
id=f"call_{tool_call_index}",
|
ToolCall(
|
||||||
name=fc.get("name", ""),
|
id=f"call_{tool_call_index}",
|
||||||
arguments=fc.get("args", {}),
|
name=fc.get("name", ""),
|
||||||
))
|
arguments=fc.get("args", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
tool_call_index += 1
|
tool_call_index += 1
|
||||||
|
|
||||||
usage_metadata = data.get("usageMetadata", {})
|
usage_metadata = data.get("usageMetadata", {})
|
||||||
|
|
@ -275,7 +292,7 @@ class GeminiProvider(LLMProvider):
|
||||||
|
|
||||||
system_instruction, contents = self._convert_messages(request.messages)
|
system_instruction, contents = self._convert_messages(request.messages)
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, object] = {
|
||||||
"contents": contents,
|
"contents": contents,
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
|
|
@ -340,7 +357,7 @@ class GeminiProvider(LLMProvider):
|
||||||
|
|
||||||
system_instruction, contents = self._convert_messages(request.messages)
|
system_instruction, contents = self._convert_messages(request.messages)
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, object] = {
|
||||||
"contents": contents,
|
"contents": contents,
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
|
|
@ -374,7 +391,7 @@ class GeminiProvider(LLMProvider):
|
||||||
|
|
||||||
async def _iterate_stream(self, response, request: LLMRequest):
|
async def _iterate_stream(self, response, request: LLMRequest):
|
||||||
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
accumulated_tool_calls: list[dict[str, object]] = []
|
||||||
model = request.model or self._model
|
model = request.model or self._model
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
|
|
@ -436,11 +453,13 @@ class GeminiProvider(LLMProvider):
|
||||||
)
|
)
|
||||||
elif "functionCall" in part:
|
elif "functionCall" in part:
|
||||||
fc = part["functionCall"]
|
fc = part["functionCall"]
|
||||||
accumulated_tool_calls.append({
|
accumulated_tool_calls.append(
|
||||||
"id": f"call_{len(accumulated_tool_calls)}",
|
{
|
||||||
"name": fc.get("name", ""),
|
"id": f"call_{len(accumulated_tool_calls)}",
|
||||||
"arguments": fc.get("args", {}),
|
"name": fc.get("name", ""),
|
||||||
})
|
"arguments": fc.get("args", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Check for finish reason
|
# Check for finish reason
|
||||||
finish_reason = candidates[0].get("finishReason", "")
|
finish_reason = candidates[0].get("finishReason", "")
|
||||||
|
|
@ -461,7 +480,7 @@ class GeminiProvider(LLMProvider):
|
||||||
)
|
)
|
||||||
accumulated_tool_calls = []
|
accumulated_tool_calls = []
|
||||||
|
|
||||||
def get_model_info(self) -> dict[str, Any]:
|
def get_model_info(self) -> dict[str, object]:
|
||||||
"""返回 Provider 和模型信息"""
|
"""返回 Provider 和模型信息"""
|
||||||
return {
|
return {
|
||||||
"provider": "gemini",
|
"provider": "gemini",
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Iterable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
from agentkit.llm.protocol import (
|
from agentkit.llm.protocol import (
|
||||||
|
|
@ -81,13 +80,13 @@ class LitellmProvider(LLMProvider):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
provider_type: str = "openai",
|
provider_type: str = "openai",
|
||||||
**default_kwargs: Any,
|
**default_kwargs: object,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._model_prefix = model_prefix
|
self._model_prefix = model_prefix
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._base_url = base_url or None # 空字符串视作未设置
|
self._base_url = base_url or None # 空字符串视作未设置
|
||||||
self._provider_type = provider_type
|
self._provider_type = provider_type
|
||||||
self._default_kwargs: dict[str, Any] = dict(default_kwargs)
|
self._default_kwargs: dict[str, object] = dict(default_kwargs)
|
||||||
|
|
||||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
|
"""非流式 chat — 调用 ``litellm.acompletion`` 并翻译响应。"""
|
||||||
|
|
@ -116,7 +115,7 @@ class LitellmProvider(LLMProvider):
|
||||||
|
|
||||||
kwargs = self._build_kwargs(request, stream=True)
|
kwargs = self._build_kwargs(request, stream=True)
|
||||||
|
|
||||||
accumulated_tool_calls: dict[int, dict[str, Any]] = {}
|
accumulated_tool_calls: dict[int, dict[str, object]] = {}
|
||||||
final_usage: TokenUsage | None = None
|
final_usage: TokenUsage | None = None
|
||||||
final_model: str = request.model
|
final_model: str = request.model
|
||||||
|
|
||||||
|
|
@ -158,9 +157,9 @@ class LitellmProvider(LLMProvider):
|
||||||
# 内部辅助
|
# 内部辅助
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, Any]:
|
def _build_kwargs(self, request: LLMRequest, *, stream: bool) -> dict[str, object]:
|
||||||
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
|
"""从 LLMRequest 构造 litellm.acompletion kwargs。"""
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, object] = {
|
||||||
"model": f"{self._model_prefix}{request.model}",
|
"model": f"{self._model_prefix}{request.model}",
|
||||||
"messages": request.messages,
|
"messages": request.messages,
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
|
|
@ -187,7 +186,7 @@ class LitellmProvider(LLMProvider):
|
||||||
|
|
||||||
def _parse_response(
|
def _parse_response(
|
||||||
self,
|
self,
|
||||||
response: Any,
|
response: object,
|
||||||
request_model: str,
|
request_model: str,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
|
@ -229,9 +228,9 @@ class LitellmProvider(LLMProvider):
|
||||||
|
|
||||||
def _parse_stream_chunk(
|
def _parse_stream_chunk(
|
||||||
self,
|
self,
|
||||||
chunk: Any,
|
chunk: object,
|
||||||
request_model: str,
|
request_model: str,
|
||||||
accumulated_tool_calls: dict[int, dict[str, Any]],
|
accumulated_tool_calls: dict[int, dict[str, object]],
|
||||||
) -> StreamChunk:
|
) -> StreamChunk:
|
||||||
"""解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。"""
|
"""解析单个流式 chunk(非 final)。累计 tool_calls 到传入字典。"""
|
||||||
choices = getattr(chunk, "choices", None) or []
|
choices = getattr(chunk, "choices", None) or []
|
||||||
|
|
@ -262,7 +261,7 @@ class LitellmProvider(LLMProvider):
|
||||||
|
|
||||||
def _finalize_tool_calls(
|
def _finalize_tool_calls(
|
||||||
self,
|
self,
|
||||||
accumulated: dict[int, dict[str, Any]],
|
accumulated: dict[int, dict[str, object]],
|
||||||
) -> list[ToolCall]:
|
) -> list[ToolCall]:
|
||||||
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
|
"""把累计的流式 tool_calls 字典转成 ToolCall 列表。"""
|
||||||
tool_calls: list[ToolCall] = []
|
tool_calls: list[ToolCall] = []
|
||||||
|
|
@ -288,7 +287,7 @@ class LitellmProvider(LLMProvider):
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
|
def _parse_tool_calls(raw_tool_calls: Iterable[object]) -> list[ToolCall]:
|
||||||
"""解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。"""
|
"""解析非流式响应的 tool_calls(OpenAI 格式 list[ChoiceMessageToolCall])。"""
|
||||||
result: list[ToolCall] = []
|
result: list[ToolCall] = []
|
||||||
for tc in raw_tool_calls:
|
for tc in raw_tool_calls:
|
||||||
|
|
@ -312,7 +311,7 @@ def _parse_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _parse_usage(usage_obj: Any) -> TokenUsage:
|
def _parse_usage(usage_obj: object) -> TokenUsage:
|
||||||
"""解析 usage 对象(OpenAI CompletionUsage 或 dict)。"""
|
"""解析 usage 对象(OpenAI CompletionUsage 或 dict)。"""
|
||||||
prompt = getattr(usage_obj, "prompt_tokens", None)
|
prompt = getattr(usage_obj, "prompt_tokens", None)
|
||||||
completion = getattr(usage_obj, "completion_tokens", None)
|
completion = getattr(usage_obj, "completion_tokens", None)
|
||||||
|
|
@ -326,8 +325,8 @@ def _parse_usage(usage_obj: Any) -> TokenUsage:
|
||||||
|
|
||||||
|
|
||||||
def _accumulate_stream_tool_calls(
|
def _accumulate_stream_tool_calls(
|
||||||
raw_tool_calls: Any,
|
raw_tool_calls: Iterable[object],
|
||||||
accumulated: dict[int, dict[str, Any]],
|
accumulated: dict[int, dict[str, object]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。
|
"""累计流式 chunk 里的 tool_calls 片段(OpenAI delta.tool_calls 格式)。
|
||||||
|
|
||||||
|
|
@ -364,7 +363,7 @@ def create_litellm_provider(
|
||||||
provider_type: str,
|
provider_type: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
) -> LitellmProvider:
|
) -> LitellmProvider:
|
||||||
"""根据 provider_type 创建 LitellmProvider 实例。
|
"""根据 provider_type 创建 LitellmProvider 实例。
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from agentkit.llm.protocol import TokenUsage
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
|
||||||
|
|
@ -294,8 +294,8 @@ class RedisUsageStore:
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._redis: Any = None
|
self._redis: object | None = None
|
||||||
self._sync_redis: Any = None
|
self._sync_redis: object | None = None
|
||||||
self._fallback: InMemoryUsageStore | None = None
|
self._fallback: InMemoryUsageStore | None = None
|
||||||
self._degraded = False
|
self._degraded = False
|
||||||
self._health_check_task: asyncio.Task[None] | None = None
|
self._health_check_task: asyncio.Task[None] | None = None
|
||||||
|
|
@ -687,7 +687,7 @@ class RedisUsageStore:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _read_list(
|
def _read_list(
|
||||||
r: Any,
|
r: object,
|
||||||
list_key: str,
|
list_key: str,
|
||||||
start: datetime,
|
start: datetime,
|
||||||
end: datetime,
|
end: datetime,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||||
|
|
@ -51,7 +50,7 @@ class WenxinProvider(OpenAICompatibleProvider):
|
||||||
secret_key: str | None = None,
|
secret_key: str | None = None,
|
||||||
base_url: str = WENXIN_DEFAULT_BASE_URL,
|
base_url: str = WENXIN_DEFAULT_BASE_URL,
|
||||||
default_model: str = "ernie-4.5-turbo-128k",
|
default_model: str = "ernie-4.5-turbo-128k",
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
# If AK/SK provided, use token-based auth
|
# If AK/SK provided, use token-based auth
|
||||||
self._access_key = access_key
|
self._access_key = access_key
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ API:腾讯云 OpenAI 兼容接口
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
from agentkit.llm.protocol import LLMRequest, LLMResponse
|
||||||
|
|
@ -48,7 +47,7 @@ class YuanbaoProvider(OpenAICompatibleProvider):
|
||||||
base_url: str = YUANBAO_DEFAULT_BASE_URL,
|
base_url: str = YUANBAO_DEFAULT_BASE_URL,
|
||||||
default_model: str = "hunyuan-turbos-latest",
|
default_model: str = "hunyuan-turbos-latest",
|
||||||
enable_enhancement: bool = False,
|
enable_enhancement: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
self._enable_enhancement = enable_enhancement
|
self._enable_enhancement = enable_enhancement
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -66,7 +65,7 @@ class RemoteLLMProvider(LLMProvider):
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_payload(self, request: LLMRequest) -> dict[str, Any]:
|
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
|
||||||
"""Convert LLMRequest to server API payload."""
|
"""Convert LLMRequest to server API payload."""
|
||||||
return {
|
return {
|
||||||
"messages": request.messages,
|
"messages": request.messages,
|
||||||
|
|
@ -91,7 +90,7 @@ class RemoteLLMProvider(LLMProvider):
|
||||||
return str(body["error"])
|
return str(body["error"])
|
||||||
return str(body)
|
return str(body)
|
||||||
|
|
||||||
def _parse_response(self, data: dict[str, Any], request: LLMRequest) -> LLMResponse:
|
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
|
||||||
"""Parse server response JSON into an LLMResponse."""
|
"""Parse server response JSON into an LLMResponse."""
|
||||||
usage_data = data.get("usage") or {}
|
usage_data = data.get("usage") or {}
|
||||||
usage = TokenUsage(
|
usage = TokenUsage(
|
||||||
|
|
@ -115,7 +114,7 @@ class RemoteLLMProvider(LLMProvider):
|
||||||
latency_ms=data.get("latency_ms", 0.0),
|
latency_ms=data.get("latency_ms", 0.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_chunk(self, data: dict[str, Any], request: LLMRequest) -> StreamChunk:
|
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
|
||||||
"""Parse a single SSE data payload into a StreamChunk."""
|
"""Parse a single SSE data payload into a StreamChunk."""
|
||||||
usage: TokenUsage | None = None
|
usage: TokenUsage | None = None
|
||||||
usage_data = data.get("usage")
|
usage_data = data.get("usage")
|
||||||
|
|
@ -218,9 +217,7 @@ class RemoteLLMProvider(LLMProvider):
|
||||||
if response.status_code == 502:
|
if response.status_code == 502:
|
||||||
await response.aread()
|
await response.aread()
|
||||||
detail = self._extract_error_detail(response)
|
detail = self._extract_error_detail(response)
|
||||||
raise LLMProviderError(
|
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
|
||||||
"remote", f"Server LLM gateway error: {detail}"
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
await response.aread()
|
await response.aread()
|
||||||
raise LLMProviderError(
|
raise LLMProviderError(
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
|
|
||||||
|
|
@ -20,9 +20,7 @@ class RetryConfig:
|
||||||
base_delay: float = 1.0
|
base_delay: float = 1.0
|
||||||
max_delay: float = 30.0
|
max_delay: float = 30.0
|
||||||
exponential_base: float = 2.0
|
exponential_base: float = 2.0
|
||||||
retryable_status_codes: set[int] = field(
|
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
|
||||||
default_factory=lambda: {429, 500, 502, 503, 529}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CircuitState(Enum):
|
class CircuitState(Enum):
|
||||||
|
|
@ -69,7 +67,7 @@ class RetryPolicy:
|
||||||
def __init__(self, config: RetryConfig | None = None):
|
def __init__(self, config: RetryConfig | None = None):
|
||||||
self._config = config or RetryConfig()
|
self._config = config or RetryConfig()
|
||||||
|
|
||||||
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
||||||
"""Execute fn with retry on retryable errors."""
|
"""Execute fn with retry on retryable errors."""
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
|
@ -84,7 +82,7 @@ class RetryPolicy:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
delay = min(
|
delay = min(
|
||||||
self._config.base_delay * (self._config.exponential_base ** attempt),
|
self._config.base_delay * (self._config.exponential_base**attempt),
|
||||||
self._config.max_delay,
|
self._config.max_delay,
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -142,7 +140,7 @@ class CircuitBreaker:
|
||||||
f"after {self._failure_count} failures"
|
f"after {self._failure_count} failures"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute(self, fn: Callable, *args: Any, **kwargs: Any) -> Any:
|
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
||||||
"""Execute fn through the circuit breaker."""
|
"""Execute fn through the circuit breaker."""
|
||||||
current_state = self.state
|
current_state = self.state
|
||||||
|
|
||||||
|
|
@ -158,6 +156,6 @@ class CircuitBreaker:
|
||||||
result = await fn(*args, **kwargs)
|
result = await fn(*args, **kwargs)
|
||||||
self._on_success()
|
self._on_success()
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception:
|
||||||
self._on_failure()
|
self._on_failure()
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -95,8 +94,7 @@ class KBAdapter(ABC):
|
||||||
async def delete_by_id(self, id: str) -> bool:
|
async def delete_by_id(self, id: str) -> bool:
|
||||||
"""按文档 ID 删除(子类可覆盖)"""
|
"""按文档 ID 删除(子类可覆盖)"""
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.__class__.__name__} does not support delete_by_id; "
|
f"{self.__class__.__name__} does not support delete_by_id; id '{id}' skipped"
|
||||||
f"id '{id}' skipped"
|
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -127,8 +125,7 @@ class KBAdapter(ABC):
|
||||||
async def get_document(self, doc_id: str) -> Document | None:
|
async def get_document(self, doc_id: str) -> Document | None:
|
||||||
"""按 ID 获取单个文档(子类可覆盖)"""
|
"""按 ID 获取单个文档(子类可覆盖)"""
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.__class__.__name__} does not support get_document; "
|
f"{self.__class__.__name__} does not support get_document; doc_id '{doc_id}' not found"
|
||||||
f"doc_id '{doc_id}' not found"
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -156,5 +153,5 @@ class KBAdapter(ABC):
|
||||||
async def __aenter__(self) -> KBAdapter:
|
async def __aenter__(self) -> KBAdapter:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args: Any) -> None:
|
async def __aexit__(self, *args: object) -> None:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -56,7 +55,9 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
)
|
)
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
if not is_safe_url(self._base_url):
|
if not is_safe_url(self._base_url):
|
||||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
raise ValueError(
|
||||||
|
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
|
||||||
|
)
|
||||||
self._username = username
|
self._username = username
|
||||||
self._api_token = api_token
|
self._api_token = api_token
|
||||||
self._space_keys = space_keys or []
|
self._space_keys = space_keys or []
|
||||||
|
|
@ -65,9 +66,7 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
"""创建 Confluence API HTTP 客户端"""
|
"""创建 Confluence API HTTP 客户端"""
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
credentials = base64.b64encode(
|
credentials = base64.b64encode(f"{self._username}:{self._api_token}".encode()).decode()
|
||||||
f"{self._username}:{self._api_token}".encode()
|
|
||||||
).decode()
|
|
||||||
return httpx.AsyncClient(
|
return httpx.AsyncClient(
|
||||||
base_url=self._base_url,
|
base_url=self._base_url,
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -101,7 +100,7 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
space_filter = " OR ".join(
|
space_filter = " OR ".join(
|
||||||
f'space = "{_escape_cql(key)}"' for key in self._space_keys
|
f'space = "{_escape_cql(key)}"' for key in self._space_keys
|
||||||
)
|
)
|
||||||
cql = f'{cql} AND ({space_filter})'
|
cql = f"{cql} AND ({space_filter})"
|
||||||
|
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
"/rest/api/content/search",
|
"/rest/api/content/search",
|
||||||
|
|
@ -115,6 +114,7 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||||
# Strip HTML tags for plain text content
|
# Strip HTML tags for plain text content
|
||||||
import re
|
import re
|
||||||
|
|
||||||
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
|
content = re.sub(r"<[^>]+>", "", body) if body else page.get("title", "")
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
|
|
@ -136,8 +136,7 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Confluence search HTTP error: {e.response.status_code} — "
|
f"Confluence search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
f"{e.response.text[:200]}"
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -157,6 +156,7 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
|
|
||||||
body = page.get("body", {}).get("storage", {}).get("value", "")
|
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||||
import re
|
import re
|
||||||
|
|
||||||
content = re.sub(r"<[^>]+>", "", body) if body else ""
|
content = re.sub(r"<[^>]+>", "", body) if body else ""
|
||||||
|
|
||||||
return Document(
|
return Document(
|
||||||
|
|
@ -191,13 +191,17 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
source_type="confluence",
|
source_type="confluence",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return sources if sources else [
|
return (
|
||||||
SourceInfo(
|
sources
|
||||||
source_id=self._source_id,
|
if sources
|
||||||
source_name=self._source_name,
|
else [
|
||||||
source_type=self._source_type,
|
SourceInfo(
|
||||||
)
|
source_id=self._source_id,
|
||||||
]
|
source_name=self._source_name,
|
||||||
|
source_type=self._source_type,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Confluence list_sources error: {e}")
|
logger.error(f"Confluence list_sources error: {e}")
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import TypeAlias
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -18,6 +18,9 @@ from agentkit.utils.security import is_safe_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 飞书搜索请求 payload:search_key/page_size/wiki_space_ids — 值为 str|int|list[str]。
|
||||||
|
FeishuSearchPayload: TypeAlias = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
class FeishuKBAdapter(KBAdapter):
|
class FeishuKBAdapter(KBAdapter):
|
||||||
"""飞书知识库适配器
|
"""飞书知识库适配器
|
||||||
|
|
@ -54,7 +57,9 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
self._app_secret = app_secret
|
self._app_secret = app_secret
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
if not is_safe_url(self._base_url):
|
if not is_safe_url(self._base_url):
|
||||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
raise ValueError(
|
||||||
|
f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed."
|
||||||
|
)
|
||||||
self._space_ids = space_ids or []
|
self._space_ids = space_ids or []
|
||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
self._token_expiry: float = 0.0
|
self._token_expiry: float = 0.0
|
||||||
|
|
@ -94,10 +99,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
self._client = None
|
self._client = None
|
||||||
return self._access_token
|
return self._access_token
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"Feishu auth failed: code={data.get('code')}, msg={data.get('msg')}")
|
||||||
f"Feishu auth failed: code={data.get('code')}, "
|
|
||||||
f"msg={data.get('msg')}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Feishu auth error: {e}")
|
logger.error(f"Feishu auth error: {e}")
|
||||||
|
|
@ -121,7 +123,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
try:
|
try:
|
||||||
payload: dict[str, Any] = {
|
payload: FeishuSearchPayload = {
|
||||||
"search_key": query,
|
"search_key": query,
|
||||||
"page_size": top_k,
|
"page_size": top_k,
|
||||||
}
|
}
|
||||||
|
|
@ -137,8 +139,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
|
|
||||||
if data.get("code") != 0:
|
if data.get("code") != 0:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Feishu search failed: code={data.get('code')}, "
|
f"Feishu search failed: code={data.get('code')}, msg={data.get('msg')}"
|
||||||
f"msg={data.get('msg')}"
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -162,8 +163,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Feishu search HTTP error: {e.response.status_code} — "
|
f"Feishu search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
f"{e.response.text[:200]}"
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -179,7 +179,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
try:
|
try:
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
f"/wiki/v2/spaces/get_node",
|
"/wiki/v2/spaces/get_node",
|
||||||
params={"token": doc_id},
|
params={"token": doc_id},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
@ -230,13 +230,17 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
source_type="feishu",
|
source_type="feishu",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return sources if sources else [
|
return (
|
||||||
SourceInfo(
|
sources
|
||||||
source_id=self._source_id,
|
if sources
|
||||||
source_name=self._source_name,
|
else [
|
||||||
source_type=self._source_type,
|
SourceInfo(
|
||||||
)
|
source_id=self._source_id,
|
||||||
]
|
source_name=self._source_name,
|
||||||
|
source_type=self._source_type,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Feishu list_sources error: {e}")
|
logger.error(f"Feishu list_sources error: {e}")
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -55,7 +54,9 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
)
|
)
|
||||||
self._endpoint_url = endpoint_url.rstrip("/")
|
self._endpoint_url = endpoint_url.rstrip("/")
|
||||||
if not is_safe_url(self._endpoint_url):
|
if not is_safe_url(self._endpoint_url):
|
||||||
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
raise ValueError(
|
||||||
|
f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed."
|
||||||
|
)
|
||||||
self._auth_config = auth_config or {}
|
self._auth_config = auth_config or {}
|
||||||
self._extra_headers = headers or {}
|
self._extra_headers = headers or {}
|
||||||
|
|
||||||
|
|
@ -74,12 +75,11 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
elif auth_type == "basic":
|
elif auth_type == "basic":
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
username = self._auth_config.get("username", "")
|
username = self._auth_config.get("username", "")
|
||||||
password = self._auth_config.get("password", "")
|
password = self._auth_config.get("password", "")
|
||||||
if username and password:
|
if username and password:
|
||||||
credentials = base64.b64encode(
|
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
f"{username}:{password}".encode()
|
|
||||||
).decode()
|
|
||||||
headers["Authorization"] = f"Basic {credentials}"
|
headers["Authorization"] = f"Basic {credentials}"
|
||||||
elif auth_type == "api_key":
|
elif auth_type == "api_key":
|
||||||
key_name = self._auth_config.get("header_name", "X-API-Key")
|
key_name = self._auth_config.get("header_name", "X-API-Key")
|
||||||
|
|
@ -135,8 +135,7 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"GenericHTTP search HTTP error: {e.response.status_code} — "
|
f"GenericHTTP search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
f"{e.response.text[:200]}"
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -177,8 +176,7 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"GenericHTTP ingest HTTP error: {e.response.status_code} — "
|
f"GenericHTTP ingest HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
f"{e.response.text[:200]}"
|
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -245,13 +243,17 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
document_count=item.get("document_count", 0),
|
document_count=item.get("document_count", 0),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return sources if sources else [
|
return (
|
||||||
SourceInfo(
|
sources
|
||||||
source_id=self._source_id,
|
if sources
|
||||||
source_name=self._source_name,
|
else [
|
||||||
source_type=self._source_type,
|
SourceInfo(
|
||||||
)
|
source_id=self._source_id,
|
||||||
]
|
source_name=self._source_name,
|
||||||
|
source_type=self._source_type,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")
|
logger.debug(f"GenericHTTP list_sources error (endpoint may not exist): {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,29 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
# 共享类型别名 — 跨 memory 子系统复用,避免 `Any` 残留。
|
||||||
|
# MetadataValue 覆盖 metadata dict 中实际出现的原始类型;
|
||||||
|
# MemoryValue 额外允许 dict/list 容器以容纳结构化负载(如 episodic 经验字典)。
|
||||||
|
MetadataValue: TypeAlias = str | int | float | bool | None
|
||||||
|
MetadataDict: TypeAlias = dict[str, MetadataValue]
|
||||||
|
MemoryValue: TypeAlias = (
|
||||||
|
str | int | float | bool | None | dict[str, MetadataValue] | list[MetadataValue]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryItem:
|
class MemoryItem:
|
||||||
"""记忆条目"""
|
"""记忆条目"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
value: Any
|
value: object
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: MetadataDict = field(default_factory=dict)
|
||||||
score: float = 1.0
|
score: float = 1.0
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"key": self.key,
|
"key": self.key,
|
||||||
"value": self.value,
|
"value": self.value,
|
||||||
|
|
@ -35,7 +45,7 @@ class Memory(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||||
"""存储记忆"""
|
"""存储记忆"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
@ -45,7 +55,9 @@ class Memory(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
async def search(
|
||||||
|
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||||
|
) -> list[MemoryItem]:
|
||||||
"""语义检索"""
|
"""语义检索"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
@ -54,7 +66,7 @@ class Memory(ABC):
|
||||||
"""删除记忆"""
|
"""删除记忆"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
|
async def store_batch(self, items: list[tuple[str, object, MetadataDict | None]]) -> None:
|
||||||
"""批量存储"""
|
"""批量存储"""
|
||||||
for key, value, metadata in items:
|
for key, value, metadata in items:
|
||||||
await self.store(key, value, metadata)
|
await self.store(key, value, metadata)
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,18 @@ import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
from agentkit.memory.base import MetadataDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 分块元数据:source_doc/position/char_count/chunking_strategy/heading/heading_level
|
||||||
|
# — 全部为原始标量(str/int)。
|
||||||
|
ChunkMetadata: TypeAlias = MetadataDict
|
||||||
|
# _split_by_headings 返回的节段结构。
|
||||||
|
SectionInfo: TypeAlias = dict[str, str | int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Chunk:
|
class Chunk:
|
||||||
|
|
@ -22,7 +30,7 @@ class Chunk:
|
||||||
|
|
||||||
chunk_id: str
|
chunk_id: str
|
||||||
content: str
|
content: str
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: ChunkMetadata = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if "source_doc" not in self.metadata:
|
if "source_doc" not in self.metadata:
|
||||||
|
|
@ -30,7 +38,7 @@ class Chunk:
|
||||||
if "position" not in self.metadata:
|
if "position" not in self.metadata:
|
||||||
self.metadata["position"] = 0
|
self.metadata["position"] = 0
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"chunk_id": self.chunk_id,
|
"chunk_id": self.chunk_id,
|
||||||
"content": self.content,
|
"content": self.content,
|
||||||
|
|
@ -57,7 +65,9 @@ class TextChunker:
|
||||||
separator: 优先分割符
|
separator: 优先分割符
|
||||||
"""
|
"""
|
||||||
if chunk_overlap >= chunk_size:
|
if chunk_overlap >= chunk_size:
|
||||||
raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})")
|
raise ValueError(
|
||||||
|
f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})"
|
||||||
|
)
|
||||||
self._chunk_size = chunk_size
|
self._chunk_size = chunk_size
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._separator = separator
|
self._separator = separator
|
||||||
|
|
@ -66,7 +76,7 @@ class TextChunker:
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
source_doc_id: str = "",
|
source_doc_id: str = "",
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: ChunkMetadata | None = None,
|
||||||
) -> list[Chunk]:
|
) -> list[Chunk]:
|
||||||
"""将文本分块
|
"""将文本分块
|
||||||
|
|
||||||
|
|
@ -96,11 +106,13 @@ class TextChunker:
|
||||||
chunk_meta = dict(base_meta)
|
chunk_meta = dict(base_meta)
|
||||||
chunk_meta["position"] = i
|
chunk_meta["position"] = i
|
||||||
chunk_meta["char_count"] = len(chunk_text)
|
chunk_meta["char_count"] = len(chunk_text)
|
||||||
chunks.append(Chunk(
|
chunks.append(
|
||||||
chunk_id=str(uuid.uuid4()),
|
Chunk(
|
||||||
content=chunk_text,
|
chunk_id=str(uuid.uuid4()),
|
||||||
metadata=chunk_meta,
|
content=chunk_text,
|
||||||
))
|
metadata=chunk_meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
@ -142,7 +154,9 @@ class TextChunker:
|
||||||
overlap_text[overlap_start:], segments
|
overlap_text[overlap_start:], segments
|
||||||
)
|
)
|
||||||
current = overlap_segments
|
current = overlap_segments
|
||||||
current_len = sum(len(s) for s in current) + len(self._separator) * max(0, len(current) - 1)
|
current_len = sum(len(s) for s in current) + len(self._separator) * max(
|
||||||
|
0, len(current) - 1
|
||||||
|
)
|
||||||
|
|
||||||
current.append(segment)
|
current.append(segment)
|
||||||
current_len += seg_len + len(self._separator)
|
current_len += seg_len + len(self._separator)
|
||||||
|
|
@ -214,7 +228,7 @@ class StructuralChunker:
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
source_doc_id: str = "",
|
source_doc_id: str = "",
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: ChunkMetadata | None = None,
|
||||||
) -> list[Chunk]:
|
) -> list[Chunk]:
|
||||||
"""将文本按结构分块
|
"""将文本按结构分块
|
||||||
|
|
||||||
|
|
@ -266,23 +280,25 @@ class StructuralChunker:
|
||||||
chunk_meta["heading"] = heading
|
chunk_meta["heading"] = heading
|
||||||
chunk_meta["heading_level"] = level
|
chunk_meta["heading_level"] = level
|
||||||
chunk_meta["char_count"] = len(content)
|
chunk_meta["char_count"] = len(content)
|
||||||
chunks.append(Chunk(
|
chunks.append(
|
||||||
chunk_id=str(uuid.uuid4()),
|
Chunk(
|
||||||
content=content,
|
chunk_id=str(uuid.uuid4()),
|
||||||
metadata=chunk_meta,
|
content=content,
|
||||||
))
|
metadata=chunk_meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
position += 1
|
position += 1
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def _split_by_headings(self, text: str) -> list[dict[str, Any]]:
|
def _split_by_headings(self, text: str) -> list[SectionInfo]:
|
||||||
"""按标题分割 Markdown 文本
|
"""按标题分割 Markdown 文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
列表,每项包含 heading, content, level
|
列表,每项包含 heading, content, level
|
||||||
"""
|
"""
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
sections: list[dict[str, Any]] = []
|
sections: list[SectionInfo] = []
|
||||||
current_heading = ""
|
current_heading = ""
|
||||||
current_level = 0
|
current_level = 0
|
||||||
current_lines: list[str] = []
|
current_lines: list[str] = []
|
||||||
|
|
@ -296,11 +312,13 @@ class StructuralChunker:
|
||||||
if current_lines:
|
if current_lines:
|
||||||
content = "\n".join(current_lines).strip()
|
content = "\n".join(current_lines).strip()
|
||||||
if content:
|
if content:
|
||||||
sections.append({
|
sections.append(
|
||||||
"heading": current_heading,
|
{
|
||||||
"content": content,
|
"heading": current_heading,
|
||||||
"level": current_level,
|
"content": content,
|
||||||
})
|
"level": current_level,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 开始新节
|
# 开始新节
|
||||||
current_heading = match.group(2).strip()
|
current_heading = match.group(2).strip()
|
||||||
|
|
@ -313,18 +331,22 @@ class StructuralChunker:
|
||||||
if current_lines:
|
if current_lines:
|
||||||
content = "\n".join(current_lines).strip()
|
content = "\n".join(current_lines).strip()
|
||||||
if content:
|
if content:
|
||||||
sections.append({
|
sections.append(
|
||||||
"heading": current_heading,
|
{
|
||||||
"content": content,
|
"heading": current_heading,
|
||||||
"level": current_level,
|
"content": content,
|
||||||
})
|
"level": current_level,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 如果没有标题结构,整体作为一个块
|
# 如果没有标题结构,整体作为一个块
|
||||||
if not sections:
|
if not sections:
|
||||||
sections.append({
|
sections.append(
|
||||||
"heading": "",
|
{
|
||||||
"content": text.strip(),
|
"heading": "",
|
||||||
"level": 0,
|
"content": text.strip(),
|
||||||
})
|
"level": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return sections
|
return sections
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,14 @@ from __future__ import annotations
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from agentkit.memory.base import MetadataDict
|
||||||
from agentkit.memory.embedder import EmbeddingCache
|
from agentkit.memory.embedder import EmbeddingCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,7 +28,7 @@ class ContextualChunk:
|
||||||
context_prefix: str
|
context_prefix: str
|
||||||
enhanced_content: str
|
enhanced_content: str
|
||||||
chunk_index: int
|
chunk_index: int
|
||||||
metadata: dict[str, Any]
|
metadata: MetadataDict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
|
|
@ -65,7 +69,7 @@ class ContextualChunker:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_gateway: Any = None,
|
llm_gateway: LLMGateway | None = None,
|
||||||
cache: EmbeddingCache | None = None,
|
cache: EmbeddingCache | None = None,
|
||||||
batch_size: int = 8,
|
batch_size: int = 8,
|
||||||
max_context_length: int = 200,
|
max_context_length: int = 200,
|
||||||
|
|
@ -90,7 +94,7 @@ class ContextualChunker:
|
||||||
self,
|
self,
|
||||||
document: str,
|
document: str,
|
||||||
chunks: list[str],
|
chunks: list[str],
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: MetadataDict | None = None,
|
||||||
) -> list[ContextualChunk]:
|
) -> list[ContextualChunk]:
|
||||||
"""为文档块添加上下文前缀
|
"""为文档块添加上下文前缀
|
||||||
|
|
||||||
|
|
@ -134,7 +138,7 @@ class ContextualChunker:
|
||||||
document: str,
|
document: str,
|
||||||
chunks: list[str],
|
chunks: list[str],
|
||||||
start_index: int,
|
start_index: int,
|
||||||
metadata: dict[str, Any] | None,
|
metadata: MetadataDict | None,
|
||||||
) -> list[ContextualChunk]:
|
) -> list[ContextualChunk]:
|
||||||
"""处理一批文档块"""
|
"""处理一批文档块"""
|
||||||
results: list[ContextualChunk] = []
|
results: list[ContextualChunk] = []
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,9 @@ import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
from agentkit.memory.base import MetadataDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -23,6 +25,11 @@ MAX_CONTENT_SIZE = 100 * 1024 * 1024 # 100MB
|
||||||
MAX_ROWS_PER_SHEET = 10_000
|
MAX_ROWS_PER_SHEET = 10_000
|
||||||
MAX_CELL_CHARS = 10_000
|
MAX_CELL_CHARS = 10_000
|
||||||
|
|
||||||
|
# 文档元数据:source/format/parser/page_count/table_count/sheet_count/row_count/
|
||||||
|
# heading_count/created_at/title/truncated — 全部为原始标量。
|
||||||
|
DocumentMetadata: TypeAlias = MetadataDict
|
||||||
|
ParseResult: TypeAlias = tuple[str, DocumentMetadata]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Document:
|
class Document:
|
||||||
|
|
@ -31,7 +38,7 @@ class Document:
|
||||||
doc_id: str
|
doc_id: str
|
||||||
title: str
|
title: str
|
||||||
content: str
|
content: str
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: DocumentMetadata = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if "source" not in self.metadata:
|
if "source" not in self.metadata:
|
||||||
|
|
@ -43,7 +50,7 @@ class Document:
|
||||||
if "created_at" not in self.metadata:
|
if "created_at" not in self.metadata:
|
||||||
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
|
self.metadata["created_at"] = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"doc_id": self.doc_id,
|
"doc_id": self.doc_id,
|
||||||
"title": self.title,
|
"title": self.title,
|
||||||
|
|
@ -136,12 +143,14 @@ class DocumentLoader:
|
||||||
|
|
||||||
parser = parsers.get(doc_format)
|
parser = parsers.get(doc_format)
|
||||||
if parser is None:
|
if parser is None:
|
||||||
logger.warning(f"Unsupported format '{doc_format}' for {filename}, falling back to text")
|
logger.warning(
|
||||||
|
f"Unsupported format '{doc_format}' for {filename}, falling back to text"
|
||||||
|
)
|
||||||
parser = self._parse_text
|
parser = self._parse_text
|
||||||
|
|
||||||
text, extra_meta = parser(content, filename)
|
text, extra_meta = parser(content, filename)
|
||||||
|
|
||||||
metadata: dict[str, Any] = {
|
metadata: DocumentMetadata = {
|
||||||
"source": filename,
|
"source": filename,
|
||||||
"format": doc_format,
|
"format": doc_format,
|
||||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
|
@ -159,7 +168,7 @@ class DocumentLoader:
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_pdf(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_pdf(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析 PDF 文件
|
"""解析 PDF 文件
|
||||||
|
|
||||||
优先使用 PyMuPDF (fitz),回退到 pdfplumber,最终回退到纯文本。
|
优先使用 PyMuPDF (fitz),回退到 pdfplumber,最终回退到纯文本。
|
||||||
|
|
@ -215,7 +224,7 @@ class DocumentLoader:
|
||||||
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
|
logger.warning(f"No PDF parser available for {filename}, falling back to text extraction")
|
||||||
return self._parse_text(content, filename)
|
return self._parse_text(content, filename)
|
||||||
|
|
||||||
def _parse_docx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_docx(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析 Word 文件
|
"""解析 Word 文件
|
||||||
|
|
||||||
使用 python-docx,回退到纯文本。
|
使用 python-docx,回退到纯文本。
|
||||||
|
|
@ -259,7 +268,7 @@ class DocumentLoader:
|
||||||
logger.warning(f"python-docx parsing failed for {filename}: {e}")
|
logger.warning(f"python-docx parsing failed for {filename}: {e}")
|
||||||
return self._parse_text(content, filename)
|
return self._parse_text(content, filename)
|
||||||
|
|
||||||
def _parse_xlsx(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_xlsx(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析 Excel 文件
|
"""解析 Excel 文件
|
||||||
|
|
||||||
使用 openpyxl,回退到纯文本。每个 sheet 转为 Markdown 表格,
|
使用 openpyxl,回退到纯文本。每个 sheet 转为 Markdown 表格,
|
||||||
|
|
@ -313,7 +322,7 @@ class DocumentLoader:
|
||||||
finally:
|
finally:
|
||||||
wb.close()
|
wb.close()
|
||||||
text = "\n".join(sections).strip()
|
text = "\n".join(sections).strip()
|
||||||
meta: dict[str, Any] = {
|
meta: DocumentMetadata = {
|
||||||
"parser": "openpyxl",
|
"parser": "openpyxl",
|
||||||
"sheet_count": sheet_count,
|
"sheet_count": sheet_count,
|
||||||
"row_count": total_rows,
|
"row_count": total_rows,
|
||||||
|
|
@ -328,7 +337,7 @@ class DocumentLoader:
|
||||||
logger.warning(f"openpyxl parsing failed for {filename}: {e}")
|
logger.warning(f"openpyxl parsing failed for {filename}: {e}")
|
||||||
return self._parse_text(content, filename)
|
return self._parse_text(content, filename)
|
||||||
|
|
||||||
def _parse_markdown(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_markdown(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析 Markdown 文件
|
"""解析 Markdown 文件
|
||||||
|
|
||||||
使用 mistune(如果可用),否则直接读取文本。
|
使用 mistune(如果可用),否则直接读取文本。
|
||||||
|
|
@ -347,7 +356,7 @@ class DocumentLoader:
|
||||||
title = line_stripped.lstrip("#").strip()
|
title = line_stripped.lstrip("#").strip()
|
||||||
break
|
break
|
||||||
|
|
||||||
meta: dict[str, Any] = {
|
meta: DocumentMetadata = {
|
||||||
"parser": "markdown",
|
"parser": "markdown",
|
||||||
}
|
}
|
||||||
if title:
|
if title:
|
||||||
|
|
@ -362,7 +371,7 @@ class DocumentLoader:
|
||||||
|
|
||||||
return text, meta
|
return text, meta
|
||||||
|
|
||||||
def _parse_html(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_html(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析 HTML 文件
|
"""解析 HTML 文件
|
||||||
|
|
||||||
使用 BeautifulSoup 提取文本,回退到纯文本。
|
使用 BeautifulSoup 提取文本,回退到纯文本。
|
||||||
|
|
@ -388,7 +397,7 @@ class DocumentLoader:
|
||||||
if soup.title and soup.title.string:
|
if soup.title and soup.title.string:
|
||||||
title = soup.title.string.strip()
|
title = soup.title.string.strip()
|
||||||
|
|
||||||
meta: dict[str, Any] = {
|
meta: DocumentMetadata = {
|
||||||
"parser": "beautifulsoup",
|
"parser": "beautifulsoup",
|
||||||
}
|
}
|
||||||
if title:
|
if title:
|
||||||
|
|
@ -402,7 +411,7 @@ class DocumentLoader:
|
||||||
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
|
logger.warning(f"BeautifulSoup parsing failed for {filename}: {e}")
|
||||||
return self._parse_text(content, filename)
|
return self._parse_text(content, filename)
|
||||||
|
|
||||||
def _parse_text(self, content: bytes, filename: str) -> tuple[str, dict[str, Any]]:
|
def _parse_text(self, content: bytes, filename: str) -> ParseResult:
|
||||||
"""解析纯文本文件"""
|
"""解析纯文本文件"""
|
||||||
try:
|
try:
|
||||||
text = content.decode("utf-8")
|
text = content.decode("utf-8")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
"""Embedder 接口与实现 - 文本向量化"""
|
"""Embedder 接口与实现 - 文本向量化"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -97,13 +102,14 @@ class OpenAIEmbedder(Embedder):
|
||||||
self._model = model
|
self._model = model
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._dimension = 1536 # text-embedding-3-small 默认维度
|
self._dimension = 1536 # text-embedding-3-small 默认维度
|
||||||
self._client: Any = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
"""Lazily create and reuse a single httpx.AsyncClient."""
|
"""Lazily create and reuse a single httpx.AsyncClient."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
self._client = httpx.AsyncClient(timeout=30.0)
|
self._client = httpx.AsyncClient(timeout=30.0)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,22 @@
|
||||||
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from agentkit.memory.base import Memory, MemoryItem
|
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||||
from agentkit.memory.embedder import Embedder
|
from agentkit.memory.embedder import Embedder
|
||||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,8 +33,8 @@ class EpisodicMemory(Memory):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: Any,
|
session_factory: object,
|
||||||
episodic_model: Any,
|
episodic_model: object,
|
||||||
embedder: Embedder | None = None,
|
embedder: Embedder | None = None,
|
||||||
decay_rate: float = 0.01,
|
decay_rate: float = 0.01,
|
||||||
alpha: float = 0.7,
|
alpha: float = 0.7,
|
||||||
|
|
@ -57,7 +62,7 @@ class EpisodicMemory(Memory):
|
||||||
self._pgvector_enabled = pgvector_enabled
|
self._pgvector_enabled = pgvector_enabled
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
|
|
||||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||||
"""存储任务经验"""
|
"""存储任务经验"""
|
||||||
async with self._session_factory() as db:
|
async with self._session_factory() as db:
|
||||||
try:
|
try:
|
||||||
|
|
@ -68,7 +73,11 @@ class EpisodicMemory(Memory):
|
||||||
embedding = None
|
embedding = None
|
||||||
if self._embedder:
|
if self._embedder:
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
text = value.get("output_summary", "") or value.get("input_summary", "") or json.dumps(value, ensure_ascii=False)[:500]
|
text = (
|
||||||
|
value.get("output_summary", "")
|
||||||
|
or value.get("input_summary", "")
|
||||||
|
or json.dumps(value, ensure_ascii=False)[:500]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text = str(value)
|
text = str(value)
|
||||||
embedding = await self._embedder.embed(text)
|
embedding = await self._embedder.embed(text)
|
||||||
|
|
@ -106,13 +115,11 @@ class EpisodicMemory(Memory):
|
||||||
logger.error(f"Failed to retrieve episodic memory: {e}")
|
logger.error(f"Failed to retrieve episodic memory: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _retrieve_pgvector(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
|
async def _retrieve_pgvector(
|
||||||
|
self, db: AsyncSession, query_embedding: list[float]
|
||||||
|
) -> MemoryItem | None:
|
||||||
"""使用 pgvector ``<=>`` 算符检索最相似条目"""
|
"""使用 pgvector ``<=>`` 算符检索最相似条目"""
|
||||||
sql = text(
|
sql = text(f"SELECT * FROM {self._table_name} ORDER BY embedding <=> :query_vec LIMIT :lim")
|
||||||
f"SELECT * FROM {self._table_name} "
|
|
||||||
f"ORDER BY embedding <=> :query_vec "
|
|
||||||
f"LIMIT :lim"
|
|
||||||
)
|
|
||||||
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
|
result = await db.execute(sql, {"query_vec": str(query_embedding), "lim": 1})
|
||||||
row = result.mappings().first()
|
row = result.mappings().first()
|
||||||
|
|
||||||
|
|
@ -147,7 +154,9 @@ class EpisodicMemory(Memory):
|
||||||
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _retrieve_client_side(self, db: Any, query_embedding: list[float]) -> MemoryItem | None:
|
async def _retrieve_client_side(
|
||||||
|
self, db: AsyncSession, query_embedding: list[float]
|
||||||
|
) -> MemoryItem | None:
|
||||||
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
||||||
Model = self._episodic_model
|
Model = self._episodic_model
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -193,7 +202,13 @@ class EpisodicMemory(Memory):
|
||||||
created_at=best_item.created_at or datetime.now(timezone.utc),
|
created_at=best_item.created_at or datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None, search_multiplier: int = 5) -> list[MemoryItem]:
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
filters: MetadataDict | None = None,
|
||||||
|
search_multiplier: int = 5,
|
||||||
|
) -> list[MemoryItem]:
|
||||||
"""语义检索相似历史案例
|
"""语义检索相似历史案例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -214,10 +229,10 @@ class EpisodicMemory(Memory):
|
||||||
|
|
||||||
async def _search_pgvector(
|
async def _search_pgvector(
|
||||||
self,
|
self,
|
||||||
db: Any,
|
db: AsyncSession,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
filters: dict[str, Any] | None,
|
filters: MetadataDict | None,
|
||||||
search_multiplier: int,
|
search_multiplier: int,
|
||||||
) -> list[MemoryItem]:
|
) -> list[MemoryItem]:
|
||||||
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
|
"""使用 pgvector ``<=>`` 算符检索,再 Python 侧 time_decay 重排"""
|
||||||
|
|
@ -225,7 +240,7 @@ class EpisodicMemory(Memory):
|
||||||
fetch_limit = top_k * search_multiplier
|
fetch_limit = top_k * search_multiplier
|
||||||
|
|
||||||
where_clauses = []
|
where_clauses = []
|
||||||
params: dict[str, Any] = {"query_vec": str(query_embedding), "lim": fetch_limit}
|
params: dict[str, object] = {"query_vec": str(query_embedding), "lim": fetch_limit}
|
||||||
|
|
||||||
filters = filters or {}
|
filters = filters or {}
|
||||||
if filters.get("agent_name"):
|
if filters.get("agent_name"):
|
||||||
|
|
@ -256,7 +271,11 @@ class EpisodicMemory(Memory):
|
||||||
items = []
|
items = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
row_embedding = row.get("embedding")
|
row_embedding = row.get("embedding")
|
||||||
age_hours = (datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600 if row.get("created_at") else 0
|
age_hours = (
|
||||||
|
(datetime.now(timezone.utc) - row["created_at"]).total_seconds() / 3600
|
||||||
|
if row.get("created_at")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
decay = math.exp(-self._decay_rate * age_hours)
|
decay = math.exp(-self._decay_rate * age_hours)
|
||||||
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
||||||
|
|
||||||
|
|
@ -266,33 +285,37 @@ class EpisodicMemory(Memory):
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=str(row.get("id", "")),
|
MemoryItem(
|
||||||
value={
|
key=str(row.get("id", "")),
|
||||||
"input_summary": row.get("input_summary", ""),
|
value={
|
||||||
"output_summary": row.get("output_summary", ""),
|
"input_summary": row.get("input_summary", ""),
|
||||||
"outcome": row.get("outcome", "success"),
|
"output_summary": row.get("output_summary", ""),
|
||||||
"quality_score": row.get("quality_score", 0.5),
|
"outcome": row.get("outcome", "success"),
|
||||||
"reflection": row.get("reflection", ""),
|
"quality_score": row.get("quality_score", 0.5),
|
||||||
},
|
"reflection": row.get("reflection", ""),
|
||||||
metadata={
|
},
|
||||||
"agent_name": row.get("agent_name", ""),
|
metadata={
|
||||||
"task_type": row.get("task_type", ""),
|
"agent_name": row.get("agent_name", ""),
|
||||||
"created_at": row["created_at"].isoformat() if row.get("created_at") else None,
|
"task_type": row.get("task_type", ""),
|
||||||
},
|
"created_at": row["created_at"].isoformat()
|
||||||
score=score,
|
if row.get("created_at")
|
||||||
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
else None,
|
||||||
))
|
},
|
||||||
|
score=score,
|
||||||
|
created_at=row.get("created_at") or datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
items.sort(key=lambda x: x.score, reverse=True)
|
items.sort(key=lambda x: x.score, reverse=True)
|
||||||
return items[:top_k]
|
return items[:top_k]
|
||||||
|
|
||||||
async def _search_client_side(
|
async def _search_client_side(
|
||||||
self,
|
self,
|
||||||
db: Any,
|
db: AsyncSession,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
filters: dict[str, Any] | None,
|
filters: MetadataDict | None,
|
||||||
search_multiplier: int,
|
search_multiplier: int,
|
||||||
) -> list[MemoryItem]:
|
) -> list[MemoryItem]:
|
||||||
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
||||||
|
|
@ -300,6 +323,7 @@ class EpisodicMemory(Memory):
|
||||||
filters = filters or {}
|
filters = filters or {}
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
stmt = select(Model)
|
stmt = select(Model)
|
||||||
|
|
||||||
if filters.get("agent_name"):
|
if filters.get("agent_name"):
|
||||||
|
|
@ -322,7 +346,11 @@ class EpisodicMemory(Memory):
|
||||||
# 计算得分并构建 MemoryItem
|
# 计算得分并构建 MemoryItem
|
||||||
items = []
|
items = []
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
|
age_hours = (
|
||||||
|
(datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600
|
||||||
|
if entry.created_at
|
||||||
|
else 0
|
||||||
|
)
|
||||||
decay = math.exp(-self._decay_rate * age_hours)
|
decay = math.exp(-self._decay_rate * age_hours)
|
||||||
time_decay_score = (entry.quality_score or 0.5) * decay
|
time_decay_score = (entry.quality_score or 0.5) * decay
|
||||||
|
|
||||||
|
|
@ -333,30 +361,34 @@ class EpisodicMemory(Memory):
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=str(entry.id),
|
MemoryItem(
|
||||||
value={
|
key=str(entry.id),
|
||||||
"input_summary": entry.input_summary,
|
value={
|
||||||
"output_summary": entry.output_summary,
|
"input_summary": entry.input_summary,
|
||||||
"outcome": entry.outcome,
|
"output_summary": entry.output_summary,
|
||||||
"quality_score": entry.quality_score,
|
"outcome": entry.outcome,
|
||||||
"reflection": entry.reflection,
|
"quality_score": entry.quality_score,
|
||||||
},
|
"reflection": entry.reflection,
|
||||||
metadata={
|
},
|
||||||
"agent_name": entry.agent_name,
|
metadata={
|
||||||
"task_type": entry.task_type,
|
"agent_name": entry.agent_name,
|
||||||
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
"task_type": entry.task_type,
|
||||||
},
|
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
||||||
score=score,
|
},
|
||||||
created_at=entry.created_at or datetime.now(timezone.utc),
|
score=score,
|
||||||
))
|
created_at=entry.created_at or datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
items.sort(key=lambda x: x.score, reverse=True)
|
items.sort(key=lambda x: x.score, reverse=True)
|
||||||
if len(items) < top_k:
|
if len(items) < top_k:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
|
"EpisodicMemory.search returned %d results after scoring (top_k=%d). "
|
||||||
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
|
"Consider increasing search_multiplier (current=%d) to avoid missing relevant entries.",
|
||||||
len(items), top_k, search_multiplier,
|
len(items),
|
||||||
|
top_k,
|
||||||
|
search_multiplier,
|
||||||
)
|
)
|
||||||
return items[:top_k]
|
return items[:top_k]
|
||||||
|
|
||||||
|
|
@ -364,8 +396,9 @@ class EpisodicMemory(Memory):
|
||||||
"""删除指定经验"""
|
"""删除指定经验"""
|
||||||
async with self._session_factory() as db:
|
async with self._session_factory() as db:
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import select, delete as sql_delete
|
from sqlalchemy import delete as sql_delete
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
Model = self._episodic_model
|
Model = self._episodic_model
|
||||||
|
|
||||||
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))
|
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,26 @@
|
||||||
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
配置驱动,不直接依赖业务系统代码,通过 base_url + api_key 连接。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from agentkit.memory.base import MetadataDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 标准化检索结果:id/content/score/source/document_id/document_title/
|
||||||
|
# knowledge_base_id/metadata — 值为原始标量或嵌套 dict。
|
||||||
|
RAGSearchResult: TypeAlias = dict[str, object]
|
||||||
|
# ingest() 写入的文档负载:title/content/source_type/metadata。
|
||||||
|
RAGIngestPayload: TypeAlias = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
class HttpRAGService:
|
class HttpRAGService:
|
||||||
"""HTTP 客户端,调用业务系统的知识库检索 API
|
"""HTTP 客户端,调用业务系统的知识库检索 API
|
||||||
|
|
@ -39,7 +52,7 @@ class HttpRAGService:
|
||||||
knowledge_base_ids: list[str] | None = None,
|
knowledge_base_ids: list[str] | None = None,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
contextual_chunking: bool = False,
|
contextual_chunking: bool = False,
|
||||||
llm_gateway: Any = None,
|
llm_gateway: LLMGateway | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -74,7 +87,7 @@ class HttpRAGService:
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_ids: list[str] | None = None,
|
knowledge_base_ids: list[str] | None = None,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[RAGSearchResult]:
|
||||||
"""语义检索知识库
|
"""语义检索知识库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -113,19 +126,23 @@ class HttpRAGService:
|
||||||
normalized = []
|
normalized = []
|
||||||
for r in results:
|
for r in results:
|
||||||
if isinstance(r, dict):
|
if isinstance(r, dict):
|
||||||
normalized.append({
|
normalized.append(
|
||||||
"id": r.get("chunk_id", r.get("id", "")),
|
{
|
||||||
"content": r.get("content", ""),
|
"id": r.get("chunk_id", r.get("id", "")),
|
||||||
"score": float(r.get("score", 0.0)),
|
"content": r.get("content", ""),
|
||||||
"source": r.get("source", "rag"),
|
"score": float(r.get("score", 0.0)),
|
||||||
"document_id": r.get("document_id", ""),
|
"source": r.get("source", "rag"),
|
||||||
"document_title": r.get("document_title", ""),
|
"document_id": r.get("document_id", ""),
|
||||||
"metadata": r.get("metadata", {}),
|
"document_title": r.get("document_title", ""),
|
||||||
})
|
"metadata": r.get("metadata", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}")
|
logger.error(
|
||||||
|
f"RAG search HTTP error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"RAG search request error: {e}")
|
logger.error(f"RAG search request error: {e}")
|
||||||
|
|
@ -141,7 +158,7 @@ class HttpRAGService:
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
use_rerank: bool = True,
|
use_rerank: bool = True,
|
||||||
use_compression: bool = False,
|
use_compression: bool = False,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[RAGSearchResult]:
|
||||||
"""增强语义检索知识库(支持 rerank 和 compression)
|
"""增强语义检索知识库(支持 rerank 和 compression)
|
||||||
|
|
||||||
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
|
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
|
||||||
|
|
@ -169,7 +186,7 @@ class HttpRAGService:
|
||||||
}
|
}
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
all_results: list[dict[str, Any]] = []
|
all_results: list[RAGSearchResult] = []
|
||||||
|
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
try:
|
try:
|
||||||
|
|
@ -189,28 +206,27 @@ class HttpRAGService:
|
||||||
# 标准化
|
# 标准化
|
||||||
for r in results:
|
for r in results:
|
||||||
if isinstance(r, dict):
|
if isinstance(r, dict):
|
||||||
all_results.append({
|
all_results.append(
|
||||||
"id": r.get("chunk_id", r.get("id", "")),
|
{
|
||||||
"content": r.get("content", ""),
|
"id": r.get("chunk_id", r.get("id", "")),
|
||||||
"score": float(r.get("score", 0.0)),
|
"content": r.get("content", ""),
|
||||||
"source": r.get("source", "rag"),
|
"score": float(r.get("score", 0.0)),
|
||||||
"document_id": r.get("document_id", ""),
|
"source": r.get("source", "rag"),
|
||||||
"document_title": r.get("document_title", ""),
|
"document_id": r.get("document_id", ""),
|
||||||
"knowledge_base_id": kb_id,
|
"document_title": r.get("document_title", ""),
|
||||||
"metadata": r.get("metadata", {}),
|
"knowledge_base_id": kb_id,
|
||||||
})
|
"metadata": r.get("metadata", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
if e.response.status_code == 404:
|
if e.response.status_code == 404:
|
||||||
# This KB doesn't support enhanced search — fall back to
|
# This KB doesn't support enhanced search — fall back to
|
||||||
# standard search for THIS KB only, not all KBs.
|
# standard search for THIS KB only, not all KBs.
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Enhanced search not available for KB {kb_id}, "
|
f"Enhanced search not available for KB {kb_id}, using standard search"
|
||||||
f"using standard search"
|
|
||||||
)
|
|
||||||
std_result = await self.search(
|
|
||||||
query, knowledge_base_ids=[kb_id], top_k=top_k
|
|
||||||
)
|
)
|
||||||
|
std_result = await self.search(query, knowledge_base_ids=[kb_id], top_k=top_k)
|
||||||
all_results.extend(std_result)
|
all_results.extend(std_result)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -232,9 +248,9 @@ class HttpRAGService:
|
||||||
async def ingest(
|
async def ingest(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: object,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: MetadataDict | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, object] | None:
|
||||||
"""写入文档到知识库(可选操作)
|
"""写入文档到知识库(可选操作)
|
||||||
|
|
||||||
When contextual_chunking is enabled and llm_gateway is configured,
|
When contextual_chunking is enabled and llm_gateway is configured,
|
||||||
|
|
@ -308,5 +324,5 @@ class HttpRAGService:
|
||||||
async def __aenter__(self) -> "HttpRAGService":
|
async def __aenter__(self) -> "HttpRAGService":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args: Any) -> None:
|
async def __aexit__(self, *args: object) -> None:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Protocol, TypeAlias, runtime_checkable
|
||||||
|
|
||||||
|
from agentkit.memory.base import MetadataDict
|
||||||
|
|
||||||
|
KBMetadata: TypeAlias = MetadataDict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -22,7 +26,7 @@ class Document:
|
||||||
content: str
|
content: str
|
||||||
title: str = ""
|
title: str = ""
|
||||||
source_id: str = ""
|
source_id: str = ""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: KBMetadata = field(default_factory=dict)
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,7 +38,7 @@ class QueryResult:
|
||||||
source_id: str
|
source_id: str
|
||||||
source_name: str
|
source_name: str
|
||||||
score: float
|
score: float
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: KBMetadata = field(default_factory=dict)
|
||||||
doc_id: str = ""
|
doc_id: str = ""
|
||||||
title: str = ""
|
title: str = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,25 +11,32 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
|
||||||
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
|
||||||
|
|
||||||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||||||
from agentkit.memory.document_loader import Document as LoaderDocument
|
from agentkit.memory.document_loader import Document as LoaderDocument
|
||||||
from agentkit.memory.embedder import Embedder
|
from agentkit.memory.embedder import Embedder
|
||||||
from agentkit.memory.knowledge_base import (
|
from agentkit.memory.knowledge_base import (
|
||||||
Document,
|
Document,
|
||||||
KnowledgeBase,
|
|
||||||
QueryResult,
|
QueryResult,
|
||||||
SourceInfo,
|
SourceInfo,
|
||||||
)
|
)
|
||||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# InMemoryLocalRAGService 内部存储的文档元信息结构。
|
||||||
|
# 字段:title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
|
||||||
|
InMemoryDocInfo: TypeAlias = dict[str, object]
|
||||||
|
# 内部 chunk 存储结构:content/embedding/metadata/source_doc_id。
|
||||||
|
InMemoryChunkInfo: TypeAlias = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
|
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
|
||||||
"""将 document_loader.Document 转换为 knowledge_base.Document"""
|
"""将 document_loader.Document 转换为 knowledge_base.Document"""
|
||||||
|
|
@ -53,7 +60,7 @@ class LocalRAGService:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
session_factory: Any,
|
session_factory: object,
|
||||||
embedder: Embedder,
|
embedder: Embedder,
|
||||||
chunk_size: int = 1000,
|
chunk_size: int = 1000,
|
||||||
chunk_overlap: int = 200,
|
chunk_overlap: int = 200,
|
||||||
|
|
@ -75,10 +82,14 @@ class LocalRAGService:
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||||
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
raise ValueError(
|
||||||
|
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
|
||||||
|
)
|
||||||
self._pgvector_enabled = pgvector_enabled
|
self._pgvector_enabled = pgvector_enabled
|
||||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._structural_chunker = StructuralChunker(
|
||||||
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
async def ingest(self, documents: list[Document]) -> list[str]:
|
async def ingest(self, documents: list[Document]) -> list[str]:
|
||||||
"""摄取文档列表
|
"""摄取文档列表
|
||||||
|
|
@ -136,9 +147,7 @@ class LocalRAGService:
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
sql = sql_text(
|
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
|
||||||
f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id"
|
|
||||||
)
|
|
||||||
await db.execute(sql, {"doc_id": id})
|
await db.execute(sql, {"doc_id": id})
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
@ -171,20 +180,15 @@ class LocalRAGService:
|
||||||
|
|
||||||
sources = []
|
sources = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
meta = {}
|
sources.append(
|
||||||
if row.get("doc_metadata"):
|
SourceInfo(
|
||||||
try:
|
source_id=row["source_doc_id"],
|
||||||
meta = json.loads(row["doc_metadata"])
|
source_name=row.get("source_title", ""),
|
||||||
except (json.JSONDecodeError, TypeError):
|
source_type=row.get("doc_format", "local"),
|
||||||
pass
|
document_count=row.get("chunk_count", 0),
|
||||||
|
last_updated=row["created_at"] if row.get("created_at") else None,
|
||||||
sources.append(SourceInfo(
|
)
|
||||||
source_id=row["source_doc_id"],
|
)
|
||||||
source_name=row.get("source_title", ""),
|
|
||||||
source_type=row.get("doc_format", "local"),
|
|
||||||
document_count=row.get("chunk_count", 0),
|
|
||||||
last_updated=row["created_at"] if row.get("created_at") else None,
|
|
||||||
))
|
|
||||||
return sources
|
return sources
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to list sources: {e}")
|
logger.error(f"Failed to list sources: {e}")
|
||||||
|
|
@ -271,7 +275,7 @@ class LocalRAGService:
|
||||||
|
|
||||||
async def _query_pgvector(
|
async def _query_pgvector(
|
||||||
self,
|
self,
|
||||||
db: Any,
|
db: AsyncSession,
|
||||||
query_embedding: list[float],
|
query_embedding: list[float],
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> list[QueryResult]:
|
) -> list[QueryResult]:
|
||||||
|
|
@ -286,10 +290,13 @@ class LocalRAGService:
|
||||||
f"LIMIT :lim"
|
f"LIMIT :lim"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await db.execute(sql, {
|
result = await db.execute(
|
||||||
"query_vec": str(query_embedding),
|
sql,
|
||||||
"lim": top_k,
|
{
|
||||||
})
|
"query_vec": str(query_embedding),
|
||||||
|
"lim": top_k,
|
||||||
|
},
|
||||||
|
)
|
||||||
rows = result.mappings().all()
|
rows = result.mappings().all()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
@ -306,21 +313,23 @@ class LocalRAGService:
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
results.append(QueryResult(
|
results.append(
|
||||||
content=row["content"],
|
QueryResult(
|
||||||
source_id=row["source_doc_id"],
|
content=row["content"],
|
||||||
source_name=row.get("source_title", ""),
|
source_id=row["source_doc_id"],
|
||||||
score=cosine,
|
source_name=row.get("source_title", ""),
|
||||||
metadata=chunk_meta,
|
score=cosine,
|
||||||
doc_id=row["source_doc_id"],
|
metadata=chunk_meta,
|
||||||
title=row.get("source_title", ""),
|
doc_id=row["source_doc_id"],
|
||||||
))
|
title=row.get("source_title", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _query_client_side(
|
async def _query_client_side(
|
||||||
self,
|
self,
|
||||||
db: Any,
|
db: AsyncSession,
|
||||||
query_embedding: list[float],
|
query_embedding: list[float],
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> list[QueryResult]:
|
) -> list[QueryResult]:
|
||||||
|
|
@ -363,15 +372,17 @@ class LocalRAGService:
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
candidates.append(QueryResult(
|
candidates.append(
|
||||||
content=row["content"],
|
QueryResult(
|
||||||
source_id=row["source_doc_id"],
|
content=row["content"],
|
||||||
source_name=row.get("source_title", ""),
|
source_id=row["source_doc_id"],
|
||||||
score=cosine,
|
source_name=row.get("source_title", ""),
|
||||||
metadata=chunk_meta,
|
score=cosine,
|
||||||
doc_id=row["source_doc_id"],
|
metadata=chunk_meta,
|
||||||
title=row.get("source_title", ""),
|
doc_id=row["source_doc_id"],
|
||||||
))
|
title=row.get("source_title", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||||
return candidates[:top_k]
|
return candidates[:top_k]
|
||||||
|
|
@ -398,11 +409,15 @@ class InMemoryLocalRAGService:
|
||||||
"""
|
"""
|
||||||
self._embedder = embedder
|
self._embedder = embedder
|
||||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._structural_chunker = StructuralChunker(
|
||||||
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
# 内存存储
|
# 内存存储
|
||||||
self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata}
|
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
|
||||||
self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
|
self._documents: dict[
|
||||||
|
str, InMemoryDocInfo
|
||||||
|
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
|
||||||
|
|
||||||
async def ingest(self, documents: list[Document]) -> list[str]:
|
async def ingest(self, documents: list[Document]) -> list[str]:
|
||||||
"""摄取文档列表
|
"""摄取文档列表
|
||||||
|
|
@ -459,15 +474,17 @@ class InMemoryLocalRAGService:
|
||||||
source_doc_id = chunk_data["source_doc_id"]
|
source_doc_id = chunk_data["source_doc_id"]
|
||||||
doc_info = self._documents.get(source_doc_id, {})
|
doc_info = self._documents.get(source_doc_id, {})
|
||||||
|
|
||||||
candidates.append(QueryResult(
|
candidates.append(
|
||||||
content=chunk_data["content"],
|
QueryResult(
|
||||||
source_id=source_doc_id,
|
content=chunk_data["content"],
|
||||||
source_name=doc_info.get("title", ""),
|
source_id=source_doc_id,
|
||||||
score=cosine,
|
source_name=doc_info.get("title", ""),
|
||||||
metadata=chunk_data.get("metadata", {}),
|
score=cosine,
|
||||||
doc_id=source_doc_id,
|
metadata=chunk_data.get("metadata", {}),
|
||||||
title=doc_info.get("title", ""),
|
doc_id=source_doc_id,
|
||||||
))
|
title=doc_info.get("title", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||||
return candidates[:top_k]
|
return candidates[:top_k]
|
||||||
|
|
@ -488,13 +505,15 @@ class InMemoryLocalRAGService:
|
||||||
"""列出已摄取的文档"""
|
"""列出已摄取的文档"""
|
||||||
sources = []
|
sources = []
|
||||||
for doc_id, doc_info in self._documents.items():
|
for doc_id, doc_info in self._documents.items():
|
||||||
sources.append(SourceInfo(
|
sources.append(
|
||||||
source_id=doc_id,
|
SourceInfo(
|
||||||
source_name=doc_info["title"],
|
source_id=doc_id,
|
||||||
source_type=doc_info.get("format", "local"),
|
source_name=doc_info["title"],
|
||||||
document_count=len(doc_info.get("chunk_ids", [])),
|
source_type=doc_info.get("format", "local"),
|
||||||
last_updated=doc_info.get("created_at"),
|
document_count=len(doc_info.get("chunk_ids", [])),
|
||||||
))
|
last_updated=doc_info.get("created_at"),
|
||||||
|
)
|
||||||
|
)
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
|
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult, SourceInfo
|
||||||
|
|
||||||
|
|
@ -186,15 +185,13 @@ class MultiSourceRetriever:
|
||||||
Returns:
|
Returns:
|
||||||
所有源的检索结果列表(已应用权重)
|
所有源的检索结果列表(已应用权重)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
|
async def _query_one(name: str, kb: KnowledgeBase) -> list[QueryResult]:
|
||||||
try:
|
try:
|
||||||
results = await kb.query(query, top_k=top_k)
|
results = await kb.query(query, top_k=top_k)
|
||||||
# 应用权重
|
# 应用权重
|
||||||
weight = (weights or {}).get(name, 1.0)
|
weight = (weights or {}).get(name, 1.0)
|
||||||
return [
|
return [replace(r, score=r.score * weight, source_name=name) for r in results]
|
||||||
replace(r, score=r.score * weight, source_name=name)
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Query failed for source '{name}': {e}")
|
logger.error(f"Query failed for source '{name}': {e}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
class MemoryFile:
|
class MemoryFile:
|
||||||
|
|
@ -26,8 +26,9 @@ class MemoryFile:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: Path, char_budget: int | None = None,
|
def __init__(
|
||||||
protected_sections: set[str] | None = None):
|
self, path: Path, char_budget: int | None = None, protected_sections: set[str] | None = None
|
||||||
|
):
|
||||||
self.path = Path(path)
|
self.path = Path(path)
|
||||||
self.char_budget = char_budget
|
self.char_budget = char_budget
|
||||||
self._protected_sections = protected_sections or set()
|
self._protected_sections = protected_sections or set()
|
||||||
|
|
@ -138,7 +139,7 @@ class MemoryFile:
|
||||||
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
|
for match in re.finditer(r"^## (.+)$", content, re.MULTILINE):
|
||||||
name = match.group(1).strip()
|
name = match.group(1).strip()
|
||||||
start = match.start()
|
start = match.start()
|
||||||
next_match = re.search(r"^## ", content[match.end():], re.MULTILINE)
|
next_match = re.search(r"^## ", content[match.end() :], re.MULTILINE)
|
||||||
if next_match:
|
if next_match:
|
||||||
end = match.end() + next_match.start()
|
end = match.end() + next_match.start()
|
||||||
else:
|
else:
|
||||||
|
|
@ -146,7 +147,7 @@ class MemoryFile:
|
||||||
sections.append((name, start, end))
|
sections.append((name, start, end))
|
||||||
|
|
||||||
if not sections:
|
if not sections:
|
||||||
return content[:self.char_budget]
|
return content[: self.char_budget]
|
||||||
|
|
||||||
# 保持原始顺序,标记每个 section 是否受保护
|
# 保持原始顺序,标记每个 section 是否受保护
|
||||||
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
|
ordered: list[tuple[str, str, bool]] = [] # (name, text, is_protected)
|
||||||
|
|
@ -222,8 +223,9 @@ class MemoryStore:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_dir: Path | str | None = None,
|
def __init__(
|
||||||
on_change: Callable[[str], None] | None = None):
|
self, base_dir: Path | str | None = None, on_change: Callable[[str], None] | None = None
|
||||||
|
):
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
base_dir = Path.home() / ".agentkit"
|
base_dir = Path.home() / ".agentkit"
|
||||||
self.base_dir = Path(base_dir)
|
self.base_dir = Path(base_dir)
|
||||||
|
|
@ -238,7 +240,9 @@ class MemoryStore:
|
||||||
protected_sections={"版本", "更新历史"},
|
protected_sections={"版本", "更新历史"},
|
||||||
)
|
)
|
||||||
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
|
self._user = MemoryFile(self.base_dir / "memories" / "USER.md", char_budget=USER_BUDGET)
|
||||||
self._memory = MemoryFile(self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET)
|
self._memory = MemoryFile(
|
||||||
|
self.base_dir / "memories" / "MEMORY.md", char_budget=MEMORY_BUDGET
|
||||||
|
)
|
||||||
self._daily_dir = self.base_dir / "memories" / "daily"
|
self._daily_dir = self.base_dir / "memories" / "daily"
|
||||||
self._daily_dir.mkdir(parents=True, exist_ok=True)
|
self._daily_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -376,4 +380,5 @@ class MemoryStore:
|
||||||
self._on_change(new_prompt)
|
self._on_change(new_prompt)
|
||||||
except Exception:
|
except Exception:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)
|
logging.getLogger(__name__).warning("Memory notify_change failed", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -87,10 +87,22 @@ class RuleQueryTransformer(QueryTransformerBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_FILLER_WORDS_CN: list[str] = [
|
_FILLER_WORDS_CN: list[str] = [
|
||||||
"帮我", "请", "一下", "分析", "看看", "告诉我", "想知道", "请问",
|
"帮我",
|
||||||
|
"请",
|
||||||
|
"一下",
|
||||||
|
"分析",
|
||||||
|
"看看",
|
||||||
|
"告诉我",
|
||||||
|
"想知道",
|
||||||
|
"请问",
|
||||||
]
|
]
|
||||||
_FILLER_WORDS_EN: list[str] = [
|
_FILLER_WORDS_EN: list[str] = [
|
||||||
"please", "can you", "help me", "could you", "i want to", "i need to",
|
"please",
|
||||||
|
"can you",
|
||||||
|
"help me",
|
||||||
|
"could you",
|
||||||
|
"i want to",
|
||||||
|
"i need to",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -101,9 +113,7 @@ class RuleQueryTransformer(QueryTransformerBase):
|
||||||
self._synonyms = synonyms or {}
|
self._synonyms = synonyms or {}
|
||||||
self._max_sub_queries = max_sub_queries
|
self._max_sub_queries = max_sub_queries
|
||||||
# Pre-compile filler patterns
|
# Pre-compile filler patterns
|
||||||
self._filler_patterns_cn = [
|
self._filler_patterns_cn = [re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN]
|
||||||
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
|
|
||||||
]
|
|
||||||
self._filler_patterns_en = [
|
self._filler_patterns_en = [
|
||||||
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
|
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
|
||||||
]
|
]
|
||||||
|
|
@ -166,7 +176,9 @@ def create_query_transformer(
|
||||||
"""工厂函数:根据策略创建查询改写器"""
|
"""工厂函数:根据策略创建查询改写器"""
|
||||||
if strategy == "llm":
|
if strategy == "llm":
|
||||||
if llm_gateway is None:
|
if llm_gateway is None:
|
||||||
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp")
|
logger.warning(
|
||||||
|
"LLM strategy requested but no llm_gateway provided, falling back to NoOp"
|
||||||
|
)
|
||||||
return NoOpQueryTransformer()
|
return NoOpQueryTransformer()
|
||||||
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
|
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
|
||||||
elif strategy == "rule":
|
elif strategy == "rule":
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.memory.base import MemoryItem
|
from agentkit.memory.base import MemoryItem, MetadataDict
|
||||||
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
||||||
from agentkit.memory.relevance_scorer import (
|
from agentkit.memory.relevance_scorer import (
|
||||||
RelevanceScorer,
|
RelevanceScorer,
|
||||||
|
|
@ -19,6 +19,10 @@ from agentkit.memory.relevance_scorer import (
|
||||||
RetrievalEvaluation,
|
RetrievalEvaluation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# 避免与 retriever.py 形成运行时循环导入。
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,7 +74,7 @@ class RAGSelfCorrectionLoop:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
retriever: Any, # MemoryRetriever
|
retriever: MemoryRetriever,
|
||||||
scorer: RelevanceScorer | None = None,
|
scorer: RelevanceScorer | None = None,
|
||||||
query_transformer: QueryTransformerBase | None = None,
|
query_transformer: QueryTransformerBase | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
|
|
@ -87,7 +91,7 @@ class RAGSelfCorrectionLoop:
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
token_budget: int = 3000,
|
token_budget: int = 3000,
|
||||||
filters: dict[str, Any] | None = None,
|
filters: MetadataDict | None = None,
|
||||||
) -> RAGLoopResult:
|
) -> RAGLoopResult:
|
||||||
"""执行带自纠正的检索
|
"""执行带自纠正的检索
|
||||||
|
|
||||||
|
|
@ -107,8 +111,11 @@ class RAGSelfCorrectionLoop:
|
||||||
while retry_count <= self._max_retries:
|
while retry_count <= self._max_retries:
|
||||||
# RETRIEVE
|
# RETRIEVE
|
||||||
items = await self._retriever.retrieve(
|
items = await self._retriever.retrieve(
|
||||||
current_query, top_k=top_k, token_budget=token_budget,
|
current_query,
|
||||||
filters=filters, _skip_correction=True,
|
top_k=top_k,
|
||||||
|
token_budget=token_budget,
|
||||||
|
filters=filters,
|
||||||
|
_skip_correction=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# EVALUATE
|
# EVALUATE
|
||||||
|
|
@ -144,9 +151,7 @@ class RAGSelfCorrectionLoop:
|
||||||
# CORRECT — rewrite query and retry
|
# CORRECT — rewrite query and retry
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count <= self._max_retries:
|
if retry_count <= self._max_retries:
|
||||||
current_query = await self._rewrite_query(
|
current_query = await self._rewrite_query(query, current_query, evaluation)
|
||||||
query, current_query, evaluation
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# DEGRADE — exceeded max retries
|
# DEGRADE — exceeded max retries
|
||||||
|
|
@ -154,9 +159,7 @@ class RAGSelfCorrectionLoop:
|
||||||
|
|
||||||
# Degraded result: filter to relevant items and mark low confidence
|
# Degraded result: filter to relevant items and mark low confidence
|
||||||
relevant_items = [
|
relevant_items = [
|
||||||
s.item
|
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
|
||||||
for s in evaluation.scores
|
|
||||||
if s.verdict != RelevanceVerdict.INCORRECT
|
|
||||||
]
|
]
|
||||||
result_items = relevant_items if relevant_items else items
|
result_items = relevant_items if relevant_items else items
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.memory.base import MemoryItem
|
from agentkit.memory.base import MemoryItem
|
||||||
|
|
||||||
|
|
@ -120,9 +118,7 @@ class RelevanceScorer:
|
||||||
reason=reason,
|
reason=reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(self, query: str, items: list[MemoryItem]) -> RetrievalEvaluation:
|
||||||
self, query: str, items: list[MemoryItem]
|
|
||||||
) -> RetrievalEvaluation:
|
|
||||||
"""评估一次检索的整体质量"""
|
"""评估一次检索的整体质量"""
|
||||||
if not items:
|
if not items:
|
||||||
return RetrievalEvaluation(
|
return RetrievalEvaluation(
|
||||||
|
|
@ -134,9 +130,7 @@ class RelevanceScorer:
|
||||||
)
|
)
|
||||||
|
|
||||||
scores = [self.score_item(query, item) for item in items]
|
scores = [self.score_item(query, item) for item in items]
|
||||||
relevant_count = sum(
|
relevant_count = sum(1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT)
|
||||||
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
|
|
||||||
)
|
|
||||||
avg_score = sum(s.score for s in scores) / len(scores)
|
avg_score = sum(s.score for s in scores) / len(scores)
|
||||||
|
|
||||||
# Overall verdict based on average score and relevant ratio
|
# Overall verdict based on average score and relevant ratio
|
||||||
|
|
|
||||||
|
|
@ -7,19 +7,16 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.memory.base import Memory, MemoryItem
|
from agentkit.memory.base import MemoryItem, MetadataDict
|
||||||
from agentkit.memory.working import WorkingMemory
|
from agentkit.memory.working import WorkingMemory
|
||||||
from agentkit.memory.episodic import EpisodicMemory
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.semantic import SemanticMemory
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
from agentkit.memory.query_transformer import QueryTransformerBase
|
from agentkit.memory.query_transformer import QueryTransformerBase
|
||||||
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
|
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
|
||||||
from agentkit.memory.relevance_scorer import RelevanceScorer
|
from agentkit.memory.relevance_scorer import RelevanceScorer
|
||||||
from agentkit.memory.knowledge_base import KnowledgeBase, QueryResult
|
from agentkit.memory.knowledge_base import KnowledgeBase
|
||||||
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
|
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
@ -32,11 +29,11 @@ def _estimate_tokens(text: str) -> int:
|
||||||
Chinese characters typically use 1-2 tokens each.
|
Chinese characters typically use 1-2 tokens each.
|
||||||
English words typically use 1 token each.
|
English words typically use 1 token each.
|
||||||
"""
|
"""
|
||||||
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
cjk_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
|
||||||
non_cjk = text
|
non_cjk = text
|
||||||
for c in text:
|
for c in text:
|
||||||
if '\u4e00' <= c <= '\u9fff':
|
if "\u4e00" <= c <= "\u9fff":
|
||||||
non_cjk = non_cjk.replace(c, ' ')
|
non_cjk = non_cjk.replace(c, " ")
|
||||||
word_count = len(non_cjk.split())
|
word_count = len(non_cjk.split())
|
||||||
return cjk_count * 2 + word_count
|
return cjk_count * 2 + word_count
|
||||||
|
|
||||||
|
|
@ -89,7 +86,7 @@ class MemoryRetriever:
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
token_budget: int = 3000,
|
token_budget: int = 3000,
|
||||||
filters: dict[str, Any] | None = None,
|
filters: MetadataDict | None = None,
|
||||||
_skip_correction: bool = False,
|
_skip_correction: bool = False,
|
||||||
sources: list[str] | None = None,
|
sources: list[str] | None = None,
|
||||||
source_weights: dict[str, float] | None = None,
|
source_weights: dict[str, float] | None = None,
|
||||||
|
|
@ -121,9 +118,7 @@ class MemoryRetriever:
|
||||||
query, top_k=top_k, token_budget=token_budget, filters=filters
|
query, top_k=top_k, token_budget=token_budget, filters=filters
|
||||||
)
|
)
|
||||||
if result.degraded:
|
if result.degraded:
|
||||||
logger.warning(
|
logger.warning(f"RAG self-correction degraded after {result.total_retries} retries")
|
||||||
f"RAG self-correction degraded after {result.total_retries} retries"
|
|
||||||
)
|
|
||||||
return result.items
|
return result.items
|
||||||
# Query transformation
|
# Query transformation
|
||||||
if self._query_transformer is not None:
|
if self._query_transformer is not None:
|
||||||
|
|
@ -139,9 +134,7 @@ class MemoryRetriever:
|
||||||
|
|
||||||
# Sub-query search in parallel
|
# Sub-query search in parallel
|
||||||
if sub_queries:
|
if sub_queries:
|
||||||
sub_tasks = [
|
sub_tasks = [self._search_layers(sq, top_k, filters) for sq in sub_queries]
|
||||||
self._search_layers(sq, top_k, filters) for sq in sub_queries
|
|
||||||
]
|
|
||||||
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
|
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
|
||||||
for result in sub_results:
|
for result in sub_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
|
@ -178,7 +171,7 @@ class MemoryRetriever:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
filters: dict[str, Any] | None = None,
|
filters: MetadataDict | None = None,
|
||||||
) -> list[MemoryItem]:
|
) -> list[MemoryItem]:
|
||||||
"""Search all configured memory layers with a single query"""
|
"""Search all configured memory layers with a single query"""
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
@ -237,18 +230,20 @@ class MemoryRetriever:
|
||||||
# QueryResult → MemoryItem
|
# QueryResult → MemoryItem
|
||||||
items = []
|
items = []
|
||||||
for r in kb_results:
|
for r in kb_results:
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=r.source_id,
|
MemoryItem(
|
||||||
value=r.content,
|
key=r.source_id,
|
||||||
metadata={
|
value=r.content,
|
||||||
**r.metadata,
|
metadata={
|
||||||
"source": "rag",
|
**r.metadata,
|
||||||
"source_name": r.source_name,
|
"source": "rag",
|
||||||
"doc_id": r.doc_id,
|
"source_name": r.source_name,
|
||||||
"document_title": r.title,
|
"doc_id": r.doc_id,
|
||||||
},
|
"document_title": r.title,
|
||||||
score=r.score,
|
},
|
||||||
))
|
score=r.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Token 预算管理
|
# Token 预算管理
|
||||||
selected = []
|
selected = []
|
||||||
|
|
@ -318,7 +313,9 @@ class MemoryRetriever:
|
||||||
if source == "rag":
|
if source == "rag":
|
||||||
kb_type = item.metadata.get("kb_type", "知识库")
|
kb_type = item.metadata.get("kb_type", "知识库")
|
||||||
document_title = item.metadata.get("document_title", "未知文档")
|
document_title = item.metadata.get("document_title", "未知文档")
|
||||||
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
|
return (
|
||||||
|
f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
|
||||||
|
)
|
||||||
elif source == "graph":
|
elif source == "graph":
|
||||||
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
|
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
|
||||||
elif source == "episodic":
|
elif source == "episodic":
|
||||||
|
|
@ -330,7 +327,7 @@ class MemoryRetriever:
|
||||||
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
|
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
self, key: str, value: Any, metadata: dict[str, Any] | None = None
|
self, key: str, value: object, metadata: MetadataDict | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store an episode into episodic memory if available.
|
"""Store an episode into episodic memory if available.
|
||||||
|
|
||||||
|
|
@ -386,12 +383,14 @@ class RetrieveKnowledgeTool(Tool):
|
||||||
items = await self._retriever.retrieve(query, top_k=5)
|
items = await self._retriever.retrieve(query, top_k=5)
|
||||||
results = []
|
results = []
|
||||||
for item in items:
|
for item in items:
|
||||||
results.append({
|
results.append(
|
||||||
"content": item.value,
|
{
|
||||||
"score": item.score,
|
"content": item.value,
|
||||||
"source": item.metadata.get("source", "unknown"),
|
"score": item.score,
|
||||||
"document_title": item.metadata.get("document_title", ""),
|
"source": item.metadata.get("source", "unknown"),
|
||||||
})
|
"document_title": item.metadata.get("document_title", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
return {"query": query, "results": results, "call_count": self._call_count}
|
return {"query": query, "results": results, "call_count": self._call_count}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e), "results": []}
|
return {"error": str(e), "results": []}
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,45 @@
|
||||||
适配器模式,对接外部 RAG 服务和知识图谱。
|
适配器模式,对接外部 RAG 服务和知识图谱。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
from __future__ import annotations
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.memory.base import Memory, MemoryItem
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.memory.http_rag import RAGSearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _RAGServiceLike(Protocol):
|
||||||
|
"""RAG 检索服务最小接口契约(duck-typed)。"""
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_ids: list[str] | None = ...,
|
||||||
|
top_k: int = ...,
|
||||||
|
) -> list[RAGSearchResult]: ...
|
||||||
|
|
||||||
|
async def enhanced_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_ids: list[str] | None = ...,
|
||||||
|
top_k: int = ...,
|
||||||
|
use_rerank: bool = ...,
|
||||||
|
use_compression: bool = ...,
|
||||||
|
) -> list[RAGSearchResult]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _GraphServiceLike(Protocol):
|
||||||
|
"""知识图谱服务最小接口契约(duck-typed)。"""
|
||||||
|
|
||||||
|
async def query(self, query: str, depth: int = ...) -> list[dict[str, object]]: ...
|
||||||
|
|
||||||
|
|
||||||
class SemanticMemory(Memory):
|
class SemanticMemory(Memory):
|
||||||
"""Semantic Memory - 知识库检索
|
"""Semantic Memory - 知识库检索
|
||||||
|
|
||||||
|
|
@ -19,8 +50,8 @@ class SemanticMemory(Memory):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
rag_service: Any = None,
|
rag_service: _RAGServiceLike | None = None,
|
||||||
graph_service: Any = None,
|
graph_service: _GraphServiceLike | None = None,
|
||||||
knowledge_base_ids: list[str] | None = None,
|
knowledge_base_ids: list[str] | None = None,
|
||||||
search_mode: str = "standard",
|
search_mode: str = "standard",
|
||||||
use_rerank: bool = True,
|
use_rerank: bool = True,
|
||||||
|
|
@ -45,9 +76,9 @@ class SemanticMemory(Memory):
|
||||||
self._use_compression = use_compression
|
self._use_compression = use_compression
|
||||||
self._kb_weights = kb_weights
|
self._kb_weights = kb_weights
|
||||||
|
|
||||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||||
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
||||||
if self._rag_service and hasattr(self._rag_service, 'ingest'):
|
if self._rag_service and hasattr(self._rag_service, "ingest"):
|
||||||
await self._rag_service.ingest(key, value, metadata)
|
await self._rag_service.ingest(key, value, metadata)
|
||||||
else:
|
else:
|
||||||
logger.warning("SemanticMemory.store: no RAG service configured for writing")
|
logger.warning("SemanticMemory.store: no RAG service configured for writing")
|
||||||
|
|
@ -56,7 +87,9 @@ class SemanticMemory(Memory):
|
||||||
"""按 key 精确检索(Semantic Memory 通常不按 key 检索)"""
|
"""按 key 精确检索(Semantic Memory 通常不按 key 检索)"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
async def search(
|
||||||
|
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||||
|
) -> list[MemoryItem]:
|
||||||
"""语义检索知识库"""
|
"""语义检索知识库"""
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
|
|
@ -64,7 +97,9 @@ class SemanticMemory(Memory):
|
||||||
if self._rag_service:
|
if self._rag_service:
|
||||||
try:
|
try:
|
||||||
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
||||||
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"):
|
if self._search_mode == "enhanced" and hasattr(
|
||||||
|
self._rag_service, "enhanced_search"
|
||||||
|
):
|
||||||
results = await self._rag_service.enhanced_search(
|
results = await self._rag_service.enhanced_search(
|
||||||
query,
|
query,
|
||||||
knowledge_base_ids=kb_ids,
|
knowledge_base_ids=kb_ids,
|
||||||
|
|
@ -73,24 +108,28 @@ class SemanticMemory(Memory):
|
||||||
use_compression=self._use_compression,
|
use_compression=self._use_compression,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
results = await self._rag_service.search(
|
||||||
|
query, knowledge_base_ids=kb_ids, top_k=top_k
|
||||||
|
)
|
||||||
for r in results:
|
for r in results:
|
||||||
kb_id = r.get("knowledge_base_id", "")
|
kb_id = r.get("knowledge_base_id", "")
|
||||||
score = r.get("score", 0.0)
|
score = r.get("score", 0.0)
|
||||||
# Apply per-KB weights
|
# Apply per-KB weights
|
||||||
if self._kb_weights and kb_id in self._kb_weights:
|
if self._kb_weights and kb_id in self._kb_weights:
|
||||||
score *= self._kb_weights[kb_id]
|
score *= self._kb_weights[kb_id]
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=r.get("id", ""),
|
MemoryItem(
|
||||||
value=r.get("content", ""),
|
key=r.get("id", ""),
|
||||||
metadata={
|
value=r.get("content", ""),
|
||||||
"source": r.get("source", "rag"),
|
metadata={
|
||||||
"score": score,
|
"source": r.get("source", "rag"),
|
||||||
"document_id": r.get("document_id"),
|
"score": score,
|
||||||
"knowledge_base_id": kb_id,
|
"document_id": r.get("document_id"),
|
||||||
},
|
"knowledge_base_id": kb_id,
|
||||||
score=score,
|
},
|
||||||
))
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAG search failed: {e}")
|
logger.error(f"RAG search failed: {e}")
|
||||||
|
|
||||||
|
|
@ -99,16 +138,18 @@ class SemanticMemory(Memory):
|
||||||
try:
|
try:
|
||||||
graph_results = await self._graph_service.query(query, depth=2)
|
graph_results = await self._graph_service.query(query, depth=2)
|
||||||
for r in graph_results[:top_k]:
|
for r in graph_results[:top_k]:
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=r.get("id", ""),
|
MemoryItem(
|
||||||
value=r.get("content", ""),
|
key=r.get("id", ""),
|
||||||
metadata={
|
value=r.get("content", ""),
|
||||||
"source": "graph",
|
metadata={
|
||||||
"entities": r.get("entities", []),
|
"source": "graph",
|
||||||
"relations": r.get("relations", []),
|
"entities": r.get("entities", []),
|
||||||
},
|
"relations": r.get("relations", []),
|
||||||
score=r.get("score", 0.0),
|
},
|
||||||
))
|
score=r.get("score", 0.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Graph search failed: {e}")
|
logger.error(f"Graph search failed: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,10 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
from agentkit.memory.base import Memory, MemoryItem
|
from agentkit.memory.base import Memory, MemoryItem, MetadataDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -32,7 +31,7 @@ class WorkingMemory(Memory):
|
||||||
def _make_key(self, key: str) -> str:
|
def _make_key(self, key: str) -> str:
|
||||||
return f"{self._key_prefix}:{key}"
|
return f"{self._key_prefix}:{key}"
|
||||||
|
|
||||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
async def store(self, key: str, value: object, metadata: MetadataDict | None = None) -> None:
|
||||||
redis_key = self._make_key(key)
|
redis_key = self._make_key(key)
|
||||||
item = MemoryItem(
|
item = MemoryItem(
|
||||||
key=key,
|
key=key,
|
||||||
|
|
@ -57,10 +56,14 @@ class WorkingMemory(Memory):
|
||||||
value=item_dict["value"],
|
value=item_dict["value"],
|
||||||
metadata=item_dict.get("metadata", {}),
|
metadata=item_dict.get("metadata", {}),
|
||||||
score=item_dict.get("score", 1.0),
|
score=item_dict.get("score", 1.0),
|
||||||
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.now(timezone.utc),
|
created_at=datetime.fromisoformat(item_dict["created_at"])
|
||||||
|
if item_dict.get("created_at")
|
||||||
|
else datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
async def search(
|
||||||
|
self, query: str, top_k: int = 5, filters: MetadataDict | None = None
|
||||||
|
) -> list[MemoryItem]:
|
||||||
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
|
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
|
||||||
pattern = self._make_key(f"{query}*")
|
pattern = self._make_key(f"{query}*")
|
||||||
keys = []
|
keys = []
|
||||||
|
|
@ -74,13 +77,15 @@ class WorkingMemory(Memory):
|
||||||
data = await self._redis.get(key)
|
data = await self._redis.get(key)
|
||||||
if data:
|
if data:
|
||||||
item_dict = json.loads(data)
|
item_dict = json.loads(data)
|
||||||
items.append(MemoryItem(
|
items.append(
|
||||||
key=item_dict["key"],
|
MemoryItem(
|
||||||
value=item_dict["value"],
|
key=item_dict["key"],
|
||||||
metadata=item_dict.get("metadata", {}),
|
value=item_dict["value"],
|
||||||
score=1.0,
|
metadata=item_dict.get("metadata", {}),
|
||||||
created_at=datetime.now(timezone.utc),
|
score=1.0,
|
||||||
))
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
async def delete(self, key: str) -> bool:
|
async def delete(self, key: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -25,6 +25,28 @@ _TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||||||
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _RedisPipelineLike(Protocol):
|
||||||
|
def set(self, key: str, value: str, ex: int | None = None) -> object: ...
|
||||||
|
def zadd(self, name: str, mapping: dict[str, float]) -> object: ...
|
||||||
|
def get(self, key: str) -> object: ...
|
||||||
|
def delete(self, *keys: str) -> object: ...
|
||||||
|
async def execute(self) -> list[object]: ...
|
||||||
|
|
||||||
|
class _RedisLike(Protocol):
|
||||||
|
async def set(self, key: str, value: str, ex: int | None = None) -> object: ...
|
||||||
|
async def get(self, key: str) -> object: ...
|
||||||
|
def pipeline(self) -> _RedisPipelineLike: ...
|
||||||
|
async def zrange(self, name: str, start: int, stop: int) -> list[object]: ...
|
||||||
|
|
||||||
|
class _PlanLike(Protocol):
|
||||||
|
@property
|
||||||
|
def id(self) -> str: ...
|
||||||
|
def to_dict(self) -> dict[str, object]: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CheckpointData:
|
class CheckpointData:
|
||||||
"""单个阶段的 checkpoint 数据。"""
|
"""单个阶段的 checkpoint 数据。"""
|
||||||
|
|
@ -33,15 +55,15 @@ class CheckpointData:
|
||||||
phase_id: str
|
phase_id: str
|
||||||
phase_name: str
|
phase_name: str
|
||||||
phase_status: str
|
phase_status: str
|
||||||
phase_result: dict[str, Any] | None = None
|
phase_result: dict[str, object] | None = None
|
||||||
plan_status: str = ""
|
plan_status: str = ""
|
||||||
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
|
def from_dict(cls, data: dict[str, object]) -> CheckpointData:
|
||||||
return cls(
|
return cls(
|
||||||
plan_id=data.get("plan_id", ""),
|
plan_id=data.get("plan_id", ""),
|
||||||
phase_id=data.get("phase_id", ""),
|
phase_id=data.get("phase_id", ""),
|
||||||
|
|
@ -67,7 +89,7 @@ class PipelineCheckpoint:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_client: Any = None,
|
redis_client: _RedisLike | None = None,
|
||||||
prefix: str = _KEY_PREFIX,
|
prefix: str = _KEY_PREFIX,
|
||||||
ttl_seconds: int = _TTL_SECONDS,
|
ttl_seconds: int = _TTL_SECONDS,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -78,7 +100,7 @@ class PipelineCheckpoint:
|
||||||
# P1 #6: 改用 dict keyed by phase_id,避免重复 append
|
# P1 #6: 改用 dict keyed by phase_id,避免重复 append
|
||||||
self._memory: dict[str, dict[str, CheckpointData]] = {}
|
self._memory: dict[str, dict[str, CheckpointData]] = {}
|
||||||
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
||||||
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
|
self._memory_plans: dict[str, tuple[dict[str, object], float]] = {}
|
||||||
|
|
||||||
def _is_expired(self, saved_at: str) -> bool:
|
def _is_expired(self, saved_at: str) -> bool:
|
||||||
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
||||||
|
|
@ -102,7 +124,7 @@ class PipelineCheckpoint:
|
||||||
"""完整 plan JSON 的存储键。"""
|
"""完整 plan JSON 的存储键。"""
|
||||||
return f"{self._prefix}:plan:{plan_id}"
|
return f"{self._prefix}:plan:{plan_id}"
|
||||||
|
|
||||||
async def save_plan(self, plan: Any) -> None:
|
async def save_plan(self, plan: _PlanLike) -> None:
|
||||||
"""保存完整 TeamPlan(用于 resume 重建)。
|
"""保存完整 TeamPlan(用于 resume 重建)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -121,7 +143,7 @@ class PipelineCheckpoint:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
||||||
|
|
||||||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
async def load_plan(self, plan_id: str) -> dict[str, object] | None:
|
||||||
"""加载完整 plan JSON。"""
|
"""加载完整 plan JSON。"""
|
||||||
# 优先 Redis
|
# 优先 Redis
|
||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
|
|
@ -142,7 +164,7 @@ class PipelineCheckpoint:
|
||||||
return None
|
return None
|
||||||
return plan_dict
|
return plan_dict
|
||||||
|
|
||||||
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
async def save(self, plan_id: str, phase: object, plan_status: str) -> None:
|
||||||
"""保存阶段 checkpoint。
|
"""保存阶段 checkpoint。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -212,7 +234,8 @@ class PipelineCheckpoint:
|
||||||
if not phase_ids:
|
if not phase_ids:
|
||||||
# Redis 无数据,检查内存(过滤过期)
|
# Redis 无数据,检查内存(过滤过期)
|
||||||
return [
|
return [
|
||||||
c for c in self._memory.get(plan_id, {}).values()
|
c
|
||||||
|
for c in self._memory.get(plan_id, {}).values()
|
||||||
if not self._is_expired(c.saved_at)
|
if not self._is_expired(c.saved_at)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -236,8 +259,7 @@ class PipelineCheckpoint:
|
||||||
|
|
||||||
# 内存降级(过滤过期 checkpoint)
|
# 内存降级(过滤过期 checkpoint)
|
||||||
return [
|
return [
|
||||||
c for c in self._memory.get(plan_id, {}).values()
|
c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at)
|
||||||
if not self._is_expired(c.saved_at)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
async def clear(self, plan_id: str) -> None:
|
async def clear(self, plan_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""Saga compensation pattern for Pipeline execution"""
|
"""Saga compensation pattern for Pipeline execution"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ class CompletedStep:
|
||||||
"""Record of a completed step with its compensation"""
|
"""Record of a completed step with its compensation"""
|
||||||
|
|
||||||
step_name: str
|
step_name: str
|
||||||
result: Any
|
result: object
|
||||||
compensate_action: str | None = None
|
compensate_action: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,9 +28,7 @@ class CompensationResult:
|
||||||
class SagaOrchestrator:
|
class SagaOrchestrator:
|
||||||
"""Orchestrates LIFO compensation for failed pipelines"""
|
"""Orchestrates LIFO compensation for failed pipelines"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, execute_skill_func: Callable[..., Awaitable[object]] | None = None):
|
||||||
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
execute_skill_func: Async function to execute a skill by name
|
execute_skill_func: Async function to execute a skill by name
|
||||||
|
|
@ -42,7 +40,7 @@ class SagaOrchestrator:
|
||||||
def record_completed(
|
def record_completed(
|
||||||
self,
|
self,
|
||||||
step_name: str,
|
step_name: str,
|
||||||
result: Any,
|
result: object,
|
||||||
compensate_action: str | None = None,
|
compensate_action: str | None = None,
|
||||||
):
|
):
|
||||||
"""Record a completed step for potential compensation"""
|
"""Record a completed step for potential compensation"""
|
||||||
|
|
@ -59,9 +57,7 @@ class SagaOrchestrator:
|
||||||
results: list[CompensationResult] = []
|
results: list[CompensationResult] = []
|
||||||
for step in reversed(self._completed_steps):
|
for step in reversed(self._completed_steps):
|
||||||
if step.compensate_action is None:
|
if step.compensate_action is None:
|
||||||
logger.info(
|
logger.info(f"No compensation for step '{step.step_name}', skipping")
|
||||||
f"No compensation for step '{step.step_name}', skipping"
|
|
||||||
)
|
|
||||||
results.append(
|
results.append(
|
||||||
CompensationResult(
|
CompensationResult(
|
||||||
step_name=step.step_name,
|
step_name=step.step_name,
|
||||||
|
|
@ -82,9 +78,7 @@ class SagaOrchestrator:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Compensation for step '{step.step_name}' failed: {e}")
|
||||||
f"Compensation for step '{step.step_name}' failed: {e}"
|
|
||||||
)
|
|
||||||
results.append(
|
results.append(
|
||||||
CompensationResult(
|
CompensationResult(
|
||||||
step_name=step.step_name,
|
step_name=step.step_name,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
|
||||||
|
|
@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||||
class DynamicPipeline:
|
class DynamicPipeline:
|
||||||
"""动态 Pipeline 组合器"""
|
"""动态 Pipeline 组合器"""
|
||||||
|
|
||||||
def __init__(self, engine: PipelineEngine, loader: Any = None):
|
def __init__(self, engine: PipelineEngine, loader: object | None = None):
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
self._loader = loader
|
self._loader = loader
|
||||||
|
|
||||||
|
|
@ -23,7 +22,7 @@ class DynamicPipeline:
|
||||||
self,
|
self,
|
||||||
pipelines: dict[str, Pipeline],
|
pipelines: dict[str, Pipeline],
|
||||||
condition_key: str,
|
condition_key: str,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""根据条件选择子 Pipeline 执行"""
|
"""根据条件选择子 Pipeline 执行"""
|
||||||
context = context or {}
|
context = context or {}
|
||||||
|
|
@ -37,14 +36,16 @@ class DynamicPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
selected = pipelines[condition_value]
|
selected = pipelines[condition_value]
|
||||||
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
|
logger.info(
|
||||||
|
f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}"
|
||||||
|
)
|
||||||
return await self._engine.execute(selected, context)
|
return await self._engine.execute(selected, context)
|
||||||
|
|
||||||
async def execute_nested(
|
async def execute_nested(
|
||||||
self,
|
self,
|
||||||
parent: Pipeline,
|
parent: Pipeline,
|
||||||
sub_pipeline_map: dict[str, Pipeline],
|
sub_pipeline_map: dict[str, Pipeline],
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行嵌套 Pipeline"""
|
"""执行嵌套 Pipeline"""
|
||||||
# 先执行父 Pipeline
|
# 先执行父 Pipeline
|
||||||
|
|
@ -52,7 +53,7 @@ class DynamicPipeline:
|
||||||
|
|
||||||
# 根据父 Pipeline 结果选择子 Pipeline
|
# 根据父 Pipeline 结果选择子 Pipeline
|
||||||
for stage_name, stage_result in parent_result.stage_results.items():
|
for stage_name, stage_result in parent_result.stage_results.items():
|
||||||
if hasattr(stage_result, 'output_data') and stage_result.output_data:
|
if hasattr(stage_result, "output_data") and stage_result.output_data:
|
||||||
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
|
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
|
||||||
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
|
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
|
||||||
sub = sub_pipeline_map[sub_pipeline_name]
|
sub = sub_pipeline_map[sub_pipeline_name]
|
||||||
|
|
@ -66,7 +67,7 @@ class DynamicPipeline:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
exit_condition: str = "done",
|
exit_condition: str = "done",
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""循环执行 Pipeline 直到条件满足"""
|
"""循环执行 Pipeline 直到条件满足"""
|
||||||
current_context = context or {}
|
current_context = context or {}
|
||||||
|
|
|
||||||
|
|
@ -3,25 +3,42 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Awaitable, Callable, Protocol
|
||||||
|
|
||||||
from agentkit.core.protocol import HandoffMessage
|
from agentkit.core.protocol import HandoffMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _RedisPubSubLike(Protocol):
|
||||||
|
"""Structural type for Redis pubsub object."""
|
||||||
|
|
||||||
|
async def subscribe(self, channel: str) -> None: ...
|
||||||
|
async def unsubscribe(self, channel: str) -> None: ...
|
||||||
|
def listen(self) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _RedisLike(Protocol):
|
||||||
|
"""Structural type for async Redis client."""
|
||||||
|
|
||||||
|
async def publish(self, channel: str, message: str) -> int: ...
|
||||||
|
def pubsub(self) -> _RedisPubSubLike: ...
|
||||||
|
|
||||||
|
|
||||||
class HandoffManager:
|
class HandoffManager:
|
||||||
"""Handoff 管理器
|
"""Handoff 管理器
|
||||||
|
|
||||||
通过 Redis Pub/Sub 管理 Agent 间的任务转交。
|
通过 Redis Pub/Sub 管理 Agent 间的任务转交。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, redis: Any = None, dispatcher: Any = None):
|
def __init__(self, redis: _RedisLike | None = None, dispatcher: object | None = None):
|
||||||
self._redis = redis
|
self._redis = redis
|
||||||
self._dispatcher = dispatcher
|
self._dispatcher = dispatcher
|
||||||
self._handlers: dict[str, list[Any]] = {}
|
self._handlers: dict[str, list[Callable[[HandoffMessage], Awaitable[None]]]] = {}
|
||||||
|
|
||||||
def register_handler(self, agent_name: str, handler: Any) -> None:
|
def register_handler(
|
||||||
|
self, agent_name: str, handler: Callable[[HandoffMessage], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
"""注册 Handoff 处理器"""
|
"""注册 Handoff 处理器"""
|
||||||
if agent_name not in self._handlers:
|
if agent_name not in self._handlers:
|
||||||
self._handlers[agent_name] = []
|
self._handlers[agent_name] = []
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||||
from agentkit.orchestrator.pipeline_schema import (
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
|
|
@ -25,6 +27,23 @@ from agentkit.orchestrator.retry import execute_with_retry
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _DispatcherLike(Protocol):
|
||||||
|
async def dispatch(self, task: object) -> None: ...
|
||||||
|
async def get_task_status(self, task_id: str) -> dict[str, object]: ...
|
||||||
|
|
||||||
|
class _StateManagerLike(Protocol):
|
||||||
|
async def create_execution(self, **kwargs: object) -> str: ...
|
||||||
|
async def update_step(self, **kwargs: object) -> None: ...
|
||||||
|
async def complete_execution(self, **kwargs: object) -> None: ...
|
||||||
|
async def fail_execution(self, **kwargs: object) -> None: ...
|
||||||
|
|
||||||
|
class _LLMGatewayLike(Protocol):
|
||||||
|
async def chat(self, **kwargs: object) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
class PipelineEngine:
|
class PipelineEngine:
|
||||||
"""Pipeline 执行引擎
|
"""Pipeline 执行引擎
|
||||||
|
|
||||||
|
|
@ -38,7 +57,12 @@ class PipelineEngine:
|
||||||
- 状态持久化(可选)
|
- 状态持久化(可选)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dispatcher: _DispatcherLike | None = None,
|
||||||
|
state_manager: _StateManagerLike | None = None,
|
||||||
|
llm_gateway: _LLMGatewayLike | None = None,
|
||||||
|
):
|
||||||
self._dispatcher = dispatcher
|
self._dispatcher = dispatcher
|
||||||
self._state_manager = state_manager
|
self._state_manager = state_manager
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
@ -46,7 +70,7 @@ class PipelineEngine:
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
adaptive_config: AdaptiveConfig | None = None,
|
adaptive_config: AdaptiveConfig | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行 Pipeline
|
"""执行 Pipeline
|
||||||
|
|
@ -68,7 +92,7 @@ class PipelineEngine:
|
||||||
async def _adaptive_loop(
|
async def _adaptive_loop(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None,
|
context: dict[str, object] | None,
|
||||||
failed_result: PipelineResult,
|
failed_result: PipelineResult,
|
||||||
adaptive_config: AdaptiveConfig,
|
adaptive_config: AdaptiveConfig,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
|
|
@ -92,34 +116,30 @@ class PipelineEngine:
|
||||||
|
|
||||||
# Replan
|
# Replan
|
||||||
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
|
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
|
||||||
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
|
logger.info(
|
||||||
|
f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)"
|
||||||
|
)
|
||||||
|
|
||||||
# Re-execute
|
# Re-execute
|
||||||
current_result = await self._execute_pipeline(new_pipeline, context)
|
current_result = await self._execute_pipeline(new_pipeline, context)
|
||||||
current_pipeline = new_pipeline
|
current_pipeline = new_pipeline
|
||||||
|
|
||||||
# Record reflection in metadata
|
# Record reflection in metadata
|
||||||
current_result.metadata["reflections"] = [
|
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
|
||||||
r.model_dump() for r in reflections
|
|
||||||
]
|
|
||||||
|
|
||||||
if current_result.status == StageStatus.COMPLETED:
|
if current_result.status == StageStatus.COMPLETED:
|
||||||
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
|
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
|
||||||
return current_result
|
return current_result
|
||||||
|
|
||||||
# Exhausted reflections
|
# Exhausted reflections
|
||||||
logger.warning(
|
logger.warning(f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)")
|
||||||
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
|
current_result.metadata["reflections"] = [r.model_dump() for r in reflections]
|
||||||
)
|
|
||||||
current_result.metadata["reflections"] = [
|
|
||||||
r.model_dump() for r in reflections
|
|
||||||
]
|
|
||||||
return current_result
|
return current_result
|
||||||
|
|
||||||
async def _execute_pipeline(
|
async def _execute_pipeline(
|
||||||
self,
|
self,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
context: dict[str, Any] | None = None,
|
context: dict[str, object] | None = None,
|
||||||
) -> PipelineResult:
|
) -> PipelineResult:
|
||||||
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
|
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
|
||||||
result = PipelineResult(pipeline_name=pipeline.name)
|
result = PipelineResult(pipeline_name=pipeline.name)
|
||||||
|
|
@ -151,7 +171,9 @@ class PipelineEngine:
|
||||||
|
|
||||||
# 逐层执行
|
# 逐层执行
|
||||||
for level, stages in enumerate(level_groups):
|
for level, stages in enumerate(level_groups):
|
||||||
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
|
logger.info(
|
||||||
|
f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)"
|
||||||
|
)
|
||||||
|
|
||||||
# 并行执行同层 stages
|
# 并行执行同层 stages
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
@ -173,9 +195,11 @@ class PipelineEngine:
|
||||||
# Update step state
|
# Update step state
|
||||||
if self._state_manager is not None and execution_id is not None:
|
if self._state_manager is not None and execution_id is not None:
|
||||||
try:
|
try:
|
||||||
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
step_status = (
|
||||||
step_output = sr.output_data if hasattr(sr, 'output_data') else None
|
"completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
||||||
step_error = sr.error_message if hasattr(sr, 'error_message') else None
|
)
|
||||||
|
step_output = sr.output_data if hasattr(sr, "output_data") else None
|
||||||
|
step_error = sr.error_message if hasattr(sr, "error_message") else None
|
||||||
await self._state_manager.update_step(
|
await self._state_manager.update_step(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
step_name=stage.name,
|
step_name=stage.name,
|
||||||
|
|
@ -189,19 +213,21 @@ class PipelineEngine:
|
||||||
# 收集输出变量
|
# 收集输出变量
|
||||||
if sr.output_data and isinstance(sr, dict):
|
if sr.output_data and isinstance(sr, dict):
|
||||||
pass
|
pass
|
||||||
elif hasattr(sr, 'output_data') and sr.output_data:
|
elif hasattr(sr, "output_data") and sr.output_data:
|
||||||
for output_key in stage.outputs:
|
for output_key in stage.outputs:
|
||||||
if output_key in sr.output_data:
|
if output_key in sr.output_data:
|
||||||
result.variables[output_key] = sr.output_data[output_key]
|
result.variables[output_key] = sr.output_data[output_key]
|
||||||
|
|
||||||
# 检查是否需要中止
|
# 检查是否需要中止
|
||||||
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
|
if hasattr(sr, "status") and sr.status == StageStatus.FAILED:
|
||||||
if not stage.continue_on_failure:
|
if not stage.continue_on_failure:
|
||||||
# Execute Saga compensation for completed steps
|
# Execute Saga compensation for completed steps
|
||||||
compensation_results = await saga.compensate()
|
compensation_results = await saga.compensate()
|
||||||
if compensation_results:
|
if compensation_results:
|
||||||
failed_compensations = [
|
failed_compensations = [
|
||||||
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
|
cr
|
||||||
|
for cr in compensation_results
|
||||||
|
if not cr.success and cr.error != "no_compensation_needed"
|
||||||
]
|
]
|
||||||
if failed_compensations:
|
if failed_compensations:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -219,7 +245,12 @@ class PipelineEngine:
|
||||||
step_name=stage.name,
|
step_name=stage.name,
|
||||||
error=result.error_message,
|
error=result.error_message,
|
||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
except (
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
RuntimeError,
|
||||||
|
ValueError,
|
||||||
|
) as exc:
|
||||||
logger.warning(f"Failed to persist failure state: {exc}")
|
logger.warning(f"Failed to persist failure state: {exc}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -252,7 +283,9 @@ class PipelineEngine:
|
||||||
started_at = datetime.now(timezone.utc).isoformat()
|
started_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
# 条件检查
|
# 条件检查
|
||||||
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
|
if stage.condition and not self._evaluate_condition(
|
||||||
|
stage.condition, pipeline_result.variables
|
||||||
|
):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.SKIPPED,
|
status=StageStatus.SKIPPED,
|
||||||
|
|
@ -312,7 +345,9 @@ class PipelineEngine:
|
||||||
if status["status"] in ("completed", "failed", "cancelled"):
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
status=StageStatus.COMPLETED
|
||||||
|
if status["status"] == "completed"
|
||||||
|
else StageStatus.FAILED,
|
||||||
output_data=status.get("output_data"),
|
output_data=status.get("output_data"),
|
||||||
error_message=status.get("error_message"),
|
error_message=status.get("error_message"),
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
|
|
@ -406,7 +441,7 @@ class PipelineEngine:
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_nested(data: dict, path: str) -> Any:
|
def _get_nested(data: dict, path: str) -> object:
|
||||||
keys = path.split(".")
|
keys = path.split(".")
|
||||||
current = data
|
current = data
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
@ -497,9 +532,7 @@ class PipelineEngine:
|
||||||
|
|
||||||
if verifier_feedback.passed:
|
if verifier_feedback.passed:
|
||||||
# 审查通过,返回成功结果
|
# 审查通过,返回成功结果
|
||||||
logger.info(
|
logger.info(f"Stage '{stage.name}' passed review in round {round_num}")
|
||||||
f"Stage '{stage.name}' passed review in round {round_num}"
|
|
||||||
)
|
|
||||||
worker_result.output_data = worker_result.output_data or {}
|
worker_result.output_data = worker_result.output_data or {}
|
||||||
worker_result.output_data["adversarial_metadata"] = {
|
worker_result.output_data["adversarial_metadata"] = {
|
||||||
"passed_round": round_num,
|
"passed_round": round_num,
|
||||||
|
|
@ -553,7 +586,7 @@ class PipelineEngine:
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
action: str,
|
action: str,
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, object],
|
||||||
stage: PipelineStage,
|
stage: PipelineStage,
|
||||||
started_at: str,
|
started_at: str,
|
||||||
timeout_seconds: int | None = None,
|
timeout_seconds: int | None = None,
|
||||||
|
|
@ -568,7 +601,9 @@ class PipelineEngine:
|
||||||
started_at: 开始时间
|
started_at: 开始时间
|
||||||
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
|
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
|
||||||
"""
|
"""
|
||||||
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
|
effective_timeout = (
|
||||||
|
timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
|
||||||
|
)
|
||||||
if self._dispatcher is None:
|
if self._dispatcher is None:
|
||||||
# Dry-run 模式
|
# Dry-run 模式
|
||||||
return StageResult(
|
return StageResult(
|
||||||
|
|
@ -602,7 +637,9 @@ class PipelineEngine:
|
||||||
if status["status"] in ("completed", "failed", "cancelled"):
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
return StageResult(
|
return StageResult(
|
||||||
stage_name=stage.name,
|
stage_name=stage.name,
|
||||||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
status=StageStatus.COMPLETED
|
||||||
|
if status["status"] == "completed"
|
||||||
|
else StageStatus.FAILED,
|
||||||
output_data=status.get("output_data"),
|
output_data=status.get("output_data"),
|
||||||
error_message=status.get("error_message"),
|
error_message=status.get("error_message"),
|
||||||
started_at=started_at,
|
started_at=started_at,
|
||||||
|
|
@ -639,7 +676,7 @@ class PipelineEngine:
|
||||||
async def _execute_verifier(
|
async def _execute_verifier(
|
||||||
self,
|
self,
|
||||||
verifier_name: str,
|
verifier_name: str,
|
||||||
worker_output: dict[str, Any],
|
worker_output: dict[str, object],
|
||||||
stage: PipelineStage,
|
stage: PipelineStage,
|
||||||
started_at: str,
|
started_at: str,
|
||||||
) -> ReviewFeedback:
|
) -> ReviewFeedback:
|
||||||
|
|
@ -679,10 +716,7 @@ class PipelineEngine:
|
||||||
try:
|
try:
|
||||||
feedback = ReviewFeedback(
|
feedback = ReviewFeedback(
|
||||||
passed=output_data.get("passed", False),
|
passed=output_data.get("passed", False),
|
||||||
issues=[
|
issues=[ReviewIssue(**issue) for issue in output_data.get("issues", [])],
|
||||||
ReviewIssue(**issue)
|
|
||||||
for issue in output_data.get("issues", [])
|
|
||||||
],
|
|
||||||
summary=output_data.get("summary", "No summary provided"),
|
summary=output_data.get("summary", "No summary provided"),
|
||||||
score=output_data.get("score", 0.0),
|
score=output_data.get("score", 0.0),
|
||||||
)
|
)
|
||||||
|
|
@ -699,7 +733,7 @@ class PipelineEngine:
|
||||||
self,
|
self,
|
||||||
feedback: ReviewFeedback,
|
feedback: ReviewFeedback,
|
||||||
feedback_mode: str = "structured+natural",
|
feedback_mode: str = "structured+natural",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -720,7 +754,7 @@ class PipelineEngine:
|
||||||
for issue in feedback.issues
|
for issue in feedback.issues
|
||||||
]
|
]
|
||||||
|
|
||||||
feedback_context: dict[str, Any] = {
|
feedback_context: dict[str, object] = {
|
||||||
"previous_attempt_failed": True,
|
"previous_attempt_failed": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -756,7 +790,9 @@ class PipelineEngine:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 未知模式,fallback 到 structured+natural
|
# 未知模式,fallback 到 structured+natural
|
||||||
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
|
logger.warning(
|
||||||
|
f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural"
|
||||||
|
)
|
||||||
feedback_context["review_feedback"] = {
|
feedback_context["review_feedback"] = {
|
||||||
"summary": feedback.summary,
|
"summary": feedback.summary,
|
||||||
"issues": issues_list,
|
"issues": issues_list,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -23,7 +22,9 @@ class PipelineLoader:
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
|
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
|
raise FileNotFoundError(
|
||||||
|
f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
content = yaml_path.read_text(encoding="utf-8")
|
content = yaml_path.read_text(encoding="utf-8")
|
||||||
return self.load_from_yaml(content, pipeline_name)
|
return self.load_from_yaml(content, pipeline_name)
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,7 @@ class PipelineExecutionModel(Base):
|
||||||
)
|
)
|
||||||
completed_at = Column(DateTime)
|
completed_at = Column(DateTime)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (Index("ix_pipeline_status_created", "status", "created_at"),)
|
||||||
Index("ix_pipeline_status_created", "status", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineStepHistoryModel(Base):
|
class PipelineStepHistoryModel(Base):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Pipeline 数据模型"""
|
"""Pipeline 数据模型"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -18,8 +18,11 @@ class StageStatus(str, Enum):
|
||||||
|
|
||||||
class ReviewIssue(BaseModel):
|
class ReviewIssue(BaseModel):
|
||||||
"""单条审查问题"""
|
"""单条审查问题"""
|
||||||
|
|
||||||
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
|
severity: Literal["critical", "major", "minor"] = Field(description="问题严重程度")
|
||||||
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(description="问题类别")
|
category: Literal["logic_error", "security", "style", "test_failure", "architecture"] = Field(
|
||||||
|
description="问题类别"
|
||||||
|
)
|
||||||
description: str = Field(min_length=1, description="问题描述")
|
description: str = Field(min_length=1, description="问题描述")
|
||||||
location: str | None = Field(default=None, description="文件路径/行号")
|
location: str | None = Field(default=None, description="文件路径/行号")
|
||||||
suggestion: str | None = Field(default=None, description="修复建议")
|
suggestion: str | None = Field(default=None, description="修复建议")
|
||||||
|
|
@ -27,6 +30,7 @@ class ReviewIssue(BaseModel):
|
||||||
|
|
||||||
class ReviewFeedback(BaseModel):
|
class ReviewFeedback(BaseModel):
|
||||||
"""Verifier 返回的结构化审查反馈"""
|
"""Verifier 返回的结构化审查反馈"""
|
||||||
|
|
||||||
passed: bool = Field(description="是否通过审查")
|
passed: bool = Field(description="是否通过审查")
|
||||||
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
|
issues: list[ReviewIssue] = Field(default_factory=list, description="问题列表")
|
||||||
summary: str = Field(min_length=1, description="自然语言审查报告")
|
summary: str = Field(min_length=1, description="自然语言审查报告")
|
||||||
|
|
@ -35,6 +39,7 @@ class ReviewFeedback(BaseModel):
|
||||||
|
|
||||||
class AdversarialState(BaseModel):
|
class AdversarialState(BaseModel):
|
||||||
"""对抗轮次状态追踪"""
|
"""对抗轮次状态追踪"""
|
||||||
|
|
||||||
current_round: int = Field(default=0, description="当前对抗轮次")
|
current_round: int = Field(default=0, description="当前对抗轮次")
|
||||||
max_rounds: int = Field(default=3, description="最大对抗轮次")
|
max_rounds: int = Field(default=3, description="最大对抗轮次")
|
||||||
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
|
feedback_history: list[ReviewFeedback] = Field(default_factory=list, description="反馈历史")
|
||||||
|
|
@ -46,7 +51,7 @@ class PipelineStage(BaseModel):
|
||||||
agent: str
|
agent: str
|
||||||
action: str
|
action: str
|
||||||
depends_on: list[str] = []
|
depends_on: list[str] = []
|
||||||
inputs: dict[str, Any] = {}
|
inputs: dict[str, object] = {}
|
||||||
outputs: list[str] = []
|
outputs: list[str] = []
|
||||||
timeout_seconds: int = 300
|
timeout_seconds: int = 300
|
||||||
retry_count: int = 0
|
retry_count: int = 0
|
||||||
|
|
@ -54,12 +59,19 @@ class PipelineStage(BaseModel):
|
||||||
condition: str | None = None
|
condition: str | None = None
|
||||||
retry_policy: StepRetryPolicy | None = None
|
retry_policy: StepRetryPolicy | None = None
|
||||||
compensate: str | None = None
|
compensate: str | None = None
|
||||||
|
|
||||||
# 对抗闭环相关字段
|
# 对抗闭环相关字段
|
||||||
verifier: str | None = Field(default=None, description="Verifier Agent 名称,配置后启用对抗模式")
|
verifier: str | None = Field(
|
||||||
|
default=None, description="Verifier Agent 名称,配置后启用对抗模式"
|
||||||
|
)
|
||||||
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
|
max_adversarial_rounds: int = Field(default=3, description="最大对抗轮次")
|
||||||
verifier_timeout_seconds: int = Field(default=120, description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds")
|
verifier_timeout_seconds: int = Field(
|
||||||
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(default="structured+natural", description="反馈模式")
|
default=120,
|
||||||
|
description="Verifier Agent 独立超时时间(秒),避免与 Worker 共享 timeout_seconds",
|
||||||
|
)
|
||||||
|
feedback_mode: Literal["structured+natural", "structured", "natural"] = Field(
|
||||||
|
default="structured+natural", description="反馈模式"
|
||||||
|
)
|
||||||
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
|
escalate_on_exhaust: str | None = Field(default=None, description="对抗轮次耗尽后的升级目标")
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
@ -70,13 +82,13 @@ class Pipeline(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
description: str
|
description: str
|
||||||
stages: list[PipelineStage]
|
stages: list[PipelineStage]
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
class StageResult(BaseModel):
|
class StageResult(BaseModel):
|
||||||
stage_name: str
|
stage_name: str
|
||||||
status: StageStatus = StageStatus.PENDING
|
status: StageStatus = StageStatus.PENDING
|
||||||
output_data: dict[str, Any] | None = None
|
output_data: dict[str, object] | None = None
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
started_at: str | None = None
|
started_at: str | None = None
|
||||||
completed_at: str | None = None
|
completed_at: str | None = None
|
||||||
|
|
@ -86,9 +98,9 @@ class PipelineResult(BaseModel):
|
||||||
pipeline_name: str
|
pipeline_name: str
|
||||||
status: StageStatus = StageStatus.PENDING
|
status: StageStatus = StageStatus.PENDING
|
||||||
stage_results: dict[str, StageResult] = {}
|
stage_results: dict[str, StageResult] = {}
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, object] = {}
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
metadata: dict[str, Any] = {}
|
metadata: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveConfig(BaseModel):
|
class AdaptiveConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Callable, Coroutine
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_models import (
|
from agentkit.orchestrator.pipeline_models import (
|
||||||
PipelineExecutionModel,
|
PipelineExecutionModel,
|
||||||
|
|
@ -183,7 +183,7 @@ class PipelineStateRedis:
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
async def _safe_redis_call(
|
async def _safe_redis_call(
|
||||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
|
self, fn: Callable[..., Coroutine[object, object, object]], *args: object, **kwargs: object
|
||||||
) -> object | None:
|
) -> object | None:
|
||||||
"""Execute a Redis call, falling back to memory on failure.
|
"""Execute a Redis call, falling back to memory on failure.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,30 @@
|
||||||
生成修正后的 Pipeline 重新执行。
|
生成修正后的 Pipeline 重新执行。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_schema import (
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineResult,
|
PipelineResult,
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
ReflectionReport,
|
ReflectionReport,
|
||||||
StageResult,
|
|
||||||
StageStatus,
|
StageStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _LLMGatewayLike(Protocol):
|
||||||
|
async def chat(self, **kwargs: object) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
class PipelineReflector:
|
class PipelineReflector:
|
||||||
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
|
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
|
||||||
|
|
||||||
|
|
@ -27,7 +35,7 @@ class PipelineReflector:
|
||||||
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
|
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_gateway: Any = None):
|
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
||||||
async def reflect(
|
async def reflect(
|
||||||
|
|
@ -54,19 +62,25 @@ class PipelineReflector:
|
||||||
if self._llm_gateway is not None:
|
if self._llm_gateway is not None:
|
||||||
try:
|
try:
|
||||||
return await self._llm_reflect(
|
return await self._llm_reflect(
|
||||||
pipeline, failed_stage, error_message,
|
pipeline,
|
||||||
completed_outputs, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
completed_outputs,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
|
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
|
||||||
|
|
||||||
# 规则兜底:基于错误信息分类
|
# 规则兜底:基于错误信息分类
|
||||||
return self._rule_based_reflect(
|
return self._rule_based_reflect(
|
||||||
failed_stage, error_message, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _find_failure(
|
def _find_failure(
|
||||||
self, result: PipelineResult,
|
self,
|
||||||
|
result: PipelineResult,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""找到第一个失败的 stage 及其错误信息。"""
|
"""找到第一个失败的 stage 及其错误信息。"""
|
||||||
for name, sr in result.stage_results.items():
|
for name, sr in result.stage_results.items():
|
||||||
|
|
@ -75,8 +89,9 @@ class PipelineReflector:
|
||||||
return "", "no failed stage found"
|
return "", "no failed stage found"
|
||||||
|
|
||||||
def _collect_completed_outputs(
|
def _collect_completed_outputs(
|
||||||
self, result: PipelineResult,
|
self,
|
||||||
) -> dict[str, Any]:
|
result: PipelineResult,
|
||||||
|
) -> dict[str, object]:
|
||||||
"""收集已完成步骤的输出。"""
|
"""收集已完成步骤的输出。"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for name, sr in result.stage_results.items():
|
for name, sr in result.stage_results.items():
|
||||||
|
|
@ -89,13 +104,16 @@ class PipelineReflector:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
failed_stage: str,
|
failed_stage: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
completed_outputs: dict[str, Any],
|
completed_outputs: dict[str, object],
|
||||||
reflection_number: int,
|
reflection_number: int,
|
||||||
) -> ReflectionReport:
|
) -> ReflectionReport:
|
||||||
"""使用 LLM 分析失败原因。"""
|
"""使用 LLM 分析失败原因。"""
|
||||||
prompt = self._build_reflection_prompt(
|
prompt = self._build_reflection_prompt(
|
||||||
pipeline, failed_stage, error_message,
|
pipeline,
|
||||||
completed_outputs, reflection_number,
|
failed_stage,
|
||||||
|
error_message,
|
||||||
|
completed_outputs,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._llm_gateway.chat(
|
response = await self._llm_gateway.chat(
|
||||||
|
|
@ -106,7 +124,9 @@ class PipelineReflector:
|
||||||
# 解析 LLM 返回的 JSON
|
# 解析 LLM 返回的 JSON
|
||||||
content = response.content if hasattr(response, "content") else str(response)
|
content = response.content if hasattr(response, "content") else str(response)
|
||||||
return self._parse_reflection_response(
|
return self._parse_reflection_response(
|
||||||
content, failed_stage, reflection_number,
|
content,
|
||||||
|
failed_stage,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_reflection_prompt(
|
def _build_reflection_prompt(
|
||||||
|
|
@ -114,15 +134,14 @@ class PipelineReflector:
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
failed_stage: str,
|
failed_stage: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
completed_outputs: dict[str, Any],
|
completed_outputs: dict[str, object],
|
||||||
reflection_number: int,
|
reflection_number: int,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建反思提示词。"""
|
"""构建反思提示词。"""
|
||||||
stage_descriptions = []
|
stage_descriptions = []
|
||||||
for s in pipeline.stages:
|
for s in pipeline.stages:
|
||||||
stage_descriptions.append(
|
stage_descriptions.append(
|
||||||
f" - {s.name}: agent={s.agent}, action={s.action}, "
|
f" - {s.name}: agent={s.agent}, action={s.action}, depends_on={s.depends_on}"
|
||||||
f"depends_on={s.depends_on}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
completed_summary = json.dumps(
|
completed_summary = json.dumps(
|
||||||
|
|
@ -174,7 +193,9 @@ JSON response:"""
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
logger.warning(f"Failed to parse LLM reflection response: {e}")
|
logger.warning(f"Failed to parse LLM reflection response: {e}")
|
||||||
return self._rule_based_reflect(
|
return self._rule_based_reflect(
|
||||||
failed_stage, content, reflection_number,
|
failed_stage,
|
||||||
|
content,
|
||||||
|
reflection_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _rule_based_reflect(
|
def _rule_based_reflect(
|
||||||
|
|
@ -218,7 +239,7 @@ class PipelineReplanner:
|
||||||
保留已完成步骤的结果,仅重新规划失败及后续步骤。
|
保留已完成步骤的结果,仅重新规划失败及后续步骤。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_gateway: Any = None):
|
def __init__(self, llm_gateway: _LLMGatewayLike | None = None):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
|
|
||||||
async def replan(
|
async def replan(
|
||||||
|
|
@ -255,8 +276,7 @@ class PipelineReplanner:
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""使用 LLM 生成修正后的 Pipeline。"""
|
"""使用 LLM 生成修正后的 Pipeline。"""
|
||||||
completed_stages = [
|
completed_stages = [
|
||||||
name for name, sr in result.stage_results.items()
|
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
|
||||||
if sr.status == StageStatus.COMPLETED
|
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = f"""Based on the reflection report, generate a corrected pipeline.
|
prompt = f"""Based on the reflection report, generate a corrected pipeline.
|
||||||
|
|
@ -284,7 +304,9 @@ JSON pipeline:"""
|
||||||
return self._parse_pipeline_response(content, pipeline)
|
return self._parse_pipeline_response(content, pipeline)
|
||||||
|
|
||||||
def _parse_pipeline_response(
|
def _parse_pipeline_response(
|
||||||
self, content: str, original: Pipeline,
|
self,
|
||||||
|
content: str,
|
||||||
|
original: Pipeline,
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""解析 LLM 返回的 Pipeline JSON。"""
|
"""解析 LLM 返回的 Pipeline JSON。"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -294,9 +316,7 @@ JSON pipeline:"""
|
||||||
text = "\n".join(lines[1:-1])
|
text = "\n".join(lines[1:-1])
|
||||||
|
|
||||||
data = json.loads(text)
|
data = json.loads(text)
|
||||||
stages = [
|
stages = [PipelineStage(**s) for s in data.get("stages", [])]
|
||||||
PipelineStage(**s) for s in data.get("stages", [])
|
|
||||||
]
|
|
||||||
return Pipeline(
|
return Pipeline(
|
||||||
name=data.get("name", original.name),
|
name=data.get("name", original.name),
|
||||||
version=data.get("version", original.version),
|
version=data.get("version", original.version),
|
||||||
|
|
@ -316,8 +336,7 @@ JSON pipeline:"""
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""基于规则的兜底重规划。"""
|
"""基于规则的兜底重规划。"""
|
||||||
completed_stages = {
|
completed_stages = {
|
||||||
name for name, sr in result.stage_results.items()
|
name for name, sr in result.stage_results.items() if sr.status == StageStatus.COMPLETED
|
||||||
if sr.status == StageStatus.COMPLETED
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 构建修正后的 stages 列表
|
# 构建修正后的 stages 列表
|
||||||
|
|
@ -345,17 +364,21 @@ JSON pipeline:"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def _adjust_failed_stage(
|
def _adjust_failed_stage(
|
||||||
self, stage: PipelineStage, report: ReflectionReport,
|
self,
|
||||||
|
stage: PipelineStage,
|
||||||
|
report: ReflectionReport,
|
||||||
) -> PipelineStage:
|
) -> PipelineStage:
|
||||||
"""根据反思报告调整失败的步骤。"""
|
"""根据反思报告调整失败的步骤。"""
|
||||||
adjustments: dict[str, Any] = {}
|
adjustments: dict[str, object] = {}
|
||||||
|
|
||||||
if report.failure_type == "timeout":
|
if report.failure_type == "timeout":
|
||||||
adjustments["timeout_seconds"] = min(
|
adjustments["timeout_seconds"] = min(
|
||||||
stage.timeout_seconds * 2, 3600,
|
stage.timeout_seconds * 2,
|
||||||
|
3600,
|
||||||
)
|
)
|
||||||
if stage.retry_policy is None:
|
if stage.retry_policy is None:
|
||||||
from agentkit.orchestrator.retry import StepRetryPolicy
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
|
||||||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||||||
|
|
||||||
elif report.failure_type == "resource_error":
|
elif report.failure_type == "resource_error":
|
||||||
|
|
@ -365,6 +388,7 @@ JSON pipeline:"""
|
||||||
# 添加重试策略,可能输入在后续可用
|
# 添加重试策略,可能输入在后续可用
|
||||||
if stage.retry_policy is None:
|
if stage.retry_policy is None:
|
||||||
from agentkit.orchestrator.retry import StepRetryPolicy
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
|
||||||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||||||
|
|
||||||
return stage.model_copy(update=adjustments)
|
return stage.model_copy(update=adjustments)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -27,7 +27,7 @@ class StepRetryPolicy:
|
||||||
def calculate_delay(self, attempt: int) -> float:
|
def calculate_delay(self, attempt: int) -> float:
|
||||||
"""Calculate delay for given attempt number (0-based)"""
|
"""Calculate delay for given attempt number (0-based)"""
|
||||||
delay = min(
|
delay = min(
|
||||||
self.base_delay * (self.exponential_base ** attempt),
|
self.base_delay * (self.exponential_base**attempt),
|
||||||
self.max_delay,
|
self.max_delay,
|
||||||
)
|
)
|
||||||
if self.jitter:
|
if self.jitter:
|
||||||
|
|
@ -36,10 +36,10 @@ class StepRetryPolicy:
|
||||||
|
|
||||||
|
|
||||||
async def execute_with_retry(
|
async def execute_with_retry(
|
||||||
func: Callable[..., Awaitable[Any]],
|
func: Callable[..., Awaitable[object]],
|
||||||
retry_policy: StepRetryPolicy | None = None,
|
retry_policy: StepRetryPolicy | None = None,
|
||||||
step_name: str = "",
|
step_name: str = "",
|
||||||
) -> Any:
|
) -> object:
|
||||||
"""Execute a function with retry policy"""
|
"""Execute a function with retry policy"""
|
||||||
if retry_policy is None:
|
if retry_policy is None:
|
||||||
return await func()
|
return await func()
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
|
from agentkit.orchestrator.pipeline_schema import PipelineStage
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStage(PipelineStage):
|
class WorkflowStage(PipelineStage):
|
||||||
"""A workflow stage extending PipelineStage with type and config."""
|
"""A workflow stage extending PipelineStage with type and config."""
|
||||||
|
|
||||||
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
|
type: str = "skill" # "skill" | "condition" | "approval" | "parallel"
|
||||||
config: dict[str, Any] = Field(default_factory=dict)
|
config: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDefinition(BaseModel):
|
class WorkflowDefinition(BaseModel):
|
||||||
|
|
@ -24,9 +23,9 @@ class WorkflowDefinition(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
version: int = 1
|
version: int = 1
|
||||||
stages: list[WorkflowStage] = Field(default_factory=list)
|
stages: list[WorkflowStage] = Field(default_factory=list)
|
||||||
triggers: list[dict[str, Any]] = Field(default_factory=list)
|
triggers: list[dict[str, object]] = Field(default_factory=list)
|
||||||
variables_schema: dict[str, Any] = Field(default_factory=dict)
|
variables_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
output_schema: dict[str, Any] = Field(default_factory=dict)
|
output_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
|
@ -38,11 +37,11 @@ class WorkflowExecution(BaseModel):
|
||||||
workflow_id: str = ""
|
workflow_id: str = ""
|
||||||
status: str = "pending" # pending|running|paused|completed|failed|cancelled
|
status: str = "pending" # pending|running|paused|completed|failed|cancelled
|
||||||
current_stage: str | None = None
|
current_stage: str | None = None
|
||||||
stage_results: dict[str, Any] = Field(default_factory=dict)
|
stage_results: dict[str, object] = Field(default_factory=dict)
|
||||||
started_at: str | None = None
|
started_at: str | None = None
|
||||||
completed_at: str | None = None
|
completed_at: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
variables: dict[str, Any] = Field(default_factory=dict)
|
variables: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowSummary(BaseModel):
|
class WorkflowSummary(BaseModel):
|
||||||
|
|
@ -62,15 +61,15 @@ class CreateWorkflowRequest(BaseModel):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
stages: list[WorkflowStage] = Field(default_factory=list)
|
stages: list[WorkflowStage] = Field(default_factory=list)
|
||||||
triggers: list[dict[str, Any]] = Field(default_factory=list)
|
triggers: list[dict[str, object]] = Field(default_factory=list)
|
||||||
variables_schema: dict[str, Any] = Field(default_factory=dict)
|
variables_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
output_schema: dict[str, Any] = Field(default_factory=dict)
|
output_schema: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ExecuteWorkflowRequest(BaseModel):
|
class ExecuteWorkflowRequest(BaseModel):
|
||||||
"""Request body for executing a workflow."""
|
"""Request body for executing a workflow."""
|
||||||
|
|
||||||
variables: dict[str, Any] = Field(default_factory=dict)
|
variables: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ApproveRequest(BaseModel):
|
class ApproveRequest(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ async def create_agent(request: CreateAgentRequest, req: Request):
|
||||||
config_dict = request.config
|
config_dict = request.config
|
||||||
try:
|
try:
|
||||||
config = SkillConfig.from_dict(config_dict)
|
config = SkillConfig.from_dict(config_dict)
|
||||||
except Exception:
|
except (ValueError, KeyError, TypeError):
|
||||||
config = AgentConfig.from_dict(config_dict)
|
config = AgentConfig.from_dict(config_dict)
|
||||||
agent = await pool.create_agent(config)
|
agent = await pool.create_agent(config)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,8 @@ from agentkit.server.auth.session_service import (
|
||||||
REVOKE_REASON_PASSWORD_CHANGED,
|
REVOKE_REASON_PASSWORD_CHANGED,
|
||||||
REVOKE_REASON_USER_TERMINATED,
|
REVOKE_REASON_USER_TERMINATED,
|
||||||
SessionCreate,
|
SessionCreate,
|
||||||
|
SessionNotFound,
|
||||||
|
SessionReuseDetected,
|
||||||
SessionService,
|
SessionService,
|
||||||
get_session_service,
|
get_session_service,
|
||||||
)
|
)
|
||||||
|
|
@ -253,7 +255,7 @@ def _is_legacy_client(request: Request) -> bool:
|
||||||
if client_v is None or cutoff_v is None:
|
if client_v is None or cutoff_v is None:
|
||||||
return False
|
return False
|
||||||
return client_v < cutoff_v
|
return client_v < cutoff_v
|
||||||
except Exception: # noqa: BLE001
|
except (ValueError, TypeError, AttributeError): # noqa: BLE001
|
||||||
logger.debug("Failed to parse X-Client-Version %r", raw)
|
logger.debug("Failed to parse X-Client-Version %r", raw)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -492,7 +494,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
|
||||||
# 1. Verify signature + type
|
# 1. Verify signature + type
|
||||||
try:
|
try:
|
||||||
refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh")
|
refresh_payload = verify_token(payload.refresh_token, secret, expected_type="refresh")
|
||||||
except Exception as exc: # noqa: BLE001
|
except (jwt.PyJWTError, ValueError, KeyError) as exc: # noqa: BLE001
|
||||||
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
|
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
|
||||||
|
|
||||||
# 2-3. Validate the session (also handles reuse detection)
|
# 2-3. Validate the session (also handles reuse detection)
|
||||||
|
|
@ -510,7 +512,7 @@ async def refresh(payload: RefreshRequest, request: Request) -> TokenResponse:
|
||||||
new_refresh_token=new_pair.refresh_token,
|
new_refresh_token=new_pair.refresh_token,
|
||||||
new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()),
|
new_ttl_seconds=int(REFRESH_TOKEN_TTL.total_seconds()),
|
||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound
|
except (SessionReuseDetected, SessionNotFound, ValueError, KeyError, RuntimeError) as exc: # noqa: BLE001 — SessionReuseDetected / SessionNotFound
|
||||||
logger.info("Refresh rejected: %s", exc)
|
logger.info("Refresh rejected: %s", exc)
|
||||||
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
|
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -483,7 +483,7 @@ async def validate_formula(
|
||||||
parse_formula(body.formula)
|
parse_formula(body.formula)
|
||||||
except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e:
|
except (FormulaParseError, FormulaSecurityError, UnknownFunctionError) as e:
|
||||||
return {"valid": False, "error": str(e)}
|
return {"valid": False, "error": str(e)}
|
||||||
except Exception as e: # pragma: no cover — defensive
|
except (ValueError, TypeError, KeyError, AttributeError) as e: # pragma: no cover — defensive
|
||||||
return {"valid": False, "error": f"Unexpected error: {e}"}
|
return {"valid": False, "error": f"Unexpected error: {e}"}
|
||||||
return {"valid": True}
|
return {"valid": True}
|
||||||
|
|
||||||
|
|
@ -750,7 +750,7 @@ async def upload_file(
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except (OSError, RuntimeError) as exc:
|
||||||
file_path.unlink(missing_ok=True)
|
file_path.unlink(missing_ok=True)
|
||||||
logger.error(f"Failed to save uploaded bitable file: {exc}")
|
logger.error(f"Failed to save uploaded bitable file: {exc}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to save file") from exc
|
raise HTTPException(status_code=500, detail="Failed to save file") from exc
|
||||||
|
|
|
||||||
|
|
@ -553,7 +553,7 @@ async def _invalidate_adapter_cache(channel_id: str) -> None:
|
||||||
if old is not None:
|
if old is not None:
|
||||||
try:
|
try:
|
||||||
await old.close()
|
await old.close()
|
||||||
except Exception: # noqa: BLE001 — 关闭异常不应阻塞配置变更
|
except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001 — 关闭异常不应阻塞配置变更
|
||||||
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
|
logger.debug("关闭旧适配器异常已忽略: channel_id=%s", channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -562,7 +562,7 @@ async def close_all_adapters() -> None:
|
||||||
for channel_id, adapter in list(_adapter_cache.items()):
|
for channel_id, adapter in list(_adapter_cache.items()):
|
||||||
try:
|
try:
|
||||||
await adapter.close()
|
await adapter.close()
|
||||||
except Exception: # noqa: BLE001
|
except (ConnectionError, RuntimeError, OSError, asyncio.TimeoutError): # noqa: BLE001
|
||||||
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
|
logger.debug("关闭适配器异常已忽略: channel_id=%s", channel_id)
|
||||||
_adapter_cache.clear()
|
_adapter_cache.clear()
|
||||||
|
|
||||||
|
|
@ -614,6 +614,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
|
||||||
model=routing.model or "default",
|
model=routing.model or "default",
|
||||||
)
|
)
|
||||||
final_content = getattr(result, "content", "") or ""
|
final_content = getattr(result, "content", "") or ""
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
|
except Exception as exc: # noqa: BLE001 — 回退路径需捕获全部异常
|
||||||
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
|
logger.warning("ReActEngine 执行失败,回退到 DIRECT_CHAT: %s", exc)
|
||||||
final_content = await _direct_chat(llm_gateway, routing)
|
final_content = await _direct_chat(llm_gateway, routing)
|
||||||
|
|
@ -628,6 +630,8 @@ async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, mess
|
||||||
content=final_content,
|
content=final_content,
|
||||||
)
|
)
|
||||||
await adapter.send_message(outgoing)
|
await adapter.send_message(outgoing)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
||||||
logger.exception("处理入站消息失败: %s", exc)
|
logger.exception("处理入站消息失败: %s", exc)
|
||||||
|
|
||||||
|
|
@ -703,7 +707,7 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
|
||||||
except WeComURLVerification as e:
|
except WeComURLVerification as e:
|
||||||
# 企微 URL 验证 — 返回 XML 响应
|
# 企微 URL 验证 — 返回 XML 响应
|
||||||
return Response(content=e.response_xml, media_type="application/xml")
|
return Response(content=e.response_xml, media_type="application/xml")
|
||||||
except Exception as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
|
except (ValueError, KeyError, RuntimeError, AttributeError, OSError) as exc: # noqa: BLE001 — 防止 receive_message 异常导致 500 触发平台重试风暴
|
||||||
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
|
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
|
||||||
return {"code": 0, "msg": "invalid_payload"}
|
return {"code": 0, "msg": "invalid_payload"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class ChatConnectionManager:
|
||||||
for ws, _ in conns:
|
for ws, _ in conns:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
stale.append(ws)
|
stale.append(ws)
|
||||||
for ws in stale:
|
for ws in stale:
|
||||||
self.remove(session_id, ws)
|
self.remove(session_id, ws)
|
||||||
|
|
@ -295,6 +295,8 @@ async def _execute_board_meeting(
|
||||||
await team.create_board(topic=routing_result.topic, expert_configs=expert_configs)
|
await team.create_board(topic=routing_result.topic, expert_configs=expert_configs)
|
||||||
orchestrator = BoardOrchestrator(team=team)
|
orchestrator = BoardOrchestrator(team=team)
|
||||||
result = await orchestrator.execute(routing_result.topic)
|
result = await orchestrator.execute(routing_result.topic)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True)
|
logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True)
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
|
|
@ -302,7 +304,7 @@ async def _execute_board_meeting(
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await team.dissolve()
|
await team.dissolve()
|
||||||
except Exception:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -348,7 +350,7 @@ async def _execute_board_meeting(
|
||||||
# Dissolve the team to release expert agents
|
# Dissolve the team to release expert agents
|
||||||
try:
|
try:
|
||||||
await team.dissolve()
|
await team.dissolve()
|
||||||
except Exception as e:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
|
||||||
logger.warning(f"Board team dissolve failed: {e}")
|
logger.warning(f"Board team dissolve failed: {e}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
@ -467,7 +469,7 @@ async def _execute_team_collab(
|
||||||
# Always dissolve the team and remove handler to avoid leaks
|
# Always dissolve the team and remove handler to avoid leaks
|
||||||
try:
|
try:
|
||||||
await team.dissolve()
|
await team.dissolve()
|
||||||
except Exception as e:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
|
||||||
logger.warning(f"Team dissolve failed: {e}")
|
logger.warning(f"Team dissolve failed: {e}")
|
||||||
# dissolve() already clears handlers via handoff_transport.close()
|
# dissolve() already clears handlers via handoff_transport.close()
|
||||||
|
|
||||||
|
|
@ -585,7 +587,7 @@ def _build_phase_engine(
|
||||||
if phase_policy is None:
|
if phase_policy is None:
|
||||||
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||||
phase_policy = default_policy()
|
phase_policy = default_policy()
|
||||||
except Exception as e:
|
except (ValueError, TypeError, KeyError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||||
session_id,
|
session_id,
|
||||||
|
|
@ -695,6 +697,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}")
|
logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -773,6 +777,8 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
||||||
return response_dict
|
return response_dict
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Chat execution error for session {session_id}: {e}")
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -829,7 +835,7 @@ async def _resolve_ws_dept_context(
|
||||||
db_path_resolved = Path(db_path)
|
db_path_resolved = Path(db_path)
|
||||||
try:
|
try:
|
||||||
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
|
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to fetch department ids for WebSocket user %s — fail-closed",
|
"Failed to fetch department ids for WebSocket user %s — fail-closed",
|
||||||
user_id,
|
user_id,
|
||||||
|
|
@ -946,7 +952,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
"data": {"content": content},
|
"data": {"content": content},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
|
||||||
logger.warning(f"Failed to enqueue intervention: {e}")
|
logger.warning(f"Failed to enqueue intervention: {e}")
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
|
|
@ -1022,11 +1028,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
|
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
|
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# Clean up pending futures
|
# Clean up pending futures
|
||||||
|
|
@ -1174,6 +1182,8 @@ async def _handle_chat_message(
|
||||||
content=final_content,
|
content=final_content,
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Check if this is a QuotaExceededError (U4: WebSocket quota).
|
# Check if this is a QuotaExceededError (U4: WebSocket quota).
|
||||||
from agentkit.llm.gateway import QuotaExceededError
|
from agentkit.llm.gateway import QuotaExceededError
|
||||||
|
|
@ -1422,6 +1432,8 @@ async def _handle_chat_message(
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Chat execution error for session {session_id}: {e}")
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
||||||
# Show meaningful error to user, but avoid leaking full stack traces
|
# Show meaningful error to user, but avoid leaking full stack traces
|
||||||
|
|
@ -1473,6 +1485,8 @@ async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]:
|
||||||
file_path.write_bytes(contents)
|
file_path.write_bytes(contents)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Failed to save uploaded file: {exc}")
|
logger.error(f"Failed to save uploaded file: {exc}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to save file") from exc
|
raise HTTPException(status_code=500, detail="Failed to save file") from exc
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ def _collect_skill_configs(request: Request) -> list[dict[str, Any]]:
|
||||||
"version": getattr(skill.config, "version", "1.0.0"),
|
"version": getattr(skill.config, "version", "1.0.0"),
|
||||||
"config": config_dict,
|
"config": config_dict,
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except (AttributeError, ValueError, KeyError, RuntimeError) as e:
|
||||||
logger.warning(f"Failed to collect skill configs: {e}")
|
logger.warning(f"Failed to collect skill configs: {e}")
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
@ -81,7 +81,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
|
||||||
from agentkit.server.routes.workflows import _workflow_store
|
from agentkit.server.routes.workflows import _workflow_store
|
||||||
|
|
||||||
workflow_store = _workflow_store
|
workflow_store = _workflow_store
|
||||||
except Exception:
|
except (ImportError, AttributeError):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
configs: list[dict[str, Any]] = []
|
configs: list[dict[str, Any]] = []
|
||||||
|
|
@ -94,7 +94,7 @@ def _collect_workflow_configs(request: Request) -> list[dict[str, Any]]:
|
||||||
configs.append(wf.to_dict())
|
configs.append(wf.to_dict())
|
||||||
else:
|
else:
|
||||||
configs.append(dict(wf))
|
configs.append(dict(wf))
|
||||||
except Exception as e:
|
except (RuntimeError, AttributeError, ValueError) as e:
|
||||||
logger.warning(f"Failed to collect workflow configs: {e}")
|
logger.warning(f"Failed to collect workflow configs: {e}")
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ Endpoints:
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -155,6 +156,8 @@ async def create_document(
|
||||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Document creation failed: {e}")
|
logger.error(f"Document creation failed: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Document creation failed") from e
|
raise HTTPException(status_code=500, detail="Document creation failed") from e
|
||||||
|
|
@ -187,7 +190,7 @@ async def upload_template(
|
||||||
file_path.write_bytes(contents)
|
file_path.write_bytes(contents)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except (OSError, RuntimeError) as exc:
|
||||||
logger.error(f"Failed to save template: {exc}")
|
logger.error(f"Failed to save template: {exc}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to save template") from exc
|
raise HTTPException(status_code=500, detail="Failed to save template") from exc
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Evolution API routes"""
|
"""Evolution API routes"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
@ -42,6 +43,8 @@ async def list_evolution_events(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
change_type=event_type,
|
change_type=event_type,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to list evolution events: {e}")
|
logger.error(f"Failed to list evolution events: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to list evolution events")
|
raise HTTPException(status_code=500, detail="Failed to list evolution events")
|
||||||
|
|
@ -63,6 +66,8 @@ async def get_skill_versions(skill_name: str, req: Request = None):
|
||||||
store = _get_evolution_store(req)
|
store = _get_evolution_store(req)
|
||||||
try:
|
try:
|
||||||
versions = await store.list_skill_versions(skill_name)
|
versions = await store.list_skill_versions(skill_name)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get skill versions for '{skill_name}': {e}")
|
logger.error(f"Failed to get skill versions for '{skill_name}': {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to get skill versions")
|
raise HTTPException(status_code=500, detail="Failed to get skill versions")
|
||||||
|
|
@ -103,6 +108,8 @@ async def trigger_evolution(request: TriggerEvolutionRequest, req: Request = Non
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
event_id = await store.record(event)
|
event_id = await store.record(event)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to record trigger event: {e}")
|
logger.error(f"Failed to record trigger event: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to trigger evolution")
|
raise HTTPException(status_code=500, detail="Failed to trigger evolution")
|
||||||
|
|
@ -153,6 +160,8 @@ async def list_ab_tests(
|
||||||
for e in entries
|
for e in entries
|
||||||
]
|
]
|
||||||
return {"items": results[:limit], "total": len(results)}
|
return {"items": results[:limit], "total": len(results)}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to list A/B tests from persistent store: {e}")
|
logger.error(f"Failed to list A/B tests from persistent store: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to list A/B tests")
|
raise HTTPException(status_code=500, detail="Failed to list A/B tests")
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ async def list_experiences(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"experiences": experiences, "total": len(experiences)}
|
return {"experiences": experiences, "total": len(experiences)}
|
||||||
except Exception as e:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
|
||||||
logger.error(f"Failed to list experiences from store: {e}")
|
logger.error(f"Failed to list experiences from store: {e}")
|
||||||
|
|
||||||
# Fallback to in-memory store
|
# Fallback to in-memory store
|
||||||
|
|
@ -324,7 +324,7 @@ async def get_metrics(
|
||||||
|
|
||||||
# Generate daily trends from the metrics
|
# Generate daily trends from the metrics
|
||||||
trends = _generate_trends(metrics_list, period)
|
trends = _generate_trends(metrics_list, period)
|
||||||
except Exception as e:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
|
||||||
logger.error(f"Failed to get metrics from store: {e}")
|
logger.error(f"Failed to get metrics from store: {e}")
|
||||||
else:
|
else:
|
||||||
# Generate from in-memory experiences
|
# Generate from in-memory experiences
|
||||||
|
|
@ -501,7 +501,7 @@ async def get_usage(
|
||||||
"errors": 0,
|
"errors": 0,
|
||||||
"avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1),
|
"avg_latency_ms": round(d["latency"] / max(d["requests"], 1), 1),
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except (ConnectionError, OSError, ValueError, KeyError, RuntimeError, AttributeError) as e:
|
||||||
logger.error(f"Failed to get usage from LLMGateway: {e}")
|
logger.error(f"Failed to get usage from LLMGateway: {e}")
|
||||||
|
|
||||||
# Fill in missing dates with zero
|
# Fill in missing dates with zero
|
||||||
|
|
@ -587,7 +587,7 @@ async def check_pitfalls(
|
||||||
})
|
})
|
||||||
|
|
||||||
return {"warnings": warnings_data}
|
return {"warnings": warnings_data}
|
||||||
except Exception as e:
|
except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError, ConnectionError) as e:
|
||||||
logger.error(f"Failed to check pitfalls: {e}")
|
logger.error(f"Failed to check pitfalls: {e}")
|
||||||
return {"warnings": []}
|
return {"warnings": []}
|
||||||
|
|
||||||
|
|
@ -642,7 +642,7 @@ async def list_path_optimizations(
|
||||||
else None,
|
else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except (ValueError, KeyError, RuntimeError, AttributeError) as e:
|
||||||
logger.error(f"Failed to get path optimizations: {e}")
|
logger.error(f"Failed to get path optimizations: {e}")
|
||||||
|
|
||||||
# Also include in-memory optimizations
|
# Also include in-memory optimizations
|
||||||
|
|
@ -767,7 +767,7 @@ async def evolution_dashboard_ws(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug("Evolution dashboard WebSocket disconnected")
|
logger.debug("Evolution dashboard WebSocket disconnected")
|
||||||
except Exception as e:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
|
||||||
logger.error(f"Evolution dashboard WebSocket error: {e}")
|
logger.error(f"Evolution dashboard WebSocket error: {e}")
|
||||||
finally:
|
finally:
|
||||||
if websocket in _ws_connections:
|
if websocket in _ws_connections:
|
||||||
|
|
@ -781,7 +781,7 @@ async def _broadcast_event(event_type: str, data: dict):
|
||||||
for ws in _ws_connections:
|
for ws in _ws_connections:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
disconnected.append(ws)
|
disconnected.append(ws)
|
||||||
for ws in disconnected:
|
for ws in disconnected:
|
||||||
if ws in _ws_connections:
|
if ws in _ws_connections:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""Health check route"""
|
"""Health check route"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
|
|
||||||
router = APIRouter(tags=["health"])
|
router = APIRouter(tags=["health"])
|
||||||
|
|
@ -23,14 +25,14 @@ async def health_check(request: Request):
|
||||||
redis_client = await task_store._get_redis()
|
redis_client = await task_store._get_redis()
|
||||||
await redis_client.ping()
|
await redis_client.ping()
|
||||||
redis_status = "available"
|
redis_status = "available"
|
||||||
except Exception as ping_exc:
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as ping_exc:
|
||||||
redis_status = f"error: {str(ping_exc)[:100]}"
|
redis_status = f"error: {str(ping_exc)[:100]}"
|
||||||
overall_status = "degraded"
|
overall_status = "degraded"
|
||||||
else:
|
else:
|
||||||
redis_status = "not_configured"
|
redis_status = "not_configured"
|
||||||
else:
|
else:
|
||||||
redis_status = "not_configured"
|
redis_status = "not_configured"
|
||||||
except Exception as exc:
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError, AttributeError) as exc:
|
||||||
redis_status = f"error: {str(exc)[:100]}"
|
redis_status = f"error: {str(exc)[:100]}"
|
||||||
overall_status = "degraded"
|
overall_status = "degraded"
|
||||||
checks["redis"] = redis_status
|
checks["redis"] = redis_status
|
||||||
|
|
@ -42,7 +44,7 @@ async def health_check(request: Request):
|
||||||
try:
|
try:
|
||||||
agents = agent_pool.list_agents()
|
agents = agent_pool.list_agents()
|
||||||
pool_size = len(agents)
|
pool_size = len(agents)
|
||||||
except Exception:
|
except (RuntimeError, AttributeError):
|
||||||
pass
|
pass
|
||||||
checks["agent_pool"] = {"status": "available", "size": pool_size}
|
checks["agent_pool"] = {"status": "available", "size": pool_size}
|
||||||
|
|
||||||
|
|
@ -57,7 +59,7 @@ async def health_check(request: Request):
|
||||||
else:
|
else:
|
||||||
llm_status = "no_providers"
|
llm_status = "no_providers"
|
||||||
overall_status = "degraded"
|
overall_status = "degraded"
|
||||||
except Exception:
|
except (RuntimeError, AttributeError, ValueError):
|
||||||
llm_status = "error"
|
llm_status = "error"
|
||||||
overall_status = "degraded"
|
overall_status = "degraded"
|
||||||
checks["llm_gateway"] = llm_status
|
checks["llm_gateway"] = llm_status
|
||||||
|
|
@ -68,7 +70,7 @@ async def health_check(request: Request):
|
||||||
if skill_registry:
|
if skill_registry:
|
||||||
try:
|
try:
|
||||||
skill_count = len(skill_registry.list_skills())
|
skill_count = len(skill_registry.list_skills())
|
||||||
except Exception:
|
except (RuntimeError, AttributeError):
|
||||||
pass
|
pass
|
||||||
checks["skill_registry"] = {
|
checks["skill_registry"] = {
|
||||||
"status": "available" if skill_registry else "not_configured",
|
"status": "available" if skill_registry else "not_configured",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -252,7 +253,7 @@ async def list_sources(
|
||||||
visible_ids = await filter_kb_sources_by_department(
|
visible_ids = await filter_kb_sources_by_department(
|
||||||
db_path, dept_ctx.department_ids, all_ids
|
db_path, dept_ctx.department_ids, all_ids
|
||||||
)
|
)
|
||||||
except Exception: # noqa: BLE001 — never block listing on DB errors
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
|
||||||
logger.exception("Department KB filtering failed — returning empty list")
|
logger.exception("Department KB filtering failed — returning empty list")
|
||||||
return {"sources": []}
|
return {"sources": []}
|
||||||
visible_set = set(visible_ids)
|
visible_set = set(visible_ids)
|
||||||
|
|
@ -398,7 +399,7 @@ async def list_documents(
|
||||||
visible_ids = await filter_kb_sources_by_department(
|
visible_ids = await filter_kb_sources_by_department(
|
||||||
db_path, dept_ctx.department_ids, all_source_ids
|
db_path, dept_ctx.department_ids, all_source_ids
|
||||||
)
|
)
|
||||||
except Exception: # noqa: BLE001 — never block listing on DB errors
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): # noqa: BLE001 — never block listing on DB errors
|
||||||
logger.exception("Department KB filtering failed — returning empty list")
|
logger.exception("Department KB filtering failed — returning empty list")
|
||||||
return {"documents": []}
|
return {"documents": []}
|
||||||
visible_set = set(visible_ids)
|
visible_set = set(visible_ids)
|
||||||
|
|
@ -484,7 +485,7 @@ async def upload_document(
|
||||||
try:
|
try:
|
||||||
text = processor.parse(tmp_path, file_type)
|
text = processor.parse(tmp_path, file_type)
|
||||||
chunks = processor.segment(text)
|
chunks = processor.segment(text)
|
||||||
except Exception as e:
|
except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
|
||||||
logger.warning("Document parsing failed: %s", e)
|
logger.warning("Document parsing failed: %s", e)
|
||||||
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
|
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
|
||||||
|
|
||||||
|
|
@ -567,7 +568,7 @@ async def preview_document(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except (ValueError, OSError, RuntimeError, UnicodeDecodeError) as e:
|
||||||
logger.warning("Document preview failed: %s", e)
|
logger.warning("Document preview failed: %s", e)
|
||||||
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
|
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
|
||||||
|
|
||||||
|
|
@ -628,7 +629,7 @@ async def search_knowledge(
|
||||||
for r in results
|
for r in results
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError) as e:
|
||||||
logger.warning(f"Semantic search failed: {e}")
|
logger.warning(f"Semantic search failed: {e}")
|
||||||
|
|
||||||
# Fallback: return empty results with a hint
|
# Fallback: return empty results with a hint
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ async def publish_skill(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
skill = skill_registry.get(skill_name)
|
skill = skill_registry.get(skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Memory API routes"""
|
"""Memory API routes"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
@ -40,6 +41,8 @@ async def search_episodic_memory(
|
||||||
if agent_name:
|
if agent_name:
|
||||||
filters["agent_name"] = agent_name
|
filters["agent_name"] = agent_name
|
||||||
items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None)
|
items = await retriever._episodic.search(query, top_k=top_k, filters=filters or None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to search episodic memory: {e}")
|
logger.error(f"Failed to search episodic memory: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to search episodic memory")
|
raise HTTPException(status_code=500, detail="Failed to search episodic memory")
|
||||||
|
|
@ -76,6 +79,8 @@ async def search_semantic_memory(
|
||||||
if knowledge_base_ids:
|
if knowledge_base_ids:
|
||||||
filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")]
|
filters["knowledge_base_ids"] = [kid.strip() for kid in knowledge_base_ids.split(",")]
|
||||||
items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None)
|
items = await retriever._semantic.search(query, top_k=top_k, filters=filters or None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to search semantic memory: {e}")
|
logger.error(f"Failed to search semantic memory: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to search semantic memory")
|
raise HTTPException(status_code=500, detail="Failed to search semantic memory")
|
||||||
|
|
@ -104,6 +109,8 @@ async def delete_episodic_memory(key: str, req: Request = None):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deleted = await retriever._episodic.delete(key)
|
deleted = await retriever._episodic.delete(key)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete episodic memory '{key}': {e}")
|
logger.error(f"Failed to delete episodic memory '{key}': {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete episodic memory")
|
raise HTTPException(status_code=500, detail="Failed to delete episodic memory")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Metrics route — /api/v1/metrics"""
|
"""Metrics route — /api/v1/metrics"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
|
|
@ -29,7 +30,7 @@ async def get_metrics(request: Request):
|
||||||
task_metrics["completed_tasks"] = counts.get("completed", 0)
|
task_metrics["completed_tasks"] = counts.get("completed", 0)
|
||||||
task_metrics["failed_tasks"] = counts.get("failed", 0)
|
task_metrics["failed_tasks"] = counts.get("failed", 0)
|
||||||
task_metrics["pending_tasks"] = counts.get("pending", 0)
|
task_metrics["pending_tasks"] = counts.get("pending", 0)
|
||||||
except Exception as e:
|
except (RuntimeError, AttributeError, ConnectionError, asyncio.TimeoutError) as e:
|
||||||
logger.warning(f"Failed to collect task metrics: {e}")
|
logger.warning(f"Failed to collect task metrics: {e}")
|
||||||
|
|
||||||
# Agent pool metrics
|
# Agent pool metrics
|
||||||
|
|
@ -41,7 +42,7 @@ async def get_metrics(request: Request):
|
||||||
try:
|
try:
|
||||||
agents = agent_pool.list_agents()
|
agents = agent_pool.list_agents()
|
||||||
agent_metrics["total_agents"] = len(agents)
|
agent_metrics["total_agents"] = len(agents)
|
||||||
except Exception as e:
|
except (RuntimeError, AttributeError) as e:
|
||||||
logger.warning(f"Failed to collect agent metrics: {e}")
|
logger.warning(f"Failed to collect agent metrics: {e}")
|
||||||
|
|
||||||
# Skill registry metrics
|
# Skill registry metrics
|
||||||
|
|
@ -53,7 +54,7 @@ async def get_metrics(request: Request):
|
||||||
try:
|
try:
|
||||||
skills = skill_registry.list_skills()
|
skills = skill_registry.list_skills()
|
||||||
skill_metrics["total_skills"] = len(skills)
|
skill_metrics["total_skills"] = len(skills)
|
||||||
except Exception as e:
|
except (RuntimeError, AttributeError) as e:
|
||||||
logger.warning(f"Failed to collect skill metrics: {e}")
|
logger.warning(f"Failed to collect skill metrics: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ async def _emit_event_safe(
|
||||||
data=data or {},
|
data=data or {},
|
||||||
)
|
)
|
||||||
await event_queue.emit(event)
|
await event_queue.emit(event)
|
||||||
except Exception as e:
|
except (asyncio.QueueFull, RuntimeError, ConnectionError) as e:
|
||||||
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
|
logger.warning(f"EventQueue emit failed (type={event_type}): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -211,7 +211,7 @@ class PortalConnectionManager:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded"))
|
asyncio.create_task(oldest.close(code=1008, reason="Connection limit exceeded"))
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
conns.append(ws)
|
conns.append(ws)
|
||||||
|
|
||||||
|
|
@ -235,7 +235,7 @@ class PortalConnectionManager:
|
||||||
for ws in conns:
|
for ws in conns:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception as e:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError) as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Portal WS send failed for user %s (marking stale): %s", user_id, e
|
"Portal WS send failed for user %s (marking stale): %s", user_id, e
|
||||||
)
|
)
|
||||||
|
|
@ -285,7 +285,7 @@ async def _build_history_messages(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = await _conversation_store.get_history(conv_id, limit=limit)
|
history = await _conversation_store.get_history(conv_id, limit=limit)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# The last message in history is the current user message (just added),
|
# The last message in history is the current user message (just added),
|
||||||
|
|
@ -553,6 +553,8 @@ async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify
|
||||||
):
|
):
|
||||||
if event.event_type == "final_answer":
|
if event.event_type == "final_answer":
|
||||||
collected_output.append(event.data.get("output", ""))
|
collected_output.append(event.data.get("output", ""))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_text = f"执行出错: {e}"
|
response_text = f"执行出错: {e}"
|
||||||
else:
|
else:
|
||||||
|
|
@ -682,6 +684,8 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield {
|
yield {
|
||||||
"event": "error",
|
"event": "error",
|
||||||
|
|
@ -862,7 +866,7 @@ async def _execute_react_background(
|
||||||
await _task_store_update_status(
|
await _task_store_update_status(
|
||||||
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
|
task_store, task_id, TaskStatus.RUNNING, started_at=datetime.now(timezone.utc)
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
|
logger.warning("Failed to update TaskStore RUNNING", exc_info=True)
|
||||||
|
|
||||||
async for event in react_engine.execute_stream(
|
async for event in react_engine.execute_stream(
|
||||||
|
|
@ -909,7 +913,7 @@ async def _execute_react_background(
|
||||||
progress=1.0,
|
progress=1.0,
|
||||||
progress_message="Completed",
|
progress_message="Completed",
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
|
logger.warning("Failed to update TaskStore COMPLETED", exc_info=True)
|
||||||
|
|
||||||
# Emit task.completed so subscribers know the task is done
|
# Emit task.completed so subscribers know the task is done
|
||||||
|
|
@ -932,7 +936,15 @@ async def _execute_react_background(
|
||||||
partial = _ensure_non_empty("".join(collected_output))
|
partial = _ensure_non_empty("".join(collected_output))
|
||||||
try:
|
try:
|
||||||
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
|
await asyncio.shield(conversation_store.add_message(conv_id, "assistant", partial))
|
||||||
except (Exception, asyncio.CancelledError):
|
except (
|
||||||
|
asyncio.CancelledError,
|
||||||
|
ConnectionError,
|
||||||
|
OSError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ValueError,
|
||||||
|
KeyError,
|
||||||
|
RuntimeError,
|
||||||
|
):
|
||||||
logger.warning("Failed to persist partial output on cancel")
|
logger.warning("Failed to persist partial output on cancel")
|
||||||
if task_store is not None:
|
if task_store is not None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -945,7 +957,15 @@ async def _execute_react_background(
|
||||||
completed_at=datetime.now(timezone.utc),
|
completed_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (Exception, asyncio.CancelledError):
|
except (
|
||||||
|
asyncio.CancelledError,
|
||||||
|
ConnectionError,
|
||||||
|
OSError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ValueError,
|
||||||
|
KeyError,
|
||||||
|
RuntimeError,
|
||||||
|
):
|
||||||
logger.warning("Failed to update TaskStore on cancel", exc_info=True)
|
logger.warning("Failed to update TaskStore on cancel", exc_info=True)
|
||||||
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
|
# P0 #2 fix: _emit_event_safe is async (it awaits event_queue.emit).
|
||||||
# Shield it so a re-entrant CancelledError doesn't kill the emit
|
# Shield it so a re-entrant CancelledError doesn't kill the emit
|
||||||
|
|
@ -963,7 +983,7 @@ async def _execute_react_background(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (Exception, asyncio.CancelledError):
|
except (asyncio.CancelledError, asyncio.QueueFull, RuntimeError, ConnectionError):
|
||||||
logger.warning("Failed to emit TASK_FAILED on cancel")
|
logger.warning("Failed to emit TASK_FAILED on cancel")
|
||||||
raise # Propagate cancellation
|
raise # Propagate cancellation
|
||||||
|
|
||||||
|
|
@ -973,7 +993,7 @@ async def _execute_react_background(
|
||||||
partial = _ensure_non_empty("".join(collected_output))
|
partial = _ensure_non_empty("".join(collected_output))
|
||||||
try:
|
try:
|
||||||
await conversation_store.add_message(conv_id, "assistant", partial)
|
await conversation_store.add_message(conv_id, "assistant", partial)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to persist partial output in background task")
|
logger.warning("Failed to persist partial output in background task")
|
||||||
|
|
||||||
if task_store is not None:
|
if task_store is not None:
|
||||||
|
|
@ -985,7 +1005,7 @@ async def _execute_react_background(
|
||||||
error_message=str(e),
|
error_message=str(e),
|
||||||
completed_at=datetime.now(timezone.utc),
|
completed_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to update TaskStore FAILED", exc_info=True)
|
logger.warning("Failed to update TaskStore FAILED", exc_info=True)
|
||||||
|
|
||||||
# Emit task.failed so subscribers know the task failed
|
# Emit task.failed so subscribers know the task failed
|
||||||
|
|
@ -1120,7 +1140,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await _task_store_get(resume_task_store, resume_task_id)
|
record = await _task_store_get(resume_task_store, resume_task_id)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("TaskStore.get failed during resume", exc_info=True)
|
logger.warning("TaskStore.get failed during resume", exc_info=True)
|
||||||
record = None
|
record = None
|
||||||
if record is not None:
|
if record is not None:
|
||||||
|
|
@ -1333,7 +1353,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await _broadcast_dashboard_event("metrics_updated", {"period": "7d"})
|
await _broadcast_dashboard_event("metrics_updated", {"period": "7d"})
|
||||||
except Exception as e:
|
except (asyncio.QueueFull, RuntimeError, ConnectionError, ValueError, KeyError) as e:
|
||||||
logger.warning(f"Failed to record experience: {e}")
|
logger.warning(f"Failed to record experience: {e}")
|
||||||
|
|
||||||
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
# Unified preprocessing via RequestPreprocessor (minimal: @skill prefix + greeting regex + REACT)
|
||||||
|
|
@ -1414,7 +1434,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
TaskStatus.PENDING,
|
TaskStatus.PENDING,
|
||||||
metadata={"conversation_id": conv.id},
|
metadata={"conversation_id": conv.id},
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to register task in TaskStore", exc_info=True)
|
logger.warning("Failed to register task in TaskStore", exc_info=True)
|
||||||
|
|
||||||
# Execute based on routing result's execution_mode
|
# Execute based on routing result's execution_mode
|
||||||
|
|
@ -1455,7 +1475,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
progress=1.0,
|
progress=1.0,
|
||||||
progress_message="Completed",
|
progress_message="Completed",
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
|
logger.warning("Failed to update TaskStore for DIRECT_CHAT", exc_info=True)
|
||||||
|
|
||||||
# Emit turn.final_answer and task.completed to EQ
|
# Emit turn.final_answer and task.completed to EQ
|
||||||
|
|
@ -1526,7 +1546,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
chat_messages.insert(
|
chat_messages.insert(
|
||||||
-1, {"role": hist_msg.role, "content": hist_msg.content}
|
-1, {"role": hist_msg.role, "content": hist_msg.content}
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
response = await llm_gateway.chat(
|
response = await llm_gateway.chat(
|
||||||
messages=chat_messages,
|
messages=chat_messages,
|
||||||
|
|
@ -1627,7 +1647,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
logger.warning("EventQueue not configured; awaiting background task directly")
|
logger.warning("EventQueue not configured; awaiting background task directly")
|
||||||
try:
|
try:
|
||||||
await bg_task
|
await bg_task
|
||||||
except Exception:
|
except (RuntimeError, ConnectionError, asyncio.TimeoutError):
|
||||||
pass # errors handled inside _execute_react_background
|
pass # errors handled inside _execute_react_background
|
||||||
active_bg_task = None
|
active_bg_task = None
|
||||||
continue
|
continue
|
||||||
|
|
@ -1734,6 +1754,8 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
# kill the task, lose the full output, and mark it FAILED —
|
# kill the task, lose the full output, and mark it FAILED —
|
||||||
# defeating layers 2 and 3. The task is only cancelled on explicit
|
# defeating layers 2 and 3. The task is only cancelled on explicit
|
||||||
# user cancel (msg_type == 'cancel') or application shutdown.
|
# user cancel (msg_type == 'cancel') or application shutdown.
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Portal WebSocket error: {e}")
|
logger.error(f"Portal WebSocket error: {e}")
|
||||||
# P1 #6 fix: Do NOT cancel the background task on connection-level
|
# P1 #6 fix: Do NOT cancel the background task on connection-level
|
||||||
|
|
@ -1758,7 +1780,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# Remove from user-scoped push tracking on any disconnect/error/return.
|
# Remove from user-scoped push tracking on any disconnect/error/return.
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ def _skill_to_detail(skill: Any) -> dict[str, Any]:
|
||||||
if hasattr(skill, "config"):
|
if hasattr(skill, "config"):
|
||||||
try:
|
try:
|
||||||
config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {}
|
config = skill.config.to_dict() if hasattr(skill.config, "to_dict") else {}
|
||||||
except Exception:
|
except (AttributeError, ValueError, TypeError):
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -174,7 +174,7 @@ async def get_skill_detail(skill_name: str, req: Request):
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
try:
|
try:
|
||||||
skill = skill_registry.get(skill_name)
|
skill = skill_registry.get(skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
return _skill_to_detail(skill)
|
return _skill_to_detail(skill)
|
||||||
|
|
@ -186,7 +186,7 @@ async def check_skill_health(skill_name: str, req: Request):
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
try:
|
try:
|
||||||
skill_registry.get(skill_name)
|
skill_registry.get(skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
# Basic health check - skill exists and is registered
|
# Basic health check - skill exists and is registered
|
||||||
|
|
@ -243,7 +243,7 @@ async def reload_skill(skill_name: str, req: Request):
|
||||||
# Verify the skill is currently registered (404 if not).
|
# Verify the skill is currently registered (404 if not).
|
||||||
try:
|
try:
|
||||||
skill_registry.get(skill_name)
|
skill_registry.get(skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||||
|
|
||||||
# Resolve the skills directory (mirrors routes.skills._get_skills_dir).
|
# Resolve the skills directory (mirrors routes.skills._get_skills_dir).
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ async def register_skill(request: RegisterSkillRequest, req: Request):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = SkillConfig.from_dict(request.config)
|
config = SkillConfig.from_dict(request.config)
|
||||||
except Exception as e:
|
except (ValueError, TypeError, KeyError) as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}")
|
raise HTTPException(status_code=422, detail=f"Invalid skill config: {e}")
|
||||||
|
|
||||||
skill = Skill(config=config)
|
skill = Skill(config=config)
|
||||||
|
|
@ -279,7 +279,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
resp = await client.get(source)
|
resp = await client.get(source)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yaml_content = resp.text
|
yaml_content = resp.text
|
||||||
except Exception as e:
|
except (httpx.HTTPError, OSError) as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
|
raise HTTPException(status_code=400, detail=f"Failed to download from source: {e}")
|
||||||
elif source and source.startswith("file://"):
|
elif source and source.startswith("file://"):
|
||||||
# Read from local file path
|
# Read from local file path
|
||||||
|
|
@ -295,7 +295,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
try:
|
try:
|
||||||
with open(local_path, encoding="utf-8") as f:
|
with open(local_path, encoding="utf-8") as f:
|
||||||
yaml_content = f.read()
|
yaml_content = f.read()
|
||||||
except Exception as e:
|
except OSError as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
|
raise HTTPException(status_code=400, detail=f"Failed to read local file: {e}")
|
||||||
else:
|
else:
|
||||||
# Search GitHub for skills (YAML config files)
|
# Search GitHub for skills (YAML config files)
|
||||||
|
|
@ -313,7 +313,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
gh_data = gh_resp.json()
|
gh_data = gh_resp.json()
|
||||||
except Exception as e:
|
except (httpx.HTTPError, OSError, ValueError) as e:
|
||||||
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
|
raise HTTPException(status_code=502, detail=f"GitHub search failed: {e}")
|
||||||
|
|
||||||
items = gh_data.get("items", [])
|
items = gh_data.get("items", [])
|
||||||
|
|
@ -334,7 +334,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
items = gh_resp2.json().get("items", [])
|
items = gh_resp2.json().get("items", [])
|
||||||
except Exception:
|
except (httpx.HTTPError, OSError, ValueError, KeyError):
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
if not items:
|
if not items:
|
||||||
|
|
@ -362,7 +362,7 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
resp = await client.get(raw_url)
|
resp = await client.get(raw_url)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
yaml_content = resp.text
|
yaml_content = resp.text
|
||||||
except Exception as e:
|
except (httpx.HTTPError, OSError) as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
|
raise HTTPException(status_code=400, detail=f"Failed to download skill: {e}")
|
||||||
|
|
||||||
# Validate YAML content before writing to disk
|
# Validate YAML content before writing to disk
|
||||||
|
|
@ -391,14 +391,14 @@ async def install_skill(request: InstallSkillRequest, req: Request):
|
||||||
)
|
)
|
||||||
loader.load_from_file(file_path)
|
loader.load_from_file(file_path)
|
||||||
registration_ok = True
|
registration_ok = True
|
||||||
except Exception as e:
|
except (ValueError, TypeError, KeyError, OSError, RuntimeError) as e:
|
||||||
logger.warning(f"Failed to register installed skill: {e}")
|
logger.warning(f"Failed to register installed skill: {e}")
|
||||||
|
|
||||||
if not registration_ok:
|
if not registration_ok:
|
||||||
# Remove the invalid YAML file and report error
|
# Remove the invalid YAML file and report error
|
||||||
try:
|
try:
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
except Exception:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
|
raise HTTPException(status_code=500, detail="Skill downloaded but registration failed")
|
||||||
|
|
||||||
|
|
@ -419,7 +419,7 @@ async def uninstall_skill(name: str, req: Request):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
skill_registry.get(validated_name)
|
skill_registry.get(validated_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, RuntimeError):
|
||||||
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Skill '{name}' not found")
|
||||||
|
|
||||||
# Remove from registry
|
# Remove from registry
|
||||||
|
|
@ -487,7 +487,7 @@ async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Requ
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await pipeline.execute(input_data=request.input_data)
|
result = await pipeline.execute(input_data=request.input_data)
|
||||||
except Exception as e:
|
except (ValueError, TypeError, KeyError, RuntimeError, OSError, ConnectionError) as e:
|
||||||
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
|
logger.error(f"Pipeline execution failed for '{name}': {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail="Pipeline execution failed")
|
raise HTTPException(status_code=500, detail="Pipeline execution failed")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ def _read_meminfo() -> dict[str, int]:
|
||||||
key = parts[0].strip()
|
key = parts[0].strip()
|
||||||
value = parts[1].strip().split()[0]
|
value = parts[1].strip().split()[0]
|
||||||
values[key] = int(value) * 1024 # kB -> bytes
|
values[key] = int(value) * 1024 # kB -> bytes
|
||||||
except Exception as exc:
|
except (OSError, ValueError, FileNotFoundError) as exc:
|
||||||
logger.debug(f"Failed to read /proc/meminfo: {exc}")
|
logger.debug(f"Failed to read /proc/meminfo: {exc}")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
@ -52,7 +52,7 @@ async def get_system_resources() -> dict[str, Any]:
|
||||||
if hasattr(os, "getloadavg"):
|
if hasattr(os, "getloadavg"):
|
||||||
try:
|
try:
|
||||||
loadavg = list(os.getloadavg())
|
loadavg = list(os.getloadavg())
|
||||||
except Exception as exc:
|
except (OSError, AttributeError) as exc:
|
||||||
logger.debug(f"Failed to get loadavg: {exc}")
|
logger.debug(f"Failed to get loadavg: {exc}")
|
||||||
|
|
||||||
meminfo = _read_meminfo()
|
meminfo = _read_meminfo()
|
||||||
|
|
@ -68,7 +68,7 @@ async def get_system_resources() -> dict[str, Any]:
|
||||||
disk_total = du.total
|
disk_total = du.total
|
||||||
disk_used = du.used
|
disk_used = du.used
|
||||||
disk_free = du.free
|
disk_free = du.free
|
||||||
except Exception as exc:
|
except (OSError, FileNotFoundError) as exc:
|
||||||
logger.debug(f"Failed to get disk usage: {exc}")
|
logger.debug(f"Failed to get disk usage: {exc}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
elif request.skill_name:
|
elif request.skill_name:
|
||||||
try:
|
try:
|
||||||
skill = skill_registry.get(request.skill_name)
|
skill = skill_registry.get(request.skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Skill '{request.skill_name}' not found",
|
detail=f"Skill '{request.skill_name}' not found",
|
||||||
|
|
@ -145,7 +145,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
quality_result = await quality_gate.validate(
|
quality_result = await quality_gate.validate(
|
||||||
task_result.output_data or {}, skill, skill_context=skill_context
|
task_result.output_data or {}, skill, skill_context=skill_context
|
||||||
)
|
)
|
||||||
except Exception:
|
except (RuntimeError, ValueError, KeyError, AttributeError, asyncio.TimeoutError):
|
||||||
pass # Quality gate failure shouldn't block the response
|
pass # Quality gate failure shouldn't block the response
|
||||||
|
|
||||||
# 7. Standardize output if skill available
|
# 7. Standardize output if skill available
|
||||||
|
|
@ -167,7 +167,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
"task_id": task.task_id,
|
"task_id": task.task_id,
|
||||||
"status": task_result.status,
|
"status": task_result.status,
|
||||||
}
|
}
|
||||||
except Exception:
|
except (ValueError, KeyError, AttributeError, RuntimeError):
|
||||||
pass # Fall through to raw output
|
pass # Fall through to raw output
|
||||||
|
|
||||||
# 8. Return raw result if no skill or standardization failed
|
# 8. Return raw result if no skill or standardization failed
|
||||||
|
|
@ -307,7 +307,7 @@ async def resume_task(task_id: str, req: Request, plan_id: str | None = None):
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await team.dissolve()
|
await team.dissolve()
|
||||||
except Exception:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -343,7 +343,7 @@ async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||||
elif request.skill_name:
|
elif request.skill_name:
|
||||||
try:
|
try:
|
||||||
skill_registry.get(request.skill_name)
|
skill_registry.get(request.skill_name)
|
||||||
except Exception:
|
except (KeyError, ValueError, AttributeError):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Skill '{request.skill_name}' not found",
|
detail=f"Skill '{request.skill_name}' not found",
|
||||||
|
|
|
||||||
|
|
@ -359,14 +359,14 @@ async def execute_command(
|
||||||
await check_pty.close()
|
await check_pty.close()
|
||||||
if cwd_result.exit_code == 0:
|
if cwd_result.exit_code == 0:
|
||||||
state.cwd = cwd_result.output.strip()
|
state.cwd = cwd_result.output.strip()
|
||||||
except Exception:
|
except (OSError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
output = f"命令执行超时({request.timeout}s)"
|
output = f"命令执行超时({request.timeout}s)"
|
||||||
exit_code = -1
|
exit_code = -1
|
||||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
except Exception as e:
|
except (OSError, RuntimeError, ValueError) as e:
|
||||||
output = str(e)
|
output = str(e)
|
||||||
exit_code = -1
|
exit_code = -1
|
||||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
|
|
@ -515,7 +515,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
})
|
})
|
||||||
await websocket.close(code=4003, reason="Permission denied")
|
await websocket.close(code=4003, reason="Permission denied")
|
||||||
return
|
return
|
||||||
except Exception:
|
except (ValueError, KeyError, RuntimeError, OSError):
|
||||||
pass # Fall through to API key / dev mode
|
pass # Fall through to API key / dev mode
|
||||||
|
|
||||||
# 2. API key via ?api_key=
|
# 2. API key via ?api_key=
|
||||||
|
|
@ -721,7 +721,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
})
|
})
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Get final result
|
# Get final result
|
||||||
|
|
@ -741,7 +741,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
await check_pty.close()
|
await check_pty.close()
|
||||||
if cwd_result.exit_code == 0:
|
if cwd_result.exit_code == 0:
|
||||||
state.cwd = cwd_result.output.strip()
|
state.cwd = cwd_result.output.strip()
|
||||||
except Exception:
|
except (OSError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Record history
|
# Record history
|
||||||
|
|
@ -790,6 +790,8 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
"cwd": state.cwd,
|
"cwd": state.cwd,
|
||||||
"duration_ms": duration_ms,
|
"duration_ms": duration_ms,
|
||||||
})
|
})
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
|
|
@ -800,7 +802,7 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
if active_pty and active_pty.is_running:
|
if active_pty and active_pty.is_running:
|
||||||
try:
|
try:
|
||||||
await active_pty.close()
|
await active_pty.close()
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
active_pty = None
|
active_pty = None
|
||||||
|
|
||||||
|
|
@ -836,11 +838,13 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}")
|
logger.debug(f"Terminal WebSocket disconnected for session {state.session_id}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}")
|
logger.error(f"Terminal WebSocket error for session {state.session_id}: {e}")
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "error", "message": str(e)[:200]})
|
await websocket.send_json({"type": "error", "message": str(e)[:200]})
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# Clean up pending confirmations
|
# Clean up pending confirmations
|
||||||
|
|
@ -851,5 +855,5 @@ async def terminal_websocket(websocket: WebSocket) -> None:
|
||||||
if active_pty and active_pty.is_running:
|
if active_pty and active_pty.is_running:
|
||||||
try:
|
try:
|
||||||
await active_pty.close()
|
await active_pty.close()
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -549,7 +549,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
})
|
})
|
||||||
await websocket.close(code=4003, reason="Permission denied")
|
await websocket.close(code=4003, reason="Permission denied")
|
||||||
return
|
return
|
||||||
except Exception:
|
except (ValueError, KeyError, RuntimeError, OSError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not auth_ok:
|
if not auth_ok:
|
||||||
|
|
@ -585,7 +585,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
})
|
})
|
||||||
await websocket.close(code=4003, reason="Not authorized")
|
await websocket.close(code=4003, reason="Not authorized")
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except (aiosqlite.Error, OSError, ValueError, KeyError, RuntimeError) as e:
|
||||||
logger.warning(f"Failed to check server terminal authorization: {e}")
|
logger.warning(f"Failed to check server terminal authorization: {e}")
|
||||||
# If DB check fails, deny access
|
# If DB check fails, deny access
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
|
|
@ -822,7 +822,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
})
|
})
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
break
|
break
|
||||||
|
|
||||||
result = await run_task
|
result = await run_task
|
||||||
|
|
@ -846,7 +846,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
else:
|
else:
|
||||||
base = state.cwd or os.path.expanduser("~")
|
base = state.cwd or os.path.expanduser("~")
|
||||||
state.cwd = os.path.normpath(os.path.join(base, cd_arg))
|
state.cwd = os.path.normpath(os.path.join(base, cd_arg))
|
||||||
except Exception:
|
except (ValueError, OSError, RuntimeError):
|
||||||
pass # cd parsing failed — keep old cwd
|
pass # cd parsing failed — keep old cwd
|
||||||
|
|
||||||
# Record history
|
# Record history
|
||||||
|
|
@ -893,6 +893,8 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
"cwd": state.cwd,
|
"cwd": state.cwd,
|
||||||
"duration_ms": duration_ms,
|
"duration_ms": duration_ms,
|
||||||
})
|
})
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
|
|
@ -903,7 +905,7 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
if active_pty and active_pty.is_running:
|
if active_pty and active_pty.is_running:
|
||||||
try:
|
try:
|
||||||
await active_pty.close()
|
await active_pty.close()
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
active_pty = None
|
active_pty = None
|
||||||
|
|
||||||
|
|
@ -942,17 +944,19 @@ async def server_terminal_websocket(websocket: WebSocket) -> None:
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}")
|
logger.debug(f"Server terminal WebSocket disconnected for session {state.session_id}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}")
|
logger.error(f"Server terminal WebSocket error for session {state.session_id}: {e}")
|
||||||
try:
|
try:
|
||||||
await websocket.send_json({"type": "error", "message": str(e)[:200]})
|
await websocket.send_json({"type": "error", "message": str(e)[:200]})
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
if active_pty and active_pty.is_running:
|
if active_pty and active_pty.is_running:
|
||||||
try:
|
try:
|
||||||
await active_pty.close()
|
await active_pty.close()
|
||||||
except Exception:
|
except (OSError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
# Clean up session from in-memory stores to prevent leaks
|
# Clean up session from in-memory stores to prevent leaks
|
||||||
_server_sessions.pop(state.session_id, None)
|
_server_sessions.pop(state.session_id, None)
|
||||||
|
|
|
||||||
|
|
@ -433,6 +433,8 @@ async def _execute_workflow(
|
||||||
)
|
)
|
||||||
result = await agent.handle_task(task)
|
result = await agent.handle_task(task)
|
||||||
stage_result = {"output": result, "skill": stage.action}
|
stage_result = {"output": result, "skill": stage.action}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stage_result = {"error": str(e), "skill": stage.action}
|
stage_result = {"error": str(e), "skill": stage.action}
|
||||||
else:
|
else:
|
||||||
|
|
@ -474,6 +476,8 @@ async def _execute_workflow(
|
||||||
)
|
)
|
||||||
result = await agent.handle_task(task)
|
result = await agent.handle_task(task)
|
||||||
return {"output": result, "skill": action}
|
return {"output": result, "skill": action}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e), "skill": action}
|
return {"error": str(e), "skill": action}
|
||||||
return {"dry_run": True, "action": action}
|
return {"dry_run": True, "action": action}
|
||||||
|
|
@ -515,6 +519,8 @@ async def _execute_workflow(
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
execution.stage_results[stage_name] = {
|
execution.stage_results[stage_name] = {
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
|
|
@ -633,7 +639,7 @@ async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None
|
||||||
for ws in targets:
|
for ws in targets:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
disconnected.append(ws)
|
disconnected.append(ws)
|
||||||
if disconnected:
|
if disconnected:
|
||||||
async with _ws_subscribers_lock:
|
async with _ws_subscribers_lock:
|
||||||
|
|
@ -952,7 +958,7 @@ async def workflow_websocket(websocket: WebSocket):
|
||||||
# Keep connection alive - messages are primarily server-push
|
# Keep connection alive - messages are primarily server-push
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug("Workflow WebSocket disconnected")
|
logger.debug("Workflow WebSocket disconnected")
|
||||||
except Exception as e:
|
except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e:
|
||||||
logger.error(f"Workflow WebSocket error: {e}")
|
logger.error(f"Workflow WebSocket error: {e}")
|
||||||
finally:
|
finally:
|
||||||
if subscribed_execution_id:
|
if subscribed_execution_id:
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class ConnectionManager:
|
||||||
for ws, _ in conns:
|
for ws, _ in conns:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
stale.append(ws)
|
stale.append(ws)
|
||||||
for ws in stale:
|
for ws in stale:
|
||||||
self.remove(task_id, ws)
|
self.remove(task_id, ws)
|
||||||
|
|
@ -153,6 +153,8 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.debug(f"WebSocket disconnected for task {task_id}")
|
logger.debug(f"WebSocket disconnected for task {task_id}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"WebSocket error for task {task_id}: {e}")
|
logger.error(f"WebSocket error for task {task_id}: {e}")
|
||||||
try:
|
try:
|
||||||
|
|
@ -162,7 +164,7 @@ async def task_websocket(websocket: WebSocket, task_id: str) -> None:
|
||||||
"data": {"message": str(e)},
|
"data": {"message": str(e)},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception:
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
manager.remove(task_id, websocket)
|
manager.remove(task_id, websocket)
|
||||||
|
|
@ -243,6 +245,8 @@ async def _run_react_and_stream(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue