refactor: remove Any from Wave 4 modules (250 sites)
This commit is contained in:
parent
57f4ee9ac0
commit
1f4f54b073
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Callable, Awaitable
|
||||||
|
|
||||||
from agentkit.bus.message import AgentMessage
|
from agentkit.bus.message import AgentMessage
|
||||||
|
|
||||||
|
|
@ -20,8 +20,8 @@ class InMemoryMessageBus:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cascade_detector: Any = None,
|
cascade_detector: object = None,
|
||||||
alignment_guard: Any = None,
|
alignment_guard: object = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
||||||
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -20,11 +19,11 @@ class AgentMessage:
|
||||||
sender: str = ""
|
sender: str = ""
|
||||||
recipient: str | None = None # None = broadcast
|
recipient: str | None = None # None = broadcast
|
||||||
topic: str = ""
|
topic: str = ""
|
||||||
payload: dict[str, Any] = field(default_factory=dict)
|
payload: dict[str, object] = field(default_factory=dict)
|
||||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
correlation_id: str | None = None # 请求-响应关联
|
correlation_id: str | None = None # 请求-响应关联
|
||||||
# --- 新增字段 ---
|
# --- 新增字段 ---
|
||||||
content: Any = None # 消息内容(与 payload 互补,payload 为 dict,content 可为任意类型)
|
content: object = None # 消息内容(与 payload 互补,payload 为 dict,content 可为任意类型)
|
||||||
msg_type: str = "notify" # "request" | "response" | "notify" | "negotiate"
|
msg_type: str = "notify" # "request" | "response" | "notify" | "negotiate"
|
||||||
ttl_seconds: int = 300 # 消息存活时间(秒)
|
ttl_seconds: int = 300 # 消息存活时间(秒)
|
||||||
|
|
||||||
|
|
@ -39,7 +38,7 @@ class AgentMessage:
|
||||||
def is_broadcast(self) -> bool:
|
def is_broadcast(self) -> bool:
|
||||||
return self.recipient is None
|
return self.recipient is None
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
ts = self.timestamp
|
ts = self.timestamp
|
||||||
if isinstance(ts, datetime):
|
if isinstance(ts, datetime):
|
||||||
ts = ts.isoformat()
|
ts = ts.isoformat()
|
||||||
|
|
@ -57,7 +56,7 @@ class AgentMessage:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> AgentMessage:
|
def from_dict(cls, data: dict[str, object]) -> AgentMessage:
|
||||||
ts = data.get("timestamp", "")
|
ts = data.get("timestamp", "")
|
||||||
if isinstance(ts, str) and ts:
|
if isinstance(ts, str) and ts:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Awaitable, Protocol as TypingProtocol, runtime_checkable
|
from typing import Callable, Awaitable, Protocol as TypingProtocol, runtime_checkable
|
||||||
|
|
||||||
from agentkit.bus.message import AgentMessage
|
from agentkit.bus.message import AgentMessage
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Callable, Awaitable
|
||||||
|
|
||||||
from agentkit.bus.message import AgentMessage
|
from agentkit.bus.message import AgentMessage
|
||||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||||
|
|
@ -32,12 +32,12 @@ class RedisMessageBus:
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._consumer_group = consumer_group
|
self._consumer_group = consumer_group
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._redis: Any = None
|
self._redis: object = None
|
||||||
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
||||||
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
||||||
self._consumer_tasks: dict[str, asyncio.Task] = {}
|
self._consumer_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
async def _get_redis(self) -> Any:
|
async def _get_redis(self) -> object:
|
||||||
"""获取 Redis 连接(懒初始化)。"""
|
"""获取 Redis 连接(懒初始化)。"""
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
@ -151,7 +151,7 @@ class RedisMessageBus:
|
||||||
|
|
||||||
async def _handle_failed_message(
|
async def _handle_failed_message(
|
||||||
self,
|
self,
|
||||||
redis: Any,
|
redis: object,
|
||||||
stream_key: str,
|
stream_key: str,
|
||||||
msg_id: str,
|
msg_id: str,
|
||||||
fields: dict,
|
fields: dict,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import caldav
|
import caldav
|
||||||
from icalendar import Calendar, Event
|
from icalendar import Calendar, Event
|
||||||
|
|
@ -33,7 +32,7 @@ from agentkit.calendar.sync.base import AbstractSyncProvider
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Async callback signature: (event_type: str, payload: dict) -> None
|
# Async callback signature: (event_type: str, payload: dict) -> None
|
||||||
NotifyCallback = Callable[[str, dict[str, Any]], Awaitable[None]]
|
NotifyCallback = Callable[[str, dict[str, object]], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
def _parse_iso(dt_str: str) -> datetime:
|
def _parse_iso(dt_str: str) -> datetime:
|
||||||
|
|
@ -51,7 +50,7 @@ def _to_iso_utc(dt: datetime) -> str:
|
||||||
return dt.astimezone(timezone.utc).isoformat()
|
return dt.astimezone(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
def _extract_dt(component: Any, key: str) -> tuple[str, bool]:
|
def _extract_dt(component: object, key: str) -> tuple[str, bool]:
|
||||||
"""Extract date/datetime from icalendar component. Returns (iso, is_all_day)."""
|
"""Extract date/datetime from icalendar component. Returns (iso, is_all_day)."""
|
||||||
prop = component.get(key)
|
prop = component.get(key)
|
||||||
if prop is None:
|
if prop is None:
|
||||||
|
|
@ -73,7 +72,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_path: str | Path | None = None,
|
db_path: str | Path | None = None,
|
||||||
client_factory: Callable[[ExternalCalendarConfig], Any] | None = None,
|
client_factory: Callable[[ExternalCalendarConfig], object] | None = None,
|
||||||
notify_callback: NotifyCallback | None = None,
|
notify_callback: NotifyCallback | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.db_path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
|
self.db_path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
|
||||||
|
|
@ -87,7 +86,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
|
||||||
# Client construction
|
# Client construction
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _make_client(self, config: ExternalCalendarConfig) -> Any:
|
def _make_client(self, config: ExternalCalendarConfig) -> object:
|
||||||
"""Build a caldav.DAVClient from config credentials."""
|
"""Build a caldav.DAVClient from config credentials."""
|
||||||
if self._client_factory is not None:
|
if self._client_factory is not None:
|
||||||
return self._client_factory(config)
|
return self._client_factory(config)
|
||||||
|
|
@ -103,7 +102,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_calendar(self, config: ExternalCalendarConfig) -> Any:
|
def _get_calendar(self, config: ExternalCalendarConfig) -> object:
|
||||||
"""Connect and return the first calendar from the principal."""
|
"""Connect and return the first calendar from the principal."""
|
||||||
client = self._make_client(config)
|
client = self._make_client(config)
|
||||||
principal = client.principal()
|
principal = client.principal()
|
||||||
|
|
@ -166,7 +165,7 @@ class CalDAVSyncProvider(AbstractSyncProvider):
|
||||||
events.append(parsed)
|
events.append(parsed)
|
||||||
return events
|
return events
|
||||||
|
|
||||||
def _parse_caldav_event(self, caldav_event: Any, user_id: str) -> CalendarEvent | None:
|
def _parse_caldav_event(self, caldav_event: object, user_id: str) -> CalendarEvent | None:
|
||||||
"""Convert a caldav.Event to a CalendarEvent dataclass."""
|
"""Convert a caldav.Event to a CalendarEvent dataclass."""
|
||||||
try:
|
try:
|
||||||
ical_data = caldav_event.data
|
ical_data = caldav_event.data
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import date, datetime, timezone
|
from datetime import date, datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from icalendar import Calendar, Event
|
from icalendar import Calendar, Event
|
||||||
from icalendar.prop import vRecur
|
from icalendar.prop import vRecur
|
||||||
|
|
@ -38,7 +37,7 @@ def _parse_iso(dt_str: str) -> datetime:
|
||||||
return dt.astimezone(timezone.utc)
|
return dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def _extract_dt(component: Any, key: str) -> tuple[str, bool]:
|
def _extract_dt(component: object, key: str) -> tuple[str, bool]:
|
||||||
"""Extract a date/datetime property from an icalendar component.
|
"""Extract a date/datetime property from an icalendar component.
|
||||||
|
|
||||||
Returns ``(iso_string, is_all_day)``. ``is_all_day`` is True when the
|
Returns ``(iso_string, is_all_day)``. ``is_all_day`` is True when the
|
||||||
|
|
@ -59,7 +58,7 @@ class ICSProvider:
|
||||||
def __init__(self, service: CalendarService) -> None:
|
def __init__(self, service: CalendarService) -> None:
|
||||||
self.service = service
|
self.service = service
|
||||||
|
|
||||||
async def import_ics(self, ics_bytes: bytes, user_id: str) -> dict[str, Any]:
|
async def import_ics(self, ics_bytes: bytes, user_id: str) -> dict[str, object]:
|
||||||
"""Parse ICS bytes and create events for *user_id*.
|
"""Parse ICS bytes and create events for *user_id*.
|
||||||
|
|
||||||
Returns ``{"imported": N, "skipped": M, "errors": [...]}``.
|
Returns ``{"imported": N, "skipped": M, "errors": [...]}``.
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.calendar.db import (
|
from agentkit.calendar.db import (
|
||||||
DEFAULT_CALENDAR_DB_PATH,
|
DEFAULT_CALENDAR_DB_PATH,
|
||||||
|
|
@ -31,7 +30,7 @@ from agentkit.calendar.sync.caldav_provider import CalDAVSyncProvider
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Async callback signature: (event_type: str, payload: dict) -> None
|
# Async callback signature: (event_type: str, payload: dict) -> None
|
||||||
NotifyCallback = Callable[[str, dict[str, Any]], Awaitable[None]]
|
NotifyCallback = Callable[[str, dict[str, object]], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
class SyncManager:
|
class SyncManager:
|
||||||
|
|
@ -64,7 +63,7 @@ class SyncManager:
|
||||||
# Sync orchestration
|
# Sync orchestration
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def sync_all(self, user_id: str) -> dict[str, Any]:
|
async def sync_all(self, user_id: str) -> dict[str, object]:
|
||||||
"""Sync all external calendar configs for a user.
|
"""Sync all external calendar configs for a user.
|
||||||
|
|
||||||
Returns ``{"synced": N, "errors": [...]}``.
|
Returns ``{"synced": N, "errors": [...]}``.
|
||||||
|
|
@ -81,7 +80,7 @@ class SyncManager:
|
||||||
logger.warning("Sync failed for config %s: %s", config.id, e)
|
logger.warning("Sync failed for config %s: %s", config.id, e)
|
||||||
return {"synced": synced, "errors": errors}
|
return {"synced": synced, "errors": errors}
|
||||||
|
|
||||||
async def sync_provider(self, config_id: str) -> dict[str, Any]:
|
async def sync_provider(self, config_id: str) -> dict[str, object]:
|
||||||
"""Sync a single external calendar config by ID.
|
"""Sync a single external calendar config by ID.
|
||||||
|
|
||||||
Pulls remote changes, then pushes local changes modified since
|
Pulls remote changes, then pushes local changes modified since
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import parse_qs, quote, urlparse
|
from urllib.parse import parse_qs, quote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -37,7 +36,7 @@ from agentkit.calendar.sync.base import AbstractSyncProvider
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Async callback signature: (event_type: str, payload: dict) -> None
|
# Async callback signature: (event_type: str, payload: dict) -> None
|
||||||
NotifyCallback = Callable[[str, dict[str, Any]], Awaitable[None]]
|
NotifyCallback = Callable[[str, dict[str, object]], Awaitable[None]]
|
||||||
|
|
||||||
GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||||
|
|
@ -70,7 +69,7 @@ def _parse_iso(dt_str: str) -> datetime:
|
||||||
return dt.astimezone(timezone.utc)
|
return dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def _outlook_dt_to_iso(dt_obj: dict[str, Any]) -> str:
|
def _outlook_dt_to_iso(dt_obj: dict[str, object]) -> str:
|
||||||
"""Convert Outlook dateTimeTimeZone to ISO 8601 UTC.
|
"""Convert Outlook dateTimeTimeZone to ISO 8601 UTC.
|
||||||
|
|
||||||
ponytail: assumes Graph returns UTC (no ``Prefer: outlook.timezone`` header
|
ponytail: assumes Graph returns UTC (no ``Prefer: outlook.timezone`` header
|
||||||
|
|
@ -102,7 +101,7 @@ def _iso_to_outlook_dt(iso_str: str, is_all_day: bool) -> dict[str, str]:
|
||||||
return {"dateTime": dt.strftime("%Y-%m-%dT%H:%M:%S"), "timeZone": "UTC"}
|
return {"dateTime": dt.strftime("%Y-%m-%dT%H:%M:%S"), "timeZone": "UTC"}
|
||||||
|
|
||||||
|
|
||||||
def _outlook_recurrence_to_rrule(recurrence: dict[str, Any] | None) -> str | None:
|
def _outlook_recurrence_to_rrule(recurrence: dict[str, object] | None) -> str | None:
|
||||||
"""Convert Outlook recurrence pattern to RRULE string."""
|
"""Convert Outlook recurrence pattern to RRULE string."""
|
||||||
if not recurrence:
|
if not recurrence:
|
||||||
return None
|
return None
|
||||||
|
|
@ -134,7 +133,7 @@ def _outlook_recurrence_to_rrule(recurrence: dict[str, Any] | None) -> str | Non
|
||||||
return ";".join(parts) if parts else None
|
return ";".join(parts) if parts else None
|
||||||
|
|
||||||
|
|
||||||
def _rrule_to_outlook_recurrence(rrule: str, start_date: str) -> dict[str, Any] | None:
|
def _rrule_to_outlook_recurrence(rrule: str, start_date: str) -> dict[str, object] | None:
|
||||||
"""Convert RRULE string to Outlook recurrence pattern.
|
"""Convert RRULE string to Outlook recurrence pattern.
|
||||||
|
|
||||||
``start_date`` is the event's start date in ``YYYY-MM-DD`` format (required
|
``start_date`` is the event's start date in ``YYYY-MM-DD`` format (required
|
||||||
|
|
@ -151,7 +150,7 @@ def _rrule_to_outlook_recurrence(rrule: str, start_date: str) -> dict[str, Any]
|
||||||
if not pattern_type:
|
if not pattern_type:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pattern: dict[str, Any] = {"type": pattern_type}
|
pattern: dict[str, object] = {"type": pattern_type}
|
||||||
|
|
||||||
interval = parts.get("INTERVAL")
|
interval = parts.get("INTERVAL")
|
||||||
pattern["interval"] = int(interval) if interval else 1
|
pattern["interval"] = int(interval) if interval else 1
|
||||||
|
|
@ -167,7 +166,7 @@ def _rrule_to_outlook_recurrence(rrule: str, start_date: str) -> dict[str, Any]
|
||||||
|
|
||||||
if count:
|
if count:
|
||||||
pattern["numberOfOccurrences"] = int(count)
|
pattern["numberOfOccurrences"] = int(count)
|
||||||
range_obj: dict[str, Any] = {
|
range_obj: dict[str, object] = {
|
||||||
"type": "numbered",
|
"type": "numbered",
|
||||||
"startDate": start_date,
|
"startDate": start_date,
|
||||||
"numberOfOccurrences": int(count),
|
"numberOfOccurrences": int(count),
|
||||||
|
|
@ -202,7 +201,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_path: str | Path | None = None,
|
db_path: str | Path | None = None,
|
||||||
client_factory: Callable[[], Any] | None = None,
|
client_factory: Callable[[], object] | None = None,
|
||||||
notify_callback: NotifyCallback | None = None,
|
notify_callback: NotifyCallback | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.db_path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
|
self.db_path = Path(db_path) if db_path is not None else DEFAULT_CALENDAR_DB_PATH
|
||||||
|
|
@ -216,19 +215,19 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
# Client / auth
|
# Client / auth
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _get_client(self) -> Any:
|
def _get_client(self) -> object:
|
||||||
"""Return an httpx.AsyncClient (real or mock from factory)."""
|
"""Return an httpx.AsyncClient (real or mock from factory)."""
|
||||||
if self._client_factory is not None:
|
if self._client_factory is not None:
|
||||||
return self._client_factory()
|
return self._client_factory()
|
||||||
return httpx.AsyncClient(timeout=30.0)
|
return httpx.AsyncClient(timeout=30.0)
|
||||||
|
|
||||||
def _load_creds(self, config: ExternalCalendarConfig) -> dict[str, Any]:
|
def _load_creds(self, config: ExternalCalendarConfig) -> dict[str, object]:
|
||||||
return json.loads(config.credentials) if config.credentials else {}
|
return json.loads(config.credentials) if config.credentials else {}
|
||||||
|
|
||||||
def _save_creds(self, config: ExternalCalendarConfig, creds: dict[str, Any]) -> None:
|
def _save_creds(self, config: ExternalCalendarConfig, creds: dict[str, object]) -> None:
|
||||||
config.credentials = json.dumps(creds)
|
config.credentials = json.dumps(creds)
|
||||||
|
|
||||||
async def _refresh_token(self, client: Any, config: ExternalCalendarConfig) -> dict[str, Any]:
|
async def _refresh_token(self, client: object, config: ExternalCalendarConfig) -> dict[str, object]:
|
||||||
"""Refresh the access token using the refresh_token grant.
|
"""Refresh the access token using the refresh_token grant.
|
||||||
|
|
||||||
Posts to the Azure AD token endpoint, updates ``config.credentials``
|
Posts to the Azure AD token endpoint, updates ``config.credentials``
|
||||||
|
|
@ -271,13 +270,13 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
|
|
||||||
async def _request(
|
async def _request(
|
||||||
self,
|
self,
|
||||||
client: Any,
|
client: object,
|
||||||
config: ExternalCalendarConfig,
|
config: ExternalCalendarConfig,
|
||||||
method: str,
|
method: str,
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
json_body: dict[str, Any] | None = None,
|
json_body: dict[str, object] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Make an authenticated Graph API request with 401 auto-refresh + retry."""
|
"""Make an authenticated Graph API request with 401 auto-refresh + retry."""
|
||||||
creds = self._load_creds(config)
|
creds = self._load_creds(config)
|
||||||
headers = {"Authorization": f"Bearer {creds.get('access_token', '')}"}
|
headers = {"Authorization": f"Bearer {creds.get('access_token', '')}"}
|
||||||
|
|
@ -330,7 +329,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _pull_delta(
|
async def _pull_delta(
|
||||||
self, client: Any, config: ExternalCalendarConfig
|
self, client: object, config: ExternalCalendarConfig
|
||||||
) -> tuple[list[CalendarEvent], str | None]:
|
) -> tuple[list[CalendarEvent], str | None]:
|
||||||
"""Call /me/calendarView/delta. Returns (events, delta_token)."""
|
"""Call /me/calendarView/delta. Returns (events, delta_token)."""
|
||||||
url = self._build_delta_url(config)
|
url = self._build_delta_url(config)
|
||||||
|
|
@ -356,7 +355,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
end = (datetime.now(timezone.utc) + timedelta(days=90)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
end = (datetime.now(timezone.utc) + timedelta(days=90)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
return f"{GRAPH_BASE}/me/calendarView/delta?startDateTime={start}&endDateTime={end}"
|
return f"{GRAPH_BASE}/me/calendarView/delta?startDateTime={start}&endDateTime={end}"
|
||||||
|
|
||||||
def _parse_outlook_event(self, raw: dict[str, Any], user_id: str) -> CalendarEvent | None:
|
def _parse_outlook_event(self, raw: dict[str, object], user_id: str) -> CalendarEvent | None:
|
||||||
"""Convert a Graph event JSON to a CalendarEvent dataclass."""
|
"""Convert a Graph event JSON to a CalendarEvent dataclass."""
|
||||||
eid = raw.get("id")
|
eid = raw.get("id")
|
||||||
if not eid:
|
if not eid:
|
||||||
|
|
@ -459,7 +458,7 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _push_single(
|
async def _push_single(
|
||||||
self, client: Any, config: ExternalCalendarConfig, event: CalendarEvent
|
self, client: object, config: ExternalCalendarConfig, event: CalendarEvent
|
||||||
) -> CalendarEvent:
|
) -> CalendarEvent:
|
||||||
"""Push a single event to Outlook, return event with external_id set."""
|
"""Push a single event to Outlook, return event with external_id set."""
|
||||||
body = self._event_to_outlook(event)
|
body = self._event_to_outlook(event)
|
||||||
|
|
@ -485,9 +484,9 @@ class OutlookSyncProvider(AbstractSyncProvider):
|
||||||
event.external_provider = "outlook"
|
event.external_provider = "outlook"
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _event_to_outlook(self, event: CalendarEvent) -> dict[str, Any]:
|
def _event_to_outlook(self, event: CalendarEvent) -> dict[str, object]:
|
||||||
"""Convert CalendarEvent to Outlook Graph event JSON."""
|
"""Convert CalendarEvent to Outlook Graph event JSON."""
|
||||||
body: dict[str, Any] = {
|
body: dict[str, object] = {
|
||||||
"subject": event.title,
|
"subject": event.title,
|
||||||
"start": _iso_to_outlook_dt(event.start_time, event.is_all_day),
|
"start": _iso_to_outlook_dt(event.start_time, event.is_all_day),
|
||||||
"end": _iso_to_outlook_dt(event.end_time, event.is_all_day),
|
"end": _iso_to_outlook_dt(event.end_time, event.is_all_day),
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -39,7 +38,7 @@ class IncomingMessage:
|
||||||
user_id: str # 平台用户 ID
|
user_id: str # 平台用户 ID
|
||||||
chat_id: str # 群组/会话 ID
|
chat_id: str # 群组/会话 ID
|
||||||
content: str # 消息文本
|
content: str # 消息文本
|
||||||
raw_event: dict[str, Any] = field(default_factory=dict) # 原始事件
|
raw_event: dict[str, object] = field(default_factory=dict) # 原始事件
|
||||||
timestamp: str = ""
|
timestamp: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -147,7 +146,7 @@ class DingTalkMessageAdapter(MessageAdapter):
|
||||||
ValueError: 事件 body 不是合法 JSON。
|
ValueError: 事件 body 不是合法 JSON。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data: dict[str, Any] = json.loads(body)
|
data: dict[str, object] = json.loads(body)
|
||||||
except json.JSONDecodeError as exc:
|
except json.JSONDecodeError as exc:
|
||||||
raise ValueError(f"钉钉事件 body 不是合法 JSON: {exc}") from exc
|
raise ValueError(f"钉钉事件 body 不是合法 JSON: {exc}") from exc
|
||||||
|
|
||||||
|
|
@ -169,7 +168,7 @@ class DingTalkMessageAdapter(MessageAdapter):
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_content(self, data: dict[str, Any]) -> str:
|
def _extract_content(self, data: dict[str, object]) -> str:
|
||||||
"""从钉钉事件提取文本内容。
|
"""从钉钉事件提取文本内容。
|
||||||
|
|
||||||
- text 类型:解析 ``text.content``,剥离 @ 机器人前缀。
|
- text 类型:解析 ``text.content``,剥离 @ 机器人前缀。
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -142,7 +141,7 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
ValueError: 事件结构无法解析。
|
ValueError: 事件结构无法解析。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data: dict[str, Any] = json.loads(body)
|
data: dict[str, object] = json.loads(body)
|
||||||
except json.JSONDecodeError as exc:
|
except json.JSONDecodeError as exc:
|
||||||
raise ValueError(f"飞书事件 body 不是合法 JSON: {exc}") from exc
|
raise ValueError(f"飞书事件 body 不是合法 JSON: {exc}") from exc
|
||||||
|
|
||||||
|
|
@ -184,7 +183,7 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decrypt_event(self, encrypt_b64: str) -> dict[str, Any]:
|
def _decrypt_event(self, encrypt_b64: str) -> dict[str, object]:
|
||||||
"""AES-256-CBC 解密飞书加密事件。
|
"""AES-256-CBC 解密飞书加密事件。
|
||||||
|
|
||||||
飞书协议:
|
飞书协议:
|
||||||
|
|
@ -216,7 +215,7 @@ class FeishuMessageAdapter(MessageAdapter):
|
||||||
|
|
||||||
return json.loads(plaintext.decode("utf-8"))
|
return json.loads(plaintext.decode("utf-8"))
|
||||||
|
|
||||||
def _extract_content(self, message: dict[str, Any]) -> str:
|
def _extract_content(self, message: dict[str, object]) -> str:
|
||||||
"""从飞书 message 字段提取文本内容。
|
"""从飞书 message 字段提取文本内容。
|
||||||
|
|
||||||
- text 类型:解析 ``content`` JSON 中的 ``text`` 字段,剥离 @ 提及标记。
|
- text 类型:解析 ``content`` JSON 中的 ``text`` 字段,剥离 @ 提及标记。
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
@ -134,7 +133,7 @@ class SlackMessageAdapter(MessageAdapter):
|
||||||
"""
|
"""
|
||||||
# 优先尝试 JSON 解析(Events API / URL 验证)
|
# 优先尝试 JSON 解析(Events API / URL 验证)
|
||||||
try:
|
try:
|
||||||
data: dict[str, Any] = json.loads(body)
|
data: dict[str, object] = json.loads(body)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
|
|
@ -144,7 +143,7 @@ class SlackMessageAdapter(MessageAdapter):
|
||||||
# 非 JSON — 视为 Slash Command(form-encoded)
|
# 非 JSON — 视为 Slash Command(form-encoded)
|
||||||
return self._parse_slash_command(body)
|
return self._parse_slash_command(body)
|
||||||
|
|
||||||
def _parse_event(self, data: dict[str, Any]) -> IncomingMessage:
|
def _parse_event(self, data: dict[str, object]) -> IncomingMessage:
|
||||||
"""解析 Events API 事件。"""
|
"""解析 Events API 事件。"""
|
||||||
# URL 验证流程
|
# URL 验证流程
|
||||||
if data.get("type") == "url_verification":
|
if data.get("type") == "url_verification":
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json as _json
|
import json as _json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
@ -87,13 +87,13 @@ def _build_client(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _emit_json(data: Any) -> None:
|
def _emit_json(data: object) -> None:
|
||||||
rprint(_json.dumps(data, indent=2, ensure_ascii=False, default=str))
|
rprint(_json.dumps(data, indent=2, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
def _safe_get(
|
def _safe_get(
|
||||||
client: AdminHttpClient, path: str, server_url: str, params: dict | None = None
|
client: AdminHttpClient, path: str, server_url: str, params: dict | None = None
|
||||||
) -> Any:
|
) -> object:
|
||||||
try:
|
try:
|
||||||
return client.get(path, params=params)
|
return client.get(path, params=params)
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|
@ -104,7 +104,7 @@ def _safe_get(
|
||||||
|
|
||||||
def _safe_post(
|
def _safe_post(
|
||||||
client: AdminHttpClient, path: str, server_url: str, body: dict | None = None
|
client: AdminHttpClient, path: str, server_url: str, body: dict | None = None
|
||||||
) -> Any:
|
) -> object:
|
||||||
try:
|
try:
|
||||||
return client.post(path, json=body)
|
return client.post(path, json=body)
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|
@ -115,7 +115,7 @@ def _safe_post(
|
||||||
|
|
||||||
def _safe_patch(
|
def _safe_patch(
|
||||||
client: AdminHttpClient, path: str, server_url: str, body: dict | None = None
|
client: AdminHttpClient, path: str, server_url: str, body: dict | None = None
|
||||||
) -> Any:
|
) -> object:
|
||||||
try:
|
try:
|
||||||
return client.patch(path, json=body)
|
return client.patch(path, json=body)
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|
@ -124,7 +124,7 @@ def _safe_patch(
|
||||||
_handle_http_error(e, server_url)
|
_handle_http_error(e, server_url)
|
||||||
|
|
||||||
|
|
||||||
def _safe_put(client: AdminHttpClient, path: str, server_url: str, body: dict | None = None) -> Any:
|
def _safe_put(client: AdminHttpClient, path: str, server_url: str, body: dict | None = None) -> object:
|
||||||
try:
|
try:
|
||||||
return client.put(path, json=body)
|
return client.put(path, json=body)
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|
@ -135,7 +135,7 @@ def _safe_put(client: AdminHttpClient, path: str, server_url: str, body: dict |
|
||||||
|
|
||||||
def _safe_delete(
|
def _safe_delete(
|
||||||
client: AdminHttpClient, path: str, server_url: str, params: dict | None = None
|
client: AdminHttpClient, path: str, server_url: str, params: dict | None = None
|
||||||
) -> Any:
|
) -> object:
|
||||||
try:
|
try:
|
||||||
return client.delete(path, params=params)
|
return client.delete(path, params=params)
|
||||||
except httpx.ConnectError:
|
except httpx.ConnectError:
|
||||||
|
|
@ -252,7 +252,7 @@ def dept_update(
|
||||||
json_output: bool = JsonFlag,
|
json_output: bool = JsonFlag,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a department's name or description."""
|
"""Update a department's name or description."""
|
||||||
body: dict[str, Any] = {}
|
body: dict[str, object] = {}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
body["name"] = name
|
body["name"] = name
|
||||||
if description is not None:
|
if description is not None:
|
||||||
|
|
@ -576,7 +576,7 @@ def user_list(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List users, optionally filtered by department."""
|
"""List users, optionally filtered by department."""
|
||||||
client = _build_client(server_url, token, api_key)
|
client = _build_client(server_url, token, api_key)
|
||||||
params: dict[str, Any] = {}
|
params: dict[str, object] = {}
|
||||||
if department_id:
|
if department_id:
|
||||||
params["department_id"] = department_id
|
params["department_id"] = department_id
|
||||||
users = _safe_get(client, "/api/v1/admin/users", client.base_url, params=params or None)
|
users = _safe_get(client, "/api/v1/admin/users", client.base_url, params=params or None)
|
||||||
|
|
@ -620,7 +620,7 @@ def user_create(
|
||||||
json_output: bool = JsonFlag,
|
json_output: bool = JsonFlag,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new user."""
|
"""Create a new user."""
|
||||||
body: dict[str, Any] = {
|
body: dict[str, object] = {
|
||||||
"username": username,
|
"username": username,
|
||||||
"email": email,
|
"email": email,
|
||||||
"password": password,
|
"password": password,
|
||||||
|
|
@ -653,7 +653,7 @@ def user_update(
|
||||||
json_output: bool = JsonFlag,
|
json_output: bool = JsonFlag,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user's role or active flag."""
|
"""Update a user's role or active flag."""
|
||||||
body: dict[str, Any] = {}
|
body: dict[str, object] = {}
|
||||||
if role is not None:
|
if role is not None:
|
||||||
body["role"] = role
|
body["role"] = role
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
|
|
@ -818,7 +818,7 @@ def llm_add_provider(
|
||||||
json_output: bool = JsonFlag,
|
json_output: bool = JsonFlag,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new LLM provider."""
|
"""Create a new LLM provider."""
|
||||||
body: dict[str, Any] = {
|
body: dict[str, object] = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"type": type,
|
"type": type,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
@ -850,7 +850,7 @@ def llm_update_provider(
|
||||||
json_output: bool = JsonFlag,
|
json_output: bool = JsonFlag,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update an LLM provider's configuration."""
|
"""Update an LLM provider's configuration."""
|
||||||
body: dict[str, Any] = {}
|
body: dict[str, object] = {}
|
||||||
if type is not None:
|
if type is not None:
|
||||||
body["type"] = type
|
body["type"] = type
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
|
|
@ -1123,7 +1123,7 @@ def kb_list_documents(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List KB documents."""
|
"""List KB documents."""
|
||||||
client = _build_client(server_url, token, api_key)
|
client = _build_client(server_url, token, api_key)
|
||||||
params: dict[str, Any] = {}
|
params: dict[str, object] = {}
|
||||||
if source_id:
|
if source_id:
|
||||||
params["source_id"] = source_id
|
params["source_id"] = source_id
|
||||||
if department_id:
|
if department_id:
|
||||||
|
|
@ -1176,7 +1176,7 @@ def kb_upload(
|
||||||
rprint(f"[red]Error: File not found: {content_file}[/red]")
|
rprint(f"[red]Error: File not found: {content_file}[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
content = path.read_text(encoding="utf-8")
|
content = path.read_text(encoding="utf-8")
|
||||||
body: dict[str, Any] = {
|
body: dict[str, object] = {
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"content": content,
|
"content": content,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
|
|
@ -1255,8 +1255,8 @@ def _usage_params(
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
start: Optional[str],
|
start: Optional[str],
|
||||||
end: Optional[str],
|
end: Optional[str],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
params: dict[str, Any] = {}
|
params: dict[str, object] = {}
|
||||||
if department_id:
|
if department_id:
|
||||||
params["department_id"] = department_id
|
params["department_id"] = department_id
|
||||||
if user_id:
|
if user_id:
|
||||||
|
|
@ -1386,7 +1386,7 @@ def usage_top_users(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Show top-N users by token usage."""
|
"""Show top-N users by token usage."""
|
||||||
client = _build_client(server_url, token, api_key)
|
client = _build_client(server_url, token, api_key)
|
||||||
params: dict[str, Any] = {"limit": limit}
|
params: dict[str, object] = {"limit": limit}
|
||||||
if department_id:
|
if department_id:
|
||||||
params["department_id"] = department_id
|
params["department_id"] = department_id
|
||||||
if start:
|
if start:
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -62,7 +60,7 @@ class AdminHttpClient:
|
||||||
4. Hard-coded defaults (server URL only)
|
4. Hard-coded defaults (server URL only)
|
||||||
"""
|
"""
|
||||||
path = Path(config_path) if config_path else DEFAULT_CONFIG_PATH
|
path = Path(config_path) if config_path else DEFAULT_CONFIG_PATH
|
||||||
file_cfg: dict[str, Any] = {}
|
file_cfg: dict[str, object] = {}
|
||||||
if path.exists():
|
if path.exists():
|
||||||
try:
|
try:
|
||||||
with path.open(encoding="utf-8") as f:
|
with path.open(encoding="utf-8") as f:
|
||||||
|
|
@ -106,8 +104,8 @@ class AdminHttpClient:
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
*,
|
*,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, object] | None = None,
|
||||||
json: dict[str, Any] | None = None,
|
json: dict[str, object] | None = None,
|
||||||
timeout: float = DEFAULT_TIMEOUT,
|
timeout: float = DEFAULT_TIMEOUT,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
url = f"{self._base_url}{path}"
|
url = f"{self._base_url}{path}"
|
||||||
|
|
@ -130,28 +128,28 @@ class AdminHttpClient:
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
return self._base_url
|
return self._base_url
|
||||||
|
|
||||||
def get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
def get(self, path: str, params: dict[str, object] | None = None) -> object:
|
||||||
resp = self._request("GET", path, params=params)
|
resp = self._request("GET", path, params=params)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
def get_text(self, path: str, params: dict[str, Any] | None = None) -> str:
|
def get_text(self, path: str, params: dict[str, object] | None = None) -> str:
|
||||||
"""GET returning response text (for CSV exports)."""
|
"""GET returning response text (for CSV exports)."""
|
||||||
resp = self._request("GET", path, params=params)
|
resp = self._request("GET", path, params=params)
|
||||||
return resp.text
|
return resp.text
|
||||||
|
|
||||||
def post(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
def post(self, path: str, json: dict[str, object] | None = None) -> object:
|
||||||
resp = self._request("POST", path, json=json)
|
resp = self._request("POST", path, json=json)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
def put(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
def put(self, path: str, json: dict[str, object] | None = None) -> object:
|
||||||
resp = self._request("PUT", path, json=json)
|
resp = self._request("PUT", path, json=json)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
def patch(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
def patch(self, path: str, json: dict[str, object] | None = None) -> object:
|
||||||
resp = self._request("PATCH", path, json=json)
|
resp = self._request("PATCH", path, json=json)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
def delete(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
def delete(self, path: str, params: dict[str, object] | None = None) -> object:
|
||||||
resp = self._request("DELETE", path, params=params)
|
resp = self._request("DELETE", path, params=params)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ When no agentkit.yaml exists, this wizard guides the user through:
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
@ -22,7 +21,7 @@ from agentkit.server.config import find_config_path
|
||||||
|
|
||||||
# ── Provider presets ──────────────────────────────────────────────────
|
# ── Provider presets ──────────────────────────────────────────────────
|
||||||
|
|
||||||
PROVIDER_PRESETS: dict[str, dict[str, Any]] = {
|
PROVIDER_PRESETS: dict[str, dict[str, object]] = {
|
||||||
"deepseek": {
|
"deepseek": {
|
||||||
"name": "DeepSeek",
|
"name": "DeepSeek",
|
||||||
"env_key": "DEEPSEEK_API_KEY",
|
"env_key": "DEEPSEEK_API_KEY",
|
||||||
|
|
@ -144,7 +143,7 @@ def run_onboarding(
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
existing_config_path = find_config_path(config_arg)
|
existing_config_path = find_config_path(config_arg)
|
||||||
existing_config: dict[str, Any] | None = None
|
existing_config: dict[str, object] | None = None
|
||||||
if existing_config_path:
|
if existing_config_path:
|
||||||
with open(existing_config_path, encoding="utf-8") as f:
|
with open(existing_config_path, encoding="utf-8") as f:
|
||||||
existing_config = yaml.safe_load(f) or {}
|
existing_config = yaml.safe_load(f) or {}
|
||||||
|
|
@ -220,7 +219,7 @@ def run_onboarding(
|
||||||
)
|
)
|
||||||
selected_model = available_models[int(model_choice) - 1]
|
selected_model = available_models[int(model_choice) - 1]
|
||||||
# Rebuild models dict: selected model gets "default" alias
|
# Rebuild models dict: selected model gets "default" alias
|
||||||
updated_models: dict[str, Any] = {}
|
updated_models: dict[str, object] = {}
|
||||||
for model, conf in preset["models"].items():
|
for model, conf in preset["models"].items():
|
||||||
if model == selected_model:
|
if model == selected_model:
|
||||||
updated_models[model] = {**conf, "alias": "default"}
|
updated_models[model] = {**conf, "alias": "default"}
|
||||||
|
|
@ -236,7 +235,7 @@ def run_onboarding(
|
||||||
|
|
||||||
# ── Step 3: Optional — add a second provider ─────────────────
|
# ── Step 3: Optional — add a second provider ─────────────────
|
||||||
env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()}
|
env_vars: dict[str, str] = {preset["env_key"]: api_key.strip()}
|
||||||
providers_config: dict[str, Any] = {
|
providers_config: dict[str, object] = {
|
||||||
selected_key: {
|
selected_key: {
|
||||||
"api_key": f"${{{preset['env_key']}}}",
|
"api_key": f"${{{preset['env_key']}}}",
|
||||||
"base_url": preset["base_url"],
|
"base_url": preset["base_url"],
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openpyxl import Workbook
|
from openpyxl import Workbook
|
||||||
|
|
||||||
|
|
@ -43,7 +42,7 @@ class ExcelRenderer:
|
||||||
|
|
||||||
return self._render_markdown(markdown_content, output_path)
|
return self._render_markdown(markdown_content, output_path)
|
||||||
|
|
||||||
def _render_json(self, data: dict[str, list[list[Any]]], output_path: Path) -> Path:
|
def _render_json(self, data: dict[str, list[list[object]]], output_path: Path) -> Path:
|
||||||
"""Render JSON dict {sheet_name: rows} into a multi-sheet workbook."""
|
"""Render JSON dict {sheet_name: rows} into a multi-sheet workbook."""
|
||||||
wb = Workbook()
|
wb = Workbook()
|
||||||
# Remove the default sheet — we'll create named ones
|
# Remove the default sheet — we'll create named ones
|
||||||
|
|
@ -108,7 +107,7 @@ class ExcelRenderer:
|
||||||
wb.save(str(output_path))
|
wb.save(str(output_path))
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
def _fill_sheet_from_table(self, ws: Any, table_lines: list[str]) -> None:
|
def _fill_sheet_from_table(self, ws: object, table_lines: list[str]) -> None:
|
||||||
"""Parse GFM table lines and write rows into a worksheet."""
|
"""Parse GFM table lines and write rows into a worksheet."""
|
||||||
for idx, line in enumerate(table_lines):
|
for idx, line in enumerate(table_lines):
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from reportlab.lib import colors
|
from reportlab.lib import colors
|
||||||
from reportlab.lib.pagesizes import A4
|
from reportlab.lib.pagesizes import A4
|
||||||
|
|
@ -116,7 +115,7 @@ class PDFRenderer:
|
||||||
)
|
)
|
||||||
|
|
||||||
styles = self._build_styles()
|
styles = self._build_styles()
|
||||||
flowables: list[Any] = []
|
flowables: list[object] = []
|
||||||
lines = markdown_content.splitlines()
|
lines = markdown_content.splitlines()
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from docxtpl import DocxTemplate
|
from docxtpl import DocxTemplate
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
|
|
@ -40,7 +39,7 @@ class TemplateRenderer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_template(
|
def render_template(
|
||||||
self, template_path: str | Path, data: dict[str, Any], output_path: Path
|
self, template_path: str | Path, data: dict[str, object], output_path: Path
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Fill a .docx template with data using Jinja2 sandbox.
|
"""Fill a .docx template with data using Jinja2 sandbox.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
|
|
@ -65,7 +65,7 @@ class MCPClient:
|
||||||
self._transport = transport
|
self._transport = transport
|
||||||
# U10 — 懒构造并缓存的 langchain client,避免每次 list_tools/call_tool
|
# U10 — 懒构造并缓存的 langchain client,避免每次 list_tools/call_tool
|
||||||
# 都新建 MultiServerMCPClient(stdio 传输下会反复 spawn 子进程)。
|
# 都新建 MultiServerMCPClient(stdio 传输下会反复 spawn 子进程)。
|
||||||
self._lc_client: Any = None
|
self._lc_client: "MultiServerMCPClient | None" = None
|
||||||
|
|
||||||
if transport is not None:
|
if transport is not None:
|
||||||
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
||||||
|
|
@ -76,13 +76,13 @@ class MCPClient:
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
self._langchain_config: dict[str, Any] | None = None
|
self._langchain_config: dict[str, object] | None = None
|
||||||
else:
|
else:
|
||||||
# 新 langchain 路径 — 解析 URL scheme 构建连接配置
|
# 新 langchain 路径 — 解析 URL scheme 构建连接配置
|
||||||
self._langchain_config = self._build_langchain_config(self._server_url, timeout)
|
self._langchain_config = self._build_langchain_config(self._server_url, timeout)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_langchain_config(server_url: str, timeout: float) -> dict[str, Any]:
|
def _build_langchain_config(server_url: str, timeout: float) -> dict[str, object]:
|
||||||
"""根据 URL scheme 构建 langchain-mcp-adapters 连接配置。
|
"""根据 URL scheme 构建 langchain-mcp-adapters 连接配置。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -150,7 +150,7 @@ class MCPClient:
|
||||||
server_url = ""
|
server_url = ""
|
||||||
return cls(server_url=server_url, transport=transport)
|
return cls(server_url=server_url, transport=transport)
|
||||||
|
|
||||||
async def _get_lc_client(self) -> Any:
|
async def _get_lc_client(self) -> "MultiServerMCPClient":
|
||||||
"""懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。
|
"""懒构造并缓存 langchain ``MultiServerMCPClient`` 实例。
|
||||||
|
|
||||||
首次调用时创建,后续返回缓存,避免每次 list_tools/call_tool 都新建
|
首次调用时创建,后续返回缓存,避免每次 list_tools/call_tool 都新建
|
||||||
|
|
@ -202,7 +202,7 @@ class MCPClient:
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_schema(tool: Any) -> dict[str, Any]:
|
def _extract_schema(tool: object) -> dict[str, object]:
|
||||||
"""从 LangChain Tool 提取 inputSchema(JSON Schema 格式)。
|
"""从 LangChain Tool 提取 inputSchema(JSON Schema 格式)。
|
||||||
|
|
||||||
LangChain 工具通常有 args_schema(pydantic model),回退到 tool.args dict。
|
LangChain 工具通常有 args_schema(pydantic model),回退到 tool.args dict。
|
||||||
|
|
@ -273,8 +273,8 @@ class MCPTool(Tool):
|
||||||
name: str,
|
name: str,
|
||||||
description: str,
|
description: str,
|
||||||
client: MCPClient,
|
client: MCPClient,
|
||||||
input_schema: dict[str, Any] | None = None,
|
input_schema: dict[str, object] | None = None,
|
||||||
output_schema: dict[str, Any] | None = None,
|
output_schema: dict[str, object] | None = None,
|
||||||
version: str = "1.0.0",
|
version: str = "1.0.0",
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ U14: 支持把已注册的 Skill 或专家团队封装成 ``Tool``,注册到
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from agentkit.skills.base import Skill
|
from agentkit.skills.base import Skill
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
@ -28,7 +28,7 @@ _DANGEROUS_TOOL_NAMES: frozenset[str] = frozenset(
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行器签名:(skill_or_team_name, input_text) -> 结果 dict
|
# 执行器签名:(skill_or_team_name, input_text) -> 结果 dict
|
||||||
SkillExecutor = Callable[[str, str], Awaitable[dict[str, Any]]]
|
SkillExecutor = Callable[[str, str], Awaitable[dict[str, object]]]
|
||||||
|
|
||||||
|
|
||||||
class PublisherRegistry:
|
class PublisherRegistry:
|
||||||
|
|
@ -90,7 +90,7 @@ class SkillMCPAdapter(Tool):
|
||||||
self._skill = skill
|
self._skill = skill
|
||||||
self._executor = executor
|
self._executor = executor
|
||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> dict[str, Any]:
|
async def execute(self, **kwargs: object) -> dict[str, object]:
|
||||||
"""调用 executor 执行技能;未配置或异常时返回错误 dict。"""
|
"""调用 executor 执行技能;未配置或异常时返回错误 dict。"""
|
||||||
input_text = kwargs.get("input", "")
|
input_text = kwargs.get("input", "")
|
||||||
if self._executor is None:
|
if self._executor is None:
|
||||||
|
|
@ -126,7 +126,7 @@ class TeamMCPAdapter(Tool):
|
||||||
self._team_name = team_name
|
self._team_name = team_name
|
||||||
self._executor = executor
|
self._executor = executor
|
||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> dict[str, Any]:
|
async def execute(self, **kwargs: object) -> dict[str, object]:
|
||||||
"""调用 executor 执行团队任务;未配置或异常时返回错误 dict。"""
|
"""调用 executor 执行团队任务;未配置或异常时返回错误 dict。"""
|
||||||
input_text = kwargs.get("input", "")
|
input_text = kwargs.get("input", "")
|
||||||
if self._executor is None:
|
if self._executor is None:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ U13: 重构为路由工厂 ``create_mcp_router()``,可挂载到主 FastAPI app
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from collections.abc import Callable
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
||||||
|
|
@ -41,7 +41,7 @@ _MCP_BLOCKED_TOOLS: frozenset[str] = frozenset(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_tool(tool: Tool) -> dict[str, Any]:
|
def _serialize_tool(tool: Tool) -> dict[str, object]:
|
||||||
"""将 Tool 序列化为 MCP 协议响应字典。"""
|
"""将 Tool 序列化为 MCP 协议响应字典。"""
|
||||||
return {
|
return {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
|
|
@ -51,8 +51,8 @@ def _serialize_tool(tool: Tool) -> dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_router(
|
def create_mcp_router(
|
||||||
tool_registry: Any = None,
|
tool_registry: object | None = None,
|
||||||
published_tools_getter: Any = None,
|
published_tools_getter: Callable[[], list[Tool]] | None = None,
|
||||||
) -> APIRouter:
|
) -> APIRouter:
|
||||||
"""构造 MCP 路由,挂载到主 app 的 ``/api/v1/mcp/`` 前缀下。
|
"""构造 MCP 路由,挂载到主 app 的 ``/api/v1/mcp/`` 前缀下。
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ def create_mcp_router(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@router.get("/tools/list")
|
@router.get("/tools/list")
|
||||||
async def list_tools(_user: dict = Depends(_mcp_member_auth)) -> dict[str, Any]:
|
async def list_tools(_user: dict = Depends(_mcp_member_auth)) -> dict[str, object]:
|
||||||
"""列出所有可用的 MCP 工具。"""
|
"""列出所有可用的 MCP 工具。"""
|
||||||
tools = _all_tools()
|
tools = _all_tools()
|
||||||
return {"tools": [_serialize_tool(t) for t in tools]}
|
return {"tools": [_serialize_tool(t) for t in tools]}
|
||||||
|
|
@ -114,7 +114,7 @@ def create_mcp_router(
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
request: dict,
|
request: dict,
|
||||||
_user: dict = Depends(_mcp_member_auth),
|
_user: dict = Depends(_mcp_member_auth),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""调用指定的 MCP 工具。"""
|
"""调用指定的 MCP 工具。"""
|
||||||
tool_name = request.get("name")
|
tool_name = request.get("name")
|
||||||
arguments = request.get("arguments", {})
|
arguments = request.get("arguments", {})
|
||||||
|
|
@ -145,7 +145,7 @@ def create_mcp_router(
|
||||||
async def jsonrpc_endpoint(
|
async def jsonrpc_endpoint(
|
||||||
request: Request,
|
request: Request,
|
||||||
_user: dict = Depends(_mcp_member_auth),
|
_user: dict = Depends(_mcp_member_auth),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""JSON-RPC 2.0 端点 — MCP 协议兼容。
|
"""JSON-RPC 2.0 端点 — MCP 协议兼容。
|
||||||
|
|
||||||
支持 methods: initialize, tools/list, tools/call。
|
支持 methods: initialize, tools/list, tools/call。
|
||||||
|
|
@ -225,7 +225,7 @@ class MCPServer:
|
||||||
的 ``/api/v1/mcp/`` 前缀下,复用主 app 的 JWT + API Key 认证。
|
的 ``/api/v1/mcp/`` 前缀下,复用主 app 的 JWT + API Key 认证。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tool_registry: Any = None, host: str = "0.0.0.0", port: int = 8080):
|
def __init__(self, tool_registry: object | None = None, host: str = "0.0.0.0", port: int = 8080):
|
||||||
self._tool_registry = tool_registry
|
self._tool_registry = tool_registry
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -52,7 +51,7 @@ class Transport(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
async def send_request(self, method: str, params: dict[str, object] | None = None) -> object:
|
||||||
"""发送 JSON-RPC 请求
|
"""发送 JSON-RPC 请求
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -65,7 +64,7 @@ class Transport(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def receive_response(self) -> dict[str, Any]:
|
async def receive_response(self) -> dict[str, object]:
|
||||||
"""接收响应
|
"""接收响应
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -92,7 +91,7 @@ class HTTPTransport(Transport):
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._pending: dict[int, asyncio.Future[dict[str, Any]]] = {}
|
self._pending: dict[int, asyncio.Future[dict[str, object]]] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
|
|
@ -131,7 +130,7 @@ class HTTPTransport(Transport):
|
||||||
self._request_id += 1
|
self._request_id += 1
|
||||||
return self._request_id
|
return self._request_id
|
||||||
|
|
||||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
async def send_request(self, method: str, params: dict[str, object] | None = None) -> object:
|
||||||
"""发送 JSON-RPC 请求并等待响应
|
"""发送 JSON-RPC 请求并等待响应
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -148,7 +147,7 @@ class HTTPTransport(Transport):
|
||||||
raise TransportError("Transport not connected")
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
request_id = self._next_request_id()
|
request_id = self._next_request_id()
|
||||||
message: dict[str, Any] = {
|
message: dict[str, object] = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"method": method,
|
"method": method,
|
||||||
|
|
@ -179,7 +178,7 @@ class HTTPTransport(Transport):
|
||||||
|
|
||||||
return data.get("result")
|
return data.get("result")
|
||||||
|
|
||||||
async def receive_response(self) -> dict[str, Any]:
|
async def receive_response(self) -> dict[str, object]:
|
||||||
"""接收响应
|
"""接收响应
|
||||||
|
|
||||||
对于 HTTPTransport,响应在 send_request 中同步返回。
|
对于 HTTPTransport,响应在 send_request 中同步返回。
|
||||||
|
|
@ -218,7 +217,7 @@ class SSETransport(Transport):
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._sse_task: asyncio.Task[None] | None = None
|
self._sse_task: asyncio.Task[None] | None = None
|
||||||
self._response_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
self._response_queue: asyncio.Queue[dict[str, object]] = asyncio.Queue()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -304,7 +303,7 @@ class SSETransport(Transport):
|
||||||
self._request_id += 1
|
self._request_id += 1
|
||||||
return self._request_id
|
return self._request_id
|
||||||
|
|
||||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
async def send_request(self, method: str, params: dict[str, object] | None = None) -> object:
|
||||||
"""通过 HTTP POST 发送 JSON-RPC 请求
|
"""通过 HTTP POST 发送 JSON-RPC 请求
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -321,7 +320,7 @@ class SSETransport(Transport):
|
||||||
raise TransportError("Transport not connected")
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
request_id = self._next_request_id()
|
request_id = self._next_request_id()
|
||||||
message: dict[str, Any] = {
|
message: dict[str, object] = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"method": method,
|
"method": method,
|
||||||
|
|
@ -352,7 +351,7 @@ class SSETransport(Transport):
|
||||||
|
|
||||||
return data.get("result")
|
return data.get("result")
|
||||||
|
|
||||||
async def receive_response(self) -> dict[str, Any]:
|
async def receive_response(self) -> dict[str, object]:
|
||||||
"""从 SSE 事件流接收响应
|
"""从 SSE 事件流接收响应
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -392,11 +391,11 @@ class StdioTransport(Transport):
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._process: asyncio.subprocess.Process | None = None
|
self._process: asyncio.subprocess.Process | None = None
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._pending: dict[int, asyncio.Future[Any]] = {}
|
self._pending: dict[int, asyncio.Future[object]] = {}
|
||||||
self._reader_task: asyncio.Task[None] | None = None
|
self._reader_task: asyncio.Task[None] | None = None
|
||||||
self._stderr_task: asyncio.Task[None] | None = None
|
self._stderr_task: asyncio.Task[None] | None = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
self._notifications: asyncio.Queue[dict[str, object]] = asyncio.Queue()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
|
|
@ -513,7 +512,7 @@ class StdioTransport(Transport):
|
||||||
|
|
||||||
logger.info("StdioTransport disconnected")
|
logger.info("StdioTransport disconnected")
|
||||||
|
|
||||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
async def send_request(self, method: str, params: dict[str, object] | None = None) -> object:
|
||||||
"""发送 JSON-RPC 请求并等待响应
|
"""发送 JSON-RPC 请求并等待响应
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -531,11 +530,11 @@ class StdioTransport(Transport):
|
||||||
return await self._send_request_internal(method, params)
|
return await self._send_request_internal(method, params)
|
||||||
|
|
||||||
async def _send_request_internal(
|
async def _send_request_internal(
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
self, method: str, params: dict[str, object] | None = None
|
||||||
) -> Any:
|
) -> object:
|
||||||
"""内部请求发送方法(connect 时也可调用)"""
|
"""内部请求发送方法(connect 时也可调用)"""
|
||||||
request_id = self._next_request_id()
|
request_id = self._next_request_id()
|
||||||
message: dict[str, Any] = {
|
message: dict[str, object] = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"method": method,
|
"method": method,
|
||||||
|
|
@ -546,7 +545,7 @@ class StdioTransport(Transport):
|
||||||
await self._write_message(message)
|
await self._write_message(message)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future: asyncio.Future[Any] = loop.create_future()
|
future: asyncio.Future[object] = loop.create_future()
|
||||||
self._pending[request_id] = future
|
self._pending[request_id] = future
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -558,9 +557,9 @@ class StdioTransport(Transport):
|
||||||
self._pending.pop(request_id, None)
|
self._pending.pop(request_id, None)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
async def _send_notification(self, method: str, params: dict[str, object] | None = None) -> None:
|
||||||
"""发送 JSON-RPC 通知(无 id,不期待响应)"""
|
"""发送 JSON-RPC 通知(无 id,不期待响应)"""
|
||||||
message: dict[str, Any] = {
|
message: dict[str, object] = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"method": method,
|
"method": method,
|
||||||
}
|
}
|
||||||
|
|
@ -568,7 +567,7 @@ class StdioTransport(Transport):
|
||||||
message["params"] = params
|
message["params"] = params
|
||||||
await self._write_message(message)
|
await self._write_message(message)
|
||||||
|
|
||||||
async def _write_message(self, message: dict[str, Any]) -> None:
|
async def _write_message(self, message: dict[str, object]) -> None:
|
||||||
"""将 JSON-RPC 消息写入子进程 stdin"""
|
"""将 JSON-RPC 消息写入子进程 stdin"""
|
||||||
if self._process is None or self._process.stdin is None:
|
if self._process is None or self._process.stdin is None:
|
||||||
raise TransportError("Process stdin not available")
|
raise TransportError("Process stdin not available")
|
||||||
|
|
@ -576,7 +575,7 @@ class StdioTransport(Transport):
|
||||||
self._process.stdin.write(data)
|
self._process.stdin.write(data)
|
||||||
await self._process.stdin.drain()
|
await self._process.stdin.drain()
|
||||||
|
|
||||||
async def receive_response(self) -> dict[str, Any]:
|
async def receive_response(self) -> dict[str, object]:
|
||||||
"""接收通知消息
|
"""接收通知消息
|
||||||
|
|
||||||
对于 StdioTransport,请求响应通过 _pending Future 异步返回。
|
对于 StdioTransport,请求响应通过 _pending Future 异步返回。
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.telemetry.tracer import get_tracer
|
from agentkit.telemetry.tracer import get_tracer
|
||||||
|
|
||||||
|
|
@ -23,7 +22,7 @@ class AlignmentConfig:
|
||||||
audit_sample_rate: float = 1.0 # 审计采样率 0.0-1.0,1.0=每次都审计
|
audit_sample_rate: float = 1.0 # 审计采样率 0.0-1.0,1.0=每次都审计
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "AlignmentConfig":
|
def from_dict(cls, data: dict[str, object]) -> "AlignmentConfig":
|
||||||
"""从字典创建,忽略未知键"""
|
"""从字典创建,忽略未知键"""
|
||||||
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
||||||
filtered = {k: v for k, v in data.items() if k in known_fields}
|
filtered = {k: v for k, v in data.items() if k in known_fields}
|
||||||
|
|
@ -56,7 +55,7 @@ class ConstraintInjector:
|
||||||
def __init__(self, config: AlignmentConfig):
|
def __init__(self, config: AlignmentConfig):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
def inject(self, input_data: dict[str, Any]) -> dict[str, Any]:
|
def inject(self, input_data: dict[str, object]) -> dict[str, object]:
|
||||||
"""注入约束指令到 input_data
|
"""注入约束指令到 input_data
|
||||||
|
|
||||||
在 input_data 中添加 'alignment_constraints' 键,值为约束列表。
|
在 input_data 中添加 'alignment_constraints' 键,值为约束列表。
|
||||||
|
|
@ -76,13 +75,13 @@ class AlignmentGuard:
|
||||||
self._interaction_counts: dict[str, int] = {}
|
self._interaction_counts: dict[str, int] = {}
|
||||||
self._loop_depths: dict[str, int] = {}
|
self._loop_depths: dict[str, int] = {}
|
||||||
|
|
||||||
def inject_constraints(self, input_data: dict[str, Any]) -> dict[str, Any]:
|
def inject_constraints(self, input_data: dict[str, object]) -> dict[str, object]:
|
||||||
"""委托给 ConstraintInjector"""
|
"""委托给 ConstraintInjector"""
|
||||||
return self._injector.inject(input_data)
|
return self._injector.inject(input_data)
|
||||||
|
|
||||||
async def check_output(
|
async def check_output(
|
||||||
self,
|
self,
|
||||||
output: dict[str, Any],
|
output: dict[str, object],
|
||||||
constraints: list[str] | None = None,
|
constraints: list[str] | None = None,
|
||||||
) -> AlignmentCheckResult:
|
) -> AlignmentCheckResult:
|
||||||
"""检查输出是否符合约束
|
"""检查输出是否符合约束
|
||||||
|
|
@ -127,7 +126,7 @@ class AlignmentGuard:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _rule_check(
|
def _rule_check(
|
||||||
self, output: dict[str, Any], constraints: list[str]
|
self, output: dict[str, object], constraints: list[str]
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""基于规则的约束检查:方向性判断,区分'禁止X'和'提及X'
|
"""基于规则的约束检查:方向性判断,区分'禁止X'和'提及X'
|
||||||
|
|
||||||
|
|
@ -206,7 +205,7 @@ class AlignmentGuard:
|
||||||
start = idx + len(keyword)
|
start = idx + len(keyword)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_text(output: dict[str, Any]) -> str:
|
def _extract_text(output: dict[str, object]) -> str:
|
||||||
"""从 output dict 中提取所有文本内容"""
|
"""从 output dict 中提取所有文本内容"""
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
for value in output.values():
|
for value in output.values():
|
||||||
|
|
@ -217,7 +216,7 @@ class AlignmentGuard:
|
||||||
return " ".join(parts)
|
return " ".join(parts)
|
||||||
|
|
||||||
async def _llm_check(
|
async def _llm_check(
|
||||||
self, output: dict[str, Any], constraints: list[str]
|
self, output: dict[str, object], constraints: list[str]
|
||||||
) -> AlignmentCheckResult:
|
) -> AlignmentCheckResult:
|
||||||
"""LLM 语义检查"""
|
"""LLM 语义检查"""
|
||||||
content = self._extract_text(output)
|
content = self._extract_text(output)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ Key schema (Redis):
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义(降级到 fallback)
|
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义(降级到 fallback)
|
||||||
try:
|
try:
|
||||||
|
|
@ -132,7 +132,7 @@ class RedisCascadeStateStore:
|
||||||
def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400):
|
def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400):
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._session_ttl = session_ttl
|
self._session_ttl = session_ttl
|
||||||
self._sync_redis: Any = None
|
self._sync_redis: object = None
|
||||||
self._fallback: InMemoryCascadeStateStore | None = None
|
self._fallback: InMemoryCascadeStateStore | None = None
|
||||||
self._degraded = False
|
self._degraded = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
from agentkit.skills.base import Skill
|
from agentkit.skills.base import Skill
|
||||||
|
|
||||||
|
|
@ -36,9 +36,9 @@ class QualityGate:
|
||||||
|
|
||||||
async def validate(
|
async def validate(
|
||||||
self,
|
self,
|
||||||
output: dict[str, Any],
|
output: dict[str, object],
|
||||||
skill: Skill,
|
skill: Skill,
|
||||||
skill_context: dict[str, Any] | None = None,
|
skill_context: dict[str, object] | None = None,
|
||||||
) -> QualityResult:
|
) -> QualityResult:
|
||||||
"""对产出执行多维度质量检查
|
"""对产出执行多维度质量检查
|
||||||
|
|
||||||
|
|
@ -150,8 +150,8 @@ class QualityGate:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_skill_match(
|
def _check_skill_match(
|
||||||
output: dict[str, Any],
|
output: dict[str, object],
|
||||||
skill_context: dict[str, Any] | None,
|
skill_context: dict[str, object] | None,
|
||||||
) -> QualityCheck | None:
|
) -> QualityCheck | None:
|
||||||
"""第五维度:技能匹配验证
|
"""第五维度:技能匹配验证
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ Schema 验证、字段类型归一化、元数据附加。
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.quality.gate import QualityResult
|
from agentkit.quality.gate import QualityResult
|
||||||
from agentkit.skills.base import Skill
|
from agentkit.skills.base import Skill
|
||||||
|
|
@ -28,7 +27,7 @@ class StandardOutput:
|
||||||
"""标准化输出"""
|
"""标准化输出"""
|
||||||
|
|
||||||
skill_name: str
|
skill_name: str
|
||||||
data: dict[str, Any]
|
data: dict[str, object]
|
||||||
metadata: OutputMetadata
|
metadata: OutputMetadata
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,7 +36,7 @@ class OutputStandardizer:
|
||||||
|
|
||||||
async def standardize(
|
async def standardize(
|
||||||
self,
|
self,
|
||||||
raw_output: dict[str, Any],
|
raw_output: dict[str, object],
|
||||||
skill: Skill,
|
skill: Skill,
|
||||||
quality_result: QualityResult | None = None,
|
quality_result: QualityResult | None = None,
|
||||||
) -> StandardOutput:
|
) -> StandardOutput:
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ PGVectorStore(向量化存储)。
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.memory.document_loader import DocumentLoader
|
from agentkit.memory.document_loader import DocumentLoader
|
||||||
from agentkit.rag_platform.models import DocumentStatus
|
from agentkit.rag_platform.models import DocumentStatus
|
||||||
|
|
@ -106,7 +106,7 @@ class DocumentProcessor:
|
||||||
file_type: str,
|
file_type: str,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
chunk_overlap: int | None = None,
|
chunk_overlap: int | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
"""解析 + 分段,返回 chunk 用于只读预览(不向量化)。
|
"""解析 + 分段,返回 chunk 用于只读预览(不向量化)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -128,7 +128,7 @@ class DocumentProcessor:
|
||||||
|
|
||||||
async def vectorize(
|
async def vectorize(
|
||||||
self,
|
self,
|
||||||
chunks: list[str] | list[dict[str, Any]],
|
chunks: list[str] | list[dict[str, object]],
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
document_id: str,
|
document_id: str,
|
||||||
vector_store: "PGVectorStore",
|
vector_store: "PGVectorStore",
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
@ -56,7 +55,7 @@ class HitProcessor:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_gateway: Any = None,
|
llm_gateway: object | None = None,
|
||||||
cache_enabled: bool = True,
|
cache_enabled: bool = True,
|
||||||
model: str = _DEFAULT_LLM_MODEL,
|
model: str = _DEFAULT_LLM_MODEL,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ U3 将扩展为完整 IngestionPipeline(含解析、预览、净化)。
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.rag_platform.indexing import KB_CHUNKS_TABLE
|
from agentkit.rag_platform.indexing import KB_CHUNKS_TABLE
|
||||||
from agentkit.rag_platform.models import QueryResult
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
|
@ -64,7 +64,7 @@ class RAGPipeline:
|
||||||
KB_CHUNKS_TABLE,
|
KB_CHUNKS_TABLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ingest(self, text: str, metadata: dict[str, Any] | None = None) -> list["TextNode"]:
|
async def ingest(self, text: str, metadata: dict[str, object] | None = None) -> list["TextNode"]:
|
||||||
"""将文本摄入向量存储。
|
"""将文本摄入向量存储。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from agentkit.rag_platform.document_processor import (
|
from agentkit.rag_platform.document_processor import (
|
||||||
|
|
@ -23,7 +21,7 @@ class PreviewChunk(BaseModel):
|
||||||
|
|
||||||
index: int
|
index: int
|
||||||
content: str
|
content: str
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PreviewResult(BaseModel):
|
class PreviewResult(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from __future__ import annotations
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
@ -63,7 +62,7 @@ class QuestionGenerator:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_gateway: Any = None,
|
llm_gateway: object | None = None,
|
||||||
max_questions_per_chunk: int = 3,
|
max_questions_per_chunk: int = 3,
|
||||||
model: str = "default",
|
model: str = "default",
|
||||||
cache: bool = True,
|
cache: bool = True,
|
||||||
|
|
@ -76,7 +75,7 @@ class QuestionGenerator:
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
chunks: list[dict[str, Any]],
|
chunks: list[dict[str, object]],
|
||||||
document_context: str = "",
|
document_context: str = "",
|
||||||
) -> list[GeneratedQuestion]:
|
) -> list[GeneratedQuestion]:
|
||||||
"""为每个 chunk 生成相关问题。
|
"""为每个 chunk 生成相关问题。
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
@ -50,9 +50,9 @@ class Reranker:
|
||||||
|
|
||||||
def __init__(self, config: RerankConfig) -> None:
|
def __init__(self, config: RerankConfig) -> None:
|
||||||
self._config = config
|
self._config = config
|
||||||
self._reranker: Any = None # 延迟初始化,避免 import 失败
|
self._reranker: object | None = None # 延迟初始化,避免 import 失败
|
||||||
|
|
||||||
def _get_reranker(self) -> Any:
|
def _get_reranker(self) -> object | None:
|
||||||
"""延迟加载 reranker 实例 — 避免在 import 时失败。"""
|
"""延迟加载 reranker 实例 — 避免在 import 时失败。"""
|
||||||
if self._reranker is not None:
|
if self._reranker is not None:
|
||||||
return self._reranker
|
return self._reranker
|
||||||
|
|
@ -69,7 +69,7 @@ class Reranker:
|
||||||
|
|
||||||
return self._reranker
|
return self._reranker
|
||||||
|
|
||||||
def _build_cohere_reranker(self, cfg: RerankConfig) -> Any:
|
def _build_cohere_reranker(self, cfg: RerankConfig) -> object:
|
||||||
"""构建 CohereRerank — 数据出境,需 api_key。"""
|
"""构建 CohereRerank — 数据出境,需 api_key。"""
|
||||||
if not cfg.api_key:
|
if not cfg.api_key:
|
||||||
raise ValueError("Cohere rerank requires api_key")
|
raise ValueError("Cohere rerank requires api_key")
|
||||||
|
|
@ -81,7 +81,7 @@ class Reranker:
|
||||||
"Install: pip install llama-index-postprocessor-cohere-rerank"
|
"Install: pip install llama-index-postprocessor-cohere-rerank"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, object] = {
|
||||||
"api_key": cfg.api_key,
|
"api_key": cfg.api_key,
|
||||||
"top_n": cfg.top_n,
|
"top_n": cfg.top_n,
|
||||||
}
|
}
|
||||||
|
|
@ -89,7 +89,7 @@ class Reranker:
|
||||||
kwargs["model"] = cfg.model_name
|
kwargs["model"] = cfg.model_name
|
||||||
return CohereRerank(**kwargs)
|
return CohereRerank(**kwargs)
|
||||||
|
|
||||||
def _build_bge_reranker(self, cfg: RerankConfig) -> Any:
|
def _build_bge_reranker(self, cfg: RerankConfig) -> object:
|
||||||
"""构建 BGE-Reranker via Xinference(本地部署,无数据出境)。
|
"""构建 BGE-Reranker via Xinference(本地部署,无数据出境)。
|
||||||
|
|
||||||
使用 SentenceTransformerRerank 作为本地 BGE-Reranker 的封装。
|
使用 SentenceTransformerRerank 作为本地 BGE-Reranker 的封装。
|
||||||
|
|
@ -107,7 +107,7 @@ class Reranker:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
model = cfg.model_name or "BAAI/bge-reranker-base"
|
model = cfg.model_name or "BAAI/bge-reranker-base"
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, object] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"top_n": cfg.top_n,
|
"top_n": cfg.top_n,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
@ -184,7 +184,7 @@ class RetrievalEngine:
|
||||||
out: list[QueryResult] = []
|
out: list[QueryResult] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
chunk_id, content, metadata, document_id, score = row
|
chunk_id, content, metadata, document_id, score = row
|
||||||
meta: dict[str, Any] = metadata if isinstance(metadata, dict) else {}
|
meta: dict[str, object] = metadata if isinstance(metadata, dict) else {}
|
||||||
kb_id = meta.get("kb_id", "")
|
kb_id = meta.get("kb_id", "")
|
||||||
out.append(
|
out.append(
|
||||||
QueryResult(
|
QueryResult(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import re
|
||||||
import struct
|
import struct
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from xml.etree.ElementTree import Element
|
||||||
|
|
||||||
# 文件类型白名单:扩展名 → MIME 类型
|
# 文件类型白名单:扩展名 → MIME 类型
|
||||||
ALLOWED_FILE_TYPES: dict[str, str] = {
|
ALLOWED_FILE_TYPES: dict[str, str] = {
|
||||||
|
|
@ -311,7 +311,7 @@ def sanitize_content(content: str, file_type: str) -> str:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def parse_xml_safe(content: bytes) -> Any:
|
def parse_xml_safe(content: bytes) -> Element:
|
||||||
"""安全解析 XML — 禁止 DTD/实体以防止 XXE 攻击。
|
"""安全解析 XML — 禁止 DTD/实体以防止 XXE 攻击。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Any, Protocol
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
@ -54,7 +54,7 @@ TASKIQ_REDIS_DB = 1
|
||||||
TASKIQ_KEY_PREFIX = "taskiq:"
|
TASKIQ_KEY_PREFIX = "taskiq:"
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_await(result: Any) -> Any:
|
async def _maybe_await(result: object) -> object:
|
||||||
"""统一处理 sync/async 调用结果 — InMemoryTaskStore 方法为 sync,RedisTaskStore 为 async。"""
|
"""统一处理 sync/async 调用结果 — InMemoryTaskStore 方法为 sync,RedisTaskStore 为 async。"""
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
return await result
|
return await result
|
||||||
|
|
@ -81,7 +81,7 @@ class VectorizeTaskParams(BaseModel):
|
||||||
|
|
||||||
@field_validator("chunk_overlap")
|
@field_validator("chunk_overlap")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _overlap_less_than_size(cls, v: int, info: Any) -> int:
|
def _overlap_less_than_size(cls, v: int, info: object) -> int:
|
||||||
"""chunk_overlap 必须小于 chunk_size。"""
|
"""chunk_overlap 必须小于 chunk_size。"""
|
||||||
size = info.data.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
size = info.data.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
||||||
if v >= size:
|
if v >= size:
|
||||||
|
|
@ -161,7 +161,7 @@ class TaskStoreProtocol(Protocol):
|
||||||
def get(self, task_id: str) -> TaskRecord | None: ...
|
def get(self, task_id: str) -> TaskRecord | None: ...
|
||||||
|
|
||||||
async def update_status(
|
async def update_status(
|
||||||
self, task_id: str, status: TaskStatus, **kwargs: Any
|
self, task_id: str, status: TaskStatus, **kwargs: object
|
||||||
) -> TaskRecord: ...
|
) -> TaskRecord: ...
|
||||||
|
|
||||||
def list_tasks(
|
def list_tasks(
|
||||||
|
|
@ -265,7 +265,7 @@ class TaskManager:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
broker: Any = None,
|
broker: object | None = None,
|
||||||
task_store: TaskStoreProtocol | None = None,
|
task_store: TaskStoreProtocol | None = None,
|
||||||
max_concurrent_per_user: int = DEFAULT_MAX_CONCURRENT_PER_USER,
|
max_concurrent_per_user: int = DEFAULT_MAX_CONCURRENT_PER_USER,
|
||||||
task_ttl_seconds: int = DEFAULT_TASK_TTL_SECONDS,
|
task_ttl_seconds: int = DEFAULT_TASK_TTL_SECONDS,
|
||||||
|
|
@ -284,7 +284,7 @@ class TaskManager:
|
||||||
return self._store
|
return self._store
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def broker(self) -> Any:
|
def broker(self) -> object:
|
||||||
"""暴露底层 broker(供 startup/shutdown 集成)。"""
|
"""暴露底层 broker(供 startup/shutdown 集成)。"""
|
||||||
return self._broker
|
return self._broker
|
||||||
|
|
||||||
|
|
@ -293,7 +293,7 @@ class TaskManager:
|
||||||
async def submit_vectorize(
|
async def submit_vectorize(
|
||||||
self,
|
self,
|
||||||
params: VectorizeTaskParams,
|
params: VectorizeTaskParams,
|
||||||
dependencies: dict[str, Any] | None = None,
|
dependencies: dict[str, object] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""提交向量化任务,返回 task_id。
|
"""提交向量化任务,返回 task_id。
|
||||||
|
|
||||||
|
|
@ -344,7 +344,7 @@ class TaskManager:
|
||||||
async def submit_batch_index(
|
async def submit_batch_index(
|
||||||
self,
|
self,
|
||||||
params: BatchIndexTaskParams,
|
params: BatchIndexTaskParams,
|
||||||
dependencies: dict[str, Any] | None = None,
|
dependencies: dict[str, object] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""提交批量索引任务,返回 task_id。"""
|
"""提交批量索引任务,返回 task_id。"""
|
||||||
await self._check_concurrency(params.user_id)
|
await self._check_concurrency(params.user_id)
|
||||||
|
|
@ -466,7 +466,7 @@ class TaskManager:
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
params: VectorizeTaskParams,
|
params: VectorizeTaskParams,
|
||||||
deps: dict[str, Any],
|
deps: dict[str, object],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""降级模式 — 同步执行向量化任务(在 asyncio 任务中)。"""
|
"""降级模式 — 同步执行向量化任务(在 asyncio 任务中)。"""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
@ -504,7 +504,7 @@ class TaskManager:
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
params: BatchIndexTaskParams,
|
params: BatchIndexTaskParams,
|
||||||
deps: dict[str, Any],
|
deps: dict[str, object],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""降级模式 — 同步执行批量索引任务。"""
|
"""降级模式 — 同步执行批量索引任务。"""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
@ -655,7 +655,7 @@ async def run_batch_index_task(
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def create_broker(redis_url: str) -> Any:
|
def create_broker(redis_url: str) -> object:
|
||||||
"""创建 TaskIQ Redis broker — 独立 db=1 隔离。
|
"""创建 TaskIQ Redis broker — 独立 db=1 隔离。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ These dependencies read the ``current_user`` payload that
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
@ -25,7 +25,7 @@ from agentkit.server.auth.permissions import Permission, has_permission, is_dev_
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> dict[str, Any] | None:
|
async def get_current_user(request: Request) -> dict[str, object] | None:
|
||||||
"""Return the current user payload, or ``None`` in dev mode.
|
"""Return the current user payload, or ``None`` in dev mode.
|
||||||
|
|
||||||
The payload is set by :class:`AuthMiddleware` and contains
|
The payload is set by :class:`AuthMiddleware` and contains
|
||||||
|
|
@ -43,7 +43,7 @@ async def get_current_user(request: Request) -> dict[str, Any] | None:
|
||||||
return getattr(request.state, "current_user", None)
|
return getattr(request.state, "current_user", None)
|
||||||
|
|
||||||
|
|
||||||
async def require_authenticated(request: Request) -> dict[str, Any]:
|
async def require_authenticated(request: Request) -> dict[str, object]:
|
||||||
"""Require an authenticated user. Raises 401 if not authenticated.
|
"""Require an authenticated user. Raises 401 if not authenticated.
|
||||||
|
|
||||||
Use this as a FastAPI dependency for endpoints that need *any*
|
Use this as a FastAPI dependency for endpoints that need *any*
|
||||||
|
|
@ -67,7 +67,7 @@ async def require_authenticated(request: Request) -> dict[str, Any]:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def require_permission(*permissions: Permission) -> Callable[..., Any]:
|
def require_permission(*permissions: Permission) -> Callable[..., object]:
|
||||||
"""Build a FastAPI dependency that requires the given permissions.
|
"""Build a FastAPI dependency that requires the given permissions.
|
||||||
|
|
||||||
The user must have *all* of the given permissions (AND semantics).
|
The user must have *all* of the given permissions (AND semantics).
|
||||||
|
|
@ -96,7 +96,7 @@ def require_permission(*permissions: Permission) -> Callable[..., Any]:
|
||||||
A FastAPI dependency function.
|
A FastAPI dependency function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _dependency(request: Request) -> dict[str, Any]:
|
async def _dependency(request: Request) -> dict[str, object]:
|
||||||
user = await get_current_user(request)
|
user = await get_current_user(request)
|
||||||
|
|
||||||
# Dev mode: no authenticated user
|
# Dev mode: no authenticated user
|
||||||
|
|
@ -132,13 +132,13 @@ def require_permission(*permissions: Permission) -> Callable[..., Any]:
|
||||||
return _dependency
|
return _dependency
|
||||||
|
|
||||||
|
|
||||||
def require_any_permission(*permissions: Permission) -> Callable[..., Any]:
|
def require_any_permission(*permissions: Permission) -> Callable[..., object]:
|
||||||
"""Build a FastAPI dependency that requires at least one of the permissions.
|
"""Build a FastAPI dependency that requires at least one of the permissions.
|
||||||
|
|
||||||
Similar to :func:`require_permission` but uses OR semantics.
|
Similar to :func:`require_permission` but uses OR semantics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _dependency(request: Request) -> dict[str, Any]:
|
async def _dependency(request: Request) -> dict[str, object]:
|
||||||
user = await get_current_user(request)
|
user = await get_current_user(request)
|
||||||
|
|
||||||
if is_dev_mode(user):
|
if is_dev_mode(user):
|
||||||
|
|
@ -181,7 +181,7 @@ async def _resolve_db_path(request: Request):
|
||||||
return DEFAULT_AUTH_DB_PATH
|
return DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
|
||||||
async def require_terminal_authorized(request: Request) -> dict[str, Any]:
|
async def require_terminal_authorized(request: Request) -> dict[str, object]:
|
||||||
"""Require a user authorized to use the local terminal.
|
"""Require a user authorized to use the local terminal.
|
||||||
|
|
||||||
This checks both:
|
This checks both:
|
||||||
|
|
@ -224,7 +224,7 @@ async def require_terminal_authorized(request: Request) -> dict[str, Any]:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def require_server_terminal_authorized(request: Request) -> dict[str, Any]:
|
async def require_server_terminal_authorized(request: Request) -> dict[str, object]:
|
||||||
"""Require a user authorized to use the server terminal.
|
"""Require a user authorized to use the server terminal.
|
||||||
|
|
||||||
Checks:
|
Checks:
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
|
@ -165,7 +164,7 @@ def create_token_pair(
|
||||||
# the JWT.
|
# the JWT.
|
||||||
jti = str(uuid.uuid4()) if effective_session_id else None
|
jti = str(uuid.uuid4()) if effective_session_id else None
|
||||||
|
|
||||||
access_payload: dict[str, Any] = {
|
access_payload: dict[str, object] = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"username": username,
|
"username": username,
|
||||||
"role": role,
|
"role": role,
|
||||||
|
|
@ -173,7 +172,7 @@ def create_token_pair(
|
||||||
"iat": int(issued_at.timestamp()),
|
"iat": int(issued_at.timestamp()),
|
||||||
"exp": int(access_exp.timestamp()),
|
"exp": int(access_exp.timestamp()),
|
||||||
}
|
}
|
||||||
refresh_payload: dict[str, Any] = {
|
refresh_payload: dict[str, object] = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"username": username,
|
"username": username,
|
||||||
"role": role,
|
"role": role,
|
||||||
|
|
@ -238,7 +237,7 @@ def create_access_token(
|
||||||
access_exp = issued_at + ACCESS_TOKEN_TTL
|
access_exp = issued_at + ACCESS_TOKEN_TTL
|
||||||
jti = str(uuid.uuid4()) if session_id else None
|
jti = str(uuid.uuid4()) if session_id else None
|
||||||
|
|
||||||
access_payload: dict[str, Any] = {
|
access_payload: dict[str, object] = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"username": username,
|
"username": username,
|
||||||
"role": role,
|
"role": role,
|
||||||
|
|
@ -261,7 +260,7 @@ def verify_token(
|
||||||
secret: str,
|
secret: str,
|
||||||
*,
|
*,
|
||||||
expected_type: str | None = None,
|
expected_type: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Verify a JWT and return its payload.
|
"""Verify a JWT and return its payload.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
@ -93,7 +92,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
"""Dev mode = no JWT secret, no global API key, no client keys."""
|
"""Dev mode = no JWT secret, no global API key, no client keys."""
|
||||||
return not self._jwt_secret and self._api_key is None and not self._client_keys
|
return not self._jwt_secret and self._api_key is None and not self._client_keys
|
||||||
|
|
||||||
def _verify_jwt(self, token: str) -> dict[str, Any] | None:
|
def _verify_jwt(self, token: str) -> dict[str, object] | None:
|
||||||
"""Verify a JWT bearer token. Returns payload or None.
|
"""Verify a JWT bearer token. Returns payload or None.
|
||||||
|
|
||||||
V2 tokens carry a ``sid`` claim. The middleware does NOT
|
V2 tokens carry a ``sid`` claim. The middleware does NOT
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,6 @@ import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from sqlalchemy import String
|
from sqlalchemy import String
|
||||||
|
|
@ -774,7 +773,7 @@ async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, object]:
|
||||||
"""Convert a ``users`` row into a JSON-safe dict."""
|
"""Convert a ``users`` row into a JSON-safe dict."""
|
||||||
return {
|
return {
|
||||||
"id": row["id"],
|
"id": row["id"],
|
||||||
|
|
@ -791,7 +790,7 @@ def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, object]:
|
||||||
"""Convert an ``auth_sessions`` row into a JSON-safe dict.
|
"""Convert an ``auth_sessions`` row into a JSON-safe dict.
|
||||||
|
|
||||||
The ``revoked`` field is normalized to a Python ``bool`` (the DB stores
|
The ``revoked`` field is normalized to a Python ``bool`` (the DB stores
|
||||||
|
|
@ -816,7 +815,7 @@ def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
def department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, object]:
|
||||||
"""Convert a ``departments`` row into a JSON-safe dict.
|
"""Convert a ``departments`` row into a JSON-safe dict.
|
||||||
|
|
||||||
The ``is_active`` field is normalized to a Python ``bool`` (the DB
|
The ``is_active`` field is normalized to a Python ``bool`` (the DB
|
||||||
|
|
@ -831,7 +830,7 @@ def department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, object]:
|
||||||
"""Convert a ``user_departments`` row into a JSON-safe dict."""
|
"""Convert a ``user_departments`` row into a JSON-safe dict."""
|
||||||
return {
|
return {
|
||||||
"user_id": row["user_id"],
|
"user_id": row["user_id"],
|
||||||
|
|
@ -840,7 +839,7 @@ def user_department_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> di
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
def skill_state_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, object]:
|
||||||
"""Convert a ``skill_states`` row into a JSON-safe dict.
|
"""Convert a ``skill_states`` row into a JSON-safe dict.
|
||||||
|
|
||||||
The ``is_disabled`` field is normalized to a Python ``bool`` (the DB
|
The ``is_disabled`` field is normalized to a Python ``bool`` (the DB
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ terminal access on a per-user basis without changing the user's role.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class Permission(str, Enum):
|
class Permission(str, Enum):
|
||||||
|
|
@ -108,7 +107,7 @@ def get_role_permissions(role: str | None) -> frozenset[Permission]:
|
||||||
return ROLE_PERMISSIONS.get(role, frozenset())
|
return ROLE_PERMISSIONS.get(role, frozenset())
|
||||||
|
|
||||||
|
|
||||||
def has_permission(user: dict[str, Any] | None, permission: Permission) -> bool:
|
def has_permission(user: dict[str, object] | None, permission: Permission) -> bool:
|
||||||
"""Check if a user payload has a specific permission.
|
"""Check if a user payload has a specific permission.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -132,7 +131,7 @@ def has_permission(user: dict[str, Any] | None, permission: Permission) -> bool:
|
||||||
return permission in get_role_permissions(role)
|
return permission in get_role_permissions(role)
|
||||||
|
|
||||||
|
|
||||||
def is_dev_mode(user: dict[str, Any] | None) -> bool:
|
def is_dev_mode(user: dict[str, object] | None) -> bool:
|
||||||
"""Return True if the request is in dev mode (no authenticated user).
|
"""Return True if the request is in dev mode (no authenticated user).
|
||||||
|
|
||||||
Dev mode is when ``AuthMiddleware`` passes through requests without
|
Dev mode is when ``AuthMiddleware`` passes through requests without
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
@ -149,7 +148,7 @@ class LocalAuthProvider:
|
||||||
is_terminal_authorized: bool = False,
|
is_terminal_authorized: bool = False,
|
||||||
is_server_terminal_authorized: bool = False,
|
is_server_terminal_authorized: bool = False,
|
||||||
created_by: str | None = None,
|
created_by: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""Create a new user in the local ``users`` table.
|
"""Create a new user in the local ``users`` table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
@ -179,7 +178,7 @@ class SessionService:
|
||||||
admin "list all sessions" view passes ``include_revoked=True``.
|
admin "list all sessions" view passes ``include_revoked=True``.
|
||||||
"""
|
"""
|
||||||
sql = "SELECT * FROM auth_sessions WHERE user_id = ?"
|
sql = "SELECT * FROM auth_sessions WHERE user_id = ?"
|
||||||
args: tuple[Any, ...] = (user_id,)
|
args: tuple[object, ...] = (user_id,)
|
||||||
if not include_revoked:
|
if not include_revoked:
|
||||||
sql += " AND revoked = 0"
|
sql += " AND revoked = 0"
|
||||||
sql += " ORDER BY created_at DESC"
|
sql += " ORDER BY created_at DESC"
|
||||||
|
|
@ -437,7 +436,7 @@ class SessionService:
|
||||||
"SET revoked = 1, revoked_reason = ? "
|
"SET revoked = 1, revoked_reason = ? "
|
||||||
"WHERE user_id = ? AND revoked = 0"
|
"WHERE user_id = ? AND revoked = 0"
|
||||||
)
|
)
|
||||||
args: list[Any] = [reason, user_id]
|
args: list[object] = [reason, user_id]
|
||||||
if except_sid is not None:
|
if except_sid is not None:
|
||||||
sql += " AND id != ?"
|
sql += " AND id != ?"
|
||||||
args.append(except_sid)
|
args.append(except_sid)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||||
from agentkit.session.store import InMemorySessionStore, SessionStore
|
from agentkit.session.store import InMemorySessionStore, SessionStore
|
||||||
|
|
@ -155,7 +154,7 @@ class SessionManager:
|
||||||
async def create_session(
|
async def create_session(
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, object] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Create a new conversation session bound to an Agent.
|
"""Create a new conversation session bound to an Agent.
|
||||||
|
|
||||||
|
|
@ -216,7 +215,7 @@ class SessionManager:
|
||||||
content: str,
|
content: str,
|
||||||
tool_call_id: str | None = None,
|
tool_call_id: str | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, object] | None = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""Append a message to a session.
|
"""Append a message to a session.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class SessionStatus(str, Enum):
|
class SessionStatus(str, Enum):
|
||||||
|
|
@ -40,9 +39,9 @@ class Message:
|
||||||
tool_call_id: str | None = None
|
tool_call_id: str | None = None
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, object] = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"message_id": self.message_id,
|
"message_id": self.message_id,
|
||||||
"session_id": self.session_id,
|
"session_id": self.session_id,
|
||||||
|
|
@ -55,7 +54,7 @@ class Message:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
def from_dict(cls, data: dict[str, object]) -> Message:
|
||||||
return cls(
|
return cls(
|
||||||
message_id=data["message_id"],
|
message_id=data["message_id"],
|
||||||
session_id=data["session_id"],
|
session_id=data["session_id"],
|
||||||
|
|
@ -89,11 +88,11 @@ class Session:
|
||||||
session_id: str
|
session_id: str
|
||||||
agent_name: str
|
agent_name: str
|
||||||
status: SessionStatus = SessionStatus.ACTIVE
|
status: SessionStatus = SessionStatus.ACTIVE
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, object] = field(default_factory=dict)
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"session_id": self.session_id,
|
"session_id": self.session_id,
|
||||||
"agent_name": self.agent_name,
|
"agent_name": self.agent_name,
|
||||||
|
|
@ -104,7 +103,7 @@ class Session:
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> Session:
|
def from_dict(cls, data: dict[str, object]) -> Session:
|
||||||
return cls(
|
return cls(
|
||||||
session_id=data["session_id"],
|
session_id=data["session_id"],
|
||||||
agent_name=data["agent_name"],
|
agent_name=data["agent_name"],
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义
|
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义
|
||||||
try:
|
try:
|
||||||
|
|
@ -119,7 +119,7 @@ class RedisSessionStore:
|
||||||
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
|
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._ttl_seconds = ttl_seconds
|
self._ttl_seconds = ttl_seconds
|
||||||
self._redis: Any = None
|
self._redis: object = None
|
||||||
|
|
||||||
async def _get_redis(self):
|
async def _get_redis(self):
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, ContextManager
|
from typing import ContextManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -27,10 +27,10 @@ class NoOpSpan:
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any) -> None:
|
def set_attribute(self, key: str, value: object) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
def add_event(self, name: str, attributes: dict[str, object] | None = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def record_exception(self, exception: Exception) -> None:
|
def record_exception(self, exception: Exception) -> None:
|
||||||
|
|
@ -43,17 +43,17 @@ class NoOpSpan:
|
||||||
class NoOpTracer:
|
class NoOpTracer:
|
||||||
"""No-op tracer when telemetry is disabled."""
|
"""No-op tracer when telemetry is disabled."""
|
||||||
|
|
||||||
def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]:
|
def start_span(self, name: str, attributes: dict[str, object] | None = None) -> ContextManager[NoOpSpan]:
|
||||||
return NoOpSpan()
|
return NoOpSpan()
|
||||||
|
|
||||||
def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> ContextManager[NoOpSpan]:
|
def start_as_current_span(self, name: str, attributes: dict[str, object] | None = None) -> ContextManager[NoOpSpan]:
|
||||||
return NoOpSpan()
|
return NoOpSpan()
|
||||||
|
|
||||||
|
|
||||||
class OTelSpan:
|
class OTelSpan:
|
||||||
"""Wrapper around OpenTelemetry Span."""
|
"""Wrapper around OpenTelemetry Span."""
|
||||||
|
|
||||||
def __init__(self, span: Any):
|
def __init__(self, span: object):
|
||||||
self._span = span
|
self._span = span
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
@ -63,10 +63,10 @@ class OTelSpan:
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
self._span.__exit__(*args)
|
self._span.__exit__(*args)
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any) -> None:
|
def set_attribute(self, key: str, value: object) -> None:
|
||||||
self._span.set_attribute(key, value)
|
self._span.set_attribute(key, value)
|
||||||
|
|
||||||
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
def add_event(self, name: str, attributes: dict[str, object] | None = None) -> None:
|
||||||
self._span.add_event(name, attributes or {})
|
self._span.add_event(name, attributes or {})
|
||||||
|
|
||||||
def record_exception(self, exception: Exception) -> None:
|
def record_exception(self, exception: Exception) -> None:
|
||||||
|
|
@ -79,14 +79,14 @@ class OTelSpan:
|
||||||
class OTelTracer:
|
class OTelTracer:
|
||||||
"""Wrapper around OpenTelemetry Tracer."""
|
"""Wrapper around OpenTelemetry Tracer."""
|
||||||
|
|
||||||
def __init__(self, tracer: Any):
|
def __init__(self, tracer: object):
|
||||||
self._tracer = tracer
|
self._tracer = tracer
|
||||||
|
|
||||||
def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan:
|
def start_span(self, name: str, attributes: dict[str, object] | None = None) -> OTelSpan:
|
||||||
span = self._tracer.start_span(name, attributes=attributes)
|
span = self._tracer.start_span(name, attributes=attributes)
|
||||||
return OTelSpan(span)
|
return OTelSpan(span)
|
||||||
|
|
||||||
def start_as_current_span(self, name: str, attributes: dict[str, Any] | None = None) -> OTelSpan:
|
def start_as_current_span(self, name: str, attributes: dict[str, object] | None = None) -> OTelSpan:
|
||||||
span = self._tracer.start_as_current_span(name, attributes=attributes)
|
span = self._tracer.start_as_current_span(name, attributes=attributes)
|
||||||
return OTelSpan(span)
|
return OTelSpan(span)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -67,7 +67,7 @@ def get_tracer(name: str = "fischer.agentkit"):
|
||||||
|
|
||||||
def start_span(
|
def start_span(
|
||||||
name: str,
|
name: str,
|
||||||
kind: Any = None,
|
kind: object = None,
|
||||||
attributes: dict | None = None,
|
attributes: dict | None = None,
|
||||||
):
|
):
|
||||||
"""Start a span — returns no-op span if OTel not installed.
|
"""Start a span — returns no-op span if OTel not installed.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue