fischer-agentkit/src/agentkit/client/sync.py

339 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Client-side configuration sync engine.
Pulls skill/agent/workflow configs from the server on startup and polls
for changes every 5 minutes. Configs are cached in a local SQLite DB so
the client can operate offline.
Usage::
from agentkit.client.sync import ConfigSync
sync = ConfigSync(
server_url="http://localhost:8001",
token_provider=lambda: jwt_token, # or None for dev mode
cache_db_path="~/.agentkit/config_cache.db",
)
await sync.start() # Initial full pull
await sync.poll_loop() # Background polling (or run in a task)
configs = sync.get_configs() # Returns cached configs
await sync.stop()
Design (KTD3):
- Polling, not WebSocket push (config changes are weekly/monthly)
- Full pull on version change (not incremental diff)
- SQLite cache for offline operation
- 5-minute poll interval (configurable)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable, TypeAlias
import httpx
logger = logging.getLogger(__name__)
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
# 服务器返回的 skills/workflows 列表元素是 dictmodel_dump/to_dict 输出),
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
SkillConfigDict: TypeAlias = dict[str, object]
WorkflowConfigDict: TypeAlias = dict[str, object]
SyncedConfigPayload: TypeAlias = dict[str, object]
# ── Defaults ──────────────────────────────────────────────────────────
DEFAULT_POLL_INTERVAL = 300 # 5 minutes
DEFAULT_TIMEOUT = 30.0
DEFAULT_CACHE_DB_PATH = Path(
os.environ.get("AGENTKIT_CONFIG_CACHE", str(Path.home() / ".agentkit" / "config_cache.db"))
)
# ── SQLite cache schema ───────────────────────────────────────────────
_CACHE_SCHEMA = """
CREATE TABLE IF NOT EXISTS config_cache (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL
);
"""
# ── ConfigSync ────────────────────────────────────────────────────────
class ConfigSync:
"""Client-side configuration sync engine.
Pulls configs from the server, caches them locally, and polls for
changes on a configurable interval.
Attributes:
server_url: Base URL of the AgentKit server (e.g. ``http://localhost:8001``).
token_provider: Callable that returns the current JWT access token
(or ``None`` if not authenticated). Called on each request.
cache_db_path: Path to the local SQLite cache file.
poll_interval: Seconds between version polls (default 300 = 5 min).
timeout: HTTP request timeout in seconds.
"""
def __init__(
self,
*,
server_url: str,
token_provider: Callable[[], str | None] | None = None,
cache_db_path: str | Path | None = None,
poll_interval: int = DEFAULT_POLL_INTERVAL,
timeout: float = DEFAULT_TIMEOUT,
) -> None:
self.server_url = server_url.rstrip("/")
self.token_provider = token_provider
self.cache_db_path = Path(cache_db_path) if cache_db_path else DEFAULT_CACHE_DB_PATH
self.poll_interval = poll_interval
self.timeout = timeout
self._client: httpx.AsyncClient | None = None
self._poll_task: asyncio.Task | None = None
self._stopped = False
# In-memory cache (mirrors the SQLite cache for fast access)
self._version: str | None = None
self._skills: list[SkillConfigDict] = []
self._workflows: list[WorkflowConfigDict] = []
self._last_synced_at: str | None = None
# ── Lifecycle ─────────────────────────────────────────────────
async def start(self) -> bool:
"""Perform an initial full sync.
Tries to pull configs from the server. On failure, loads the
local cache. Returns ``True`` if the server was reachable.
"""
self._init_cache_db()
self._client = httpx.AsyncClient(timeout=self.timeout)
# Try server sync first
success = await self._pull_all()
if not success:
logger.info("Server unreachable, loading cached configs")
self._load_from_cache()
return success
async def stop(self) -> None:
"""Stop the polling loop and close the HTTP client."""
self._stopped = True
if self._poll_task and not self._poll_task.done():
self._poll_task.cancel()
try:
await self._poll_task
except asyncio.CancelledError:
pass
if self._client:
await self._client.aclose()
self._client = None
async def poll_loop(self) -> None:
"""Background polling loop.
Every ``poll_interval`` seconds, checks the server's config
version. If it differs from the cached version, re-pulls all
configs. On network errors, keeps the existing cache and retries
on the next interval.
"""
while not self._stopped:
try:
await asyncio.sleep(self.poll_interval)
if self._stopped:
break
await self._check_and_sync()
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Config poll error: {e}")
def start_polling(self) -> asyncio.Task:
"""Start the polling loop as a background task."""
self._poll_task = asyncio.create_task(self.poll_loop())
return self._poll_task
# ── Sync operations ───────────────────────────────────────────
async def _check_and_sync(self) -> bool:
"""Check the server version; re-pull if changed.
Returns ``True`` if configs were re-pulled, ``False`` if the
version was unchanged or the server was unreachable.
"""
if not self._client:
return False
try:
resp = await self._client.get(
f"{self.server_url}/api/v1/config/version",
headers=self._build_headers(),
)
if resp.status_code != 200:
logger.warning(f"Version check failed: HTTP {resp.status_code}")
return False
server_version = resp.json().get("version")
if server_version == self._version:
logger.debug("Config version unchanged, skipping sync")
return False
logger.info(f"Config version changed: {self._version}{server_version}")
return await self._pull_all()
except (httpx.HTTPError, OSError) as e:
logger.warning(f"Version check network error: {e}")
return False
async def _pull_all(self) -> bool:
"""Pull all configs from the server and update the cache.
Returns ``True`` on success, ``False`` on failure.
"""
if not self._client:
return False
try:
resp = await self._client.get(
f"{self.server_url}/api/v1/config/all",
headers=self._build_headers(),
)
if resp.status_code != 200:
logger.warning(f"Config pull failed: HTTP {resp.status_code}")
return False
data = resp.json()
self._version = data.get("version")
self._skills = data.get("skills", [])
self._workflows = data.get("workflows", [])
self._last_synced_at = data.get("synced_at")
self._save_to_cache(data)
logger.info(
f"Synced {len(self._skills)} skills, {len(self._workflows)} workflows "
f"(version={self._version})"
)
return True
except (httpx.HTTPError, OSError, json.JSONDecodeError) as e:
logger.warning(f"Config pull error: {e}")
return False
# ── Cache access ──────────────────────────────────────────────
def get_version(self) -> str | None:
"""Return the current cached config version hash."""
return self._version
def get_skills(self) -> list[SkillConfigDict]:
"""Return the cached skill configs."""
return list(self._skills)
def get_workflows(self) -> list[WorkflowConfigDict]:
"""Return the cached workflow configs."""
return list(self._workflows)
def get_all(self) -> SyncedConfigPayload:
"""Return all cached configs as a single dict."""
return {
"version": self._version,
"skills": list(self._skills),
"workflows": list(self._workflows),
"synced_at": self._last_synced_at,
}
def get_skill(self, name: str) -> SkillConfigDict | None:
"""Return a single skill config by name, or ``None``."""
for skill in self._skills:
if skill.get("name") == name:
return skill
return None
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
"""Return a single workflow config by ID, or ``None``."""
for wf in self._workflows:
if wf.get("workflow_id") == workflow_id:
return wf
return None
# ── Internal helpers ──────────────────────────────────────────
def _build_headers(self) -> dict[str, str]:
"""Build HTTP headers with JWT if available."""
headers: dict[str, str] = {"Accept": "application/json"}
if self.token_provider:
token = self.token_provider()
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _init_cache_db(self) -> None:
"""Initialize the SQLite cache database."""
self.cache_db_path.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(str(self.cache_db_path)) as conn:
conn.executescript(_CACHE_SCHEMA)
conn.commit()
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
"""Save the synced configs to the local SQLite cache."""
now = datetime.now(timezone.utc).isoformat()
with sqlite3.connect(str(self.cache_db_path)) as conn:
conn.executescript(_CACHE_SCHEMA)
entries = [
("version", json.dumps(data.get("version")), now),
("skills", json.dumps(data.get("skills", [])), now),
("workflows", json.dumps(data.get("workflows", [])), now),
("synced_at", json.dumps(data.get("synced_at")), now),
]
conn.executemany(
"INSERT OR REPLACE INTO config_cache (key, value, updated_at) VALUES (?, ?, ?)",
entries,
)
conn.commit()
def _load_from_cache(self) -> bool:
"""Load configs from the local SQLite cache.
Returns ``True`` if the cache had data, ``False`` otherwise.
"""
try:
with sqlite3.connect(str(self.cache_db_path)) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT key, value FROM config_cache WHERE key IN (?, ?, ?, ?)",
("version", "skills", "workflows", "synced_at"),
)
rows = {row["key"]: row["value"] for row in cursor.fetchall()}
if not rows:
return False
self._version = json.loads(rows.get("version", "null"))
self._skills = json.loads(rows.get("skills", "[]"))
self._workflows = json.loads(rows.get("workflows", "[]"))
self._last_synced_at = json.loads(rows.get("synced_at", "null"))
logger.info(
f"Loaded {len(self._skills)} skills, {len(self._workflows)} workflows "
f"from cache (version={self._version})"
)
return True
except Exception as e:
logger.warning(f"Failed to load config cache: {e}")
return False