339 lines
13 KiB
Python
339 lines
13 KiB
Python
"""Client-side configuration sync engine.
|
||
|
||
Pulls skill/agent/workflow configs from the server on startup and polls
|
||
for changes every 5 minutes. Configs are cached in a local SQLite DB so
|
||
the client can operate offline.
|
||
|
||
Usage::
|
||
|
||
from agentkit.client.sync import ConfigSync
|
||
|
||
sync = ConfigSync(
|
||
server_url="http://localhost:8001",
|
||
token_provider=lambda: jwt_token, # or None for dev mode
|
||
cache_db_path="~/.agentkit/config_cache.db",
|
||
)
|
||
await sync.start() # Initial full pull
|
||
await sync.poll_loop() # Background polling (or run in a task)
|
||
configs = sync.get_configs() # Returns cached configs
|
||
await sync.stop()
|
||
|
||
Design (KTD3):
|
||
- Polling, not WebSocket push (config changes are weekly/monthly)
|
||
- Full pull on version change (not incremental diff)
|
||
- SQLite cache for offline operation
|
||
- 5-minute poll interval (configurable)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Callable, TypeAlias
|
||
|
||
import httpx
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 缓存的配置项 — 技能/工作流配置为 JSON 反序列化后的字典,值为标量或嵌套结构。
|
||
# 服务器返回的 skills/workflows 列表元素是 dict(model_dump/to_dict 输出),
|
||
# 其中可能包含 list/dict 等容器,因此使用 object 作为值类型。
|
||
SkillConfigDict: TypeAlias = dict[str, object]
|
||
WorkflowConfigDict: TypeAlias = dict[str, object]
|
||
SyncedConfigPayload: TypeAlias = dict[str, object]
|
||
|
||
|
||
# ── Defaults ──────────────────────────────────────────────────────────
|
||
|
||
DEFAULT_POLL_INTERVAL = 300 # 5 minutes
|
||
DEFAULT_TIMEOUT = 30.0
|
||
DEFAULT_CACHE_DB_PATH = Path(
|
||
os.environ.get("AGENTKIT_CONFIG_CACHE", str(Path.home() / ".agentkit" / "config_cache.db"))
|
||
)
|
||
|
||
|
||
# ── SQLite cache schema ───────────────────────────────────────────────
|
||
|
||
_CACHE_SCHEMA = """
|
||
CREATE TABLE IF NOT EXISTS config_cache (
|
||
key TEXT PRIMARY KEY,
|
||
value TEXT NOT NULL,
|
||
updated_at TEXT NOT NULL
|
||
);
|
||
"""
|
||
|
||
|
||
# ── ConfigSync ────────────────────────────────────────────────────────
|
||
|
||
|
||
class ConfigSync:
|
||
"""Client-side configuration sync engine.
|
||
|
||
Pulls configs from the server, caches them locally, and polls for
|
||
changes on a configurable interval.
|
||
|
||
Attributes:
|
||
server_url: Base URL of the AgentKit server (e.g. ``http://localhost:8001``).
|
||
token_provider: Callable that returns the current JWT access token
|
||
(or ``None`` if not authenticated). Called on each request.
|
||
cache_db_path: Path to the local SQLite cache file.
|
||
poll_interval: Seconds between version polls (default 300 = 5 min).
|
||
timeout: HTTP request timeout in seconds.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
server_url: str,
|
||
token_provider: Callable[[], str | None] | None = None,
|
||
cache_db_path: str | Path | None = None,
|
||
poll_interval: int = DEFAULT_POLL_INTERVAL,
|
||
timeout: float = DEFAULT_TIMEOUT,
|
||
) -> None:
|
||
self.server_url = server_url.rstrip("/")
|
||
self.token_provider = token_provider
|
||
self.cache_db_path = Path(cache_db_path) if cache_db_path else DEFAULT_CACHE_DB_PATH
|
||
self.poll_interval = poll_interval
|
||
self.timeout = timeout
|
||
|
||
self._client: httpx.AsyncClient | None = None
|
||
self._poll_task: asyncio.Task | None = None
|
||
self._stopped = False
|
||
|
||
# In-memory cache (mirrors the SQLite cache for fast access)
|
||
self._version: str | None = None
|
||
self._skills: list[SkillConfigDict] = []
|
||
self._workflows: list[WorkflowConfigDict] = []
|
||
self._last_synced_at: str | None = None
|
||
|
||
# ── Lifecycle ─────────────────────────────────────────────────
|
||
|
||
async def start(self) -> bool:
|
||
"""Perform an initial full sync.
|
||
|
||
Tries to pull configs from the server. On failure, loads the
|
||
local cache. Returns ``True`` if the server was reachable.
|
||
"""
|
||
self._init_cache_db()
|
||
self._client = httpx.AsyncClient(timeout=self.timeout)
|
||
|
||
# Try server sync first
|
||
success = await self._pull_all()
|
||
if not success:
|
||
logger.info("Server unreachable, loading cached configs")
|
||
self._load_from_cache()
|
||
|
||
return success
|
||
|
||
async def stop(self) -> None:
|
||
"""Stop the polling loop and close the HTTP client."""
|
||
self._stopped = True
|
||
if self._poll_task and not self._poll_task.done():
|
||
self._poll_task.cancel()
|
||
try:
|
||
await self._poll_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
if self._client:
|
||
await self._client.aclose()
|
||
self._client = None
|
||
|
||
async def poll_loop(self) -> None:
|
||
"""Background polling loop.
|
||
|
||
Every ``poll_interval`` seconds, checks the server's config
|
||
version. If it differs from the cached version, re-pulls all
|
||
configs. On network errors, keeps the existing cache and retries
|
||
on the next interval.
|
||
"""
|
||
while not self._stopped:
|
||
try:
|
||
await asyncio.sleep(self.poll_interval)
|
||
if self._stopped:
|
||
break
|
||
await self._check_and_sync()
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.warning(f"Config poll error: {e}")
|
||
|
||
def start_polling(self) -> asyncio.Task:
|
||
"""Start the polling loop as a background task."""
|
||
self._poll_task = asyncio.create_task(self.poll_loop())
|
||
return self._poll_task
|
||
|
||
# ── Sync operations ───────────────────────────────────────────
|
||
|
||
async def _check_and_sync(self) -> bool:
|
||
"""Check the server version; re-pull if changed.
|
||
|
||
Returns ``True`` if configs were re-pulled, ``False`` if the
|
||
version was unchanged or the server was unreachable.
|
||
"""
|
||
if not self._client:
|
||
return False
|
||
|
||
try:
|
||
resp = await self._client.get(
|
||
f"{self.server_url}/api/v1/config/version",
|
||
headers=self._build_headers(),
|
||
)
|
||
if resp.status_code != 200:
|
||
logger.warning(f"Version check failed: HTTP {resp.status_code}")
|
||
return False
|
||
|
||
server_version = resp.json().get("version")
|
||
if server_version == self._version:
|
||
logger.debug("Config version unchanged, skipping sync")
|
||
return False
|
||
|
||
logger.info(f"Config version changed: {self._version} → {server_version}")
|
||
return await self._pull_all()
|
||
|
||
except (httpx.HTTPError, OSError) as e:
|
||
logger.warning(f"Version check network error: {e}")
|
||
return False
|
||
|
||
async def _pull_all(self) -> bool:
|
||
"""Pull all configs from the server and update the cache.
|
||
|
||
Returns ``True`` on success, ``False`` on failure.
|
||
"""
|
||
if not self._client:
|
||
return False
|
||
|
||
try:
|
||
resp = await self._client.get(
|
||
f"{self.server_url}/api/v1/config/all",
|
||
headers=self._build_headers(),
|
||
)
|
||
if resp.status_code != 200:
|
||
logger.warning(f"Config pull failed: HTTP {resp.status_code}")
|
||
return False
|
||
|
||
data = resp.json()
|
||
self._version = data.get("version")
|
||
self._skills = data.get("skills", [])
|
||
self._workflows = data.get("workflows", [])
|
||
self._last_synced_at = data.get("synced_at")
|
||
|
||
self._save_to_cache(data)
|
||
logger.info(
|
||
f"Synced {len(self._skills)} skills, {len(self._workflows)} workflows "
|
||
f"(version={self._version})"
|
||
)
|
||
return True
|
||
|
||
except (httpx.HTTPError, OSError, json.JSONDecodeError) as e:
|
||
logger.warning(f"Config pull error: {e}")
|
||
return False
|
||
|
||
# ── Cache access ──────────────────────────────────────────────
|
||
|
||
def get_version(self) -> str | None:
|
||
"""Return the current cached config version hash."""
|
||
return self._version
|
||
|
||
def get_skills(self) -> list[SkillConfigDict]:
|
||
"""Return the cached skill configs."""
|
||
return list(self._skills)
|
||
|
||
def get_workflows(self) -> list[WorkflowConfigDict]:
|
||
"""Return the cached workflow configs."""
|
||
return list(self._workflows)
|
||
|
||
def get_all(self) -> SyncedConfigPayload:
|
||
"""Return all cached configs as a single dict."""
|
||
return {
|
||
"version": self._version,
|
||
"skills": list(self._skills),
|
||
"workflows": list(self._workflows),
|
||
"synced_at": self._last_synced_at,
|
||
}
|
||
|
||
def get_skill(self, name: str) -> SkillConfigDict | None:
|
||
"""Return a single skill config by name, or ``None``."""
|
||
for skill in self._skills:
|
||
if skill.get("name") == name:
|
||
return skill
|
||
return None
|
||
|
||
def get_workflow(self, workflow_id: str) -> WorkflowConfigDict | None:
|
||
"""Return a single workflow config by ID, or ``None``."""
|
||
for wf in self._workflows:
|
||
if wf.get("workflow_id") == workflow_id:
|
||
return wf
|
||
return None
|
||
|
||
# ── Internal helpers ──────────────────────────────────────────
|
||
|
||
def _build_headers(self) -> dict[str, str]:
|
||
"""Build HTTP headers with JWT if available."""
|
||
headers: dict[str, str] = {"Accept": "application/json"}
|
||
if self.token_provider:
|
||
token = self.token_provider()
|
||
if token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
return headers
|
||
|
||
def _init_cache_db(self) -> None:
|
||
"""Initialize the SQLite cache database."""
|
||
self.cache_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
||
conn.executescript(_CACHE_SCHEMA)
|
||
conn.commit()
|
||
|
||
def _save_to_cache(self, data: SyncedConfigPayload) -> None:
|
||
"""Save the synced configs to the local SQLite cache."""
|
||
now = datetime.now(timezone.utc).isoformat()
|
||
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
||
conn.executescript(_CACHE_SCHEMA)
|
||
entries = [
|
||
("version", json.dumps(data.get("version")), now),
|
||
("skills", json.dumps(data.get("skills", [])), now),
|
||
("workflows", json.dumps(data.get("workflows", [])), now),
|
||
("synced_at", json.dumps(data.get("synced_at")), now),
|
||
]
|
||
conn.executemany(
|
||
"INSERT OR REPLACE INTO config_cache (key, value, updated_at) VALUES (?, ?, ?)",
|
||
entries,
|
||
)
|
||
conn.commit()
|
||
|
||
def _load_from_cache(self) -> bool:
|
||
"""Load configs from the local SQLite cache.
|
||
|
||
Returns ``True`` if the cache had data, ``False`` otherwise.
|
||
"""
|
||
try:
|
||
with sqlite3.connect(str(self.cache_db_path)) as conn:
|
||
conn.row_factory = sqlite3.Row
|
||
cursor = conn.execute(
|
||
"SELECT key, value FROM config_cache WHERE key IN (?, ?, ?, ?)",
|
||
("version", "skills", "workflows", "synced_at"),
|
||
)
|
||
rows = {row["key"]: row["value"] for row in cursor.fetchall()}
|
||
|
||
if not rows:
|
||
return False
|
||
|
||
self._version = json.loads(rows.get("version", "null"))
|
||
self._skills = json.loads(rows.get("skills", "[]"))
|
||
self._workflows = json.loads(rows.get("workflows", "[]"))
|
||
self._last_synced_at = json.loads(rows.get("synced_at", "null"))
|
||
|
||
logger.info(
|
||
f"Loaded {len(self._skills)} skills, {len(self._workflows)} workflows "
|
||
f"from cache (version={self._version})"
|
||
)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load config cache: {e}")
|
||
return False
|