fix(security,reliability): resolve all P2 findings from code review

This commit is contained in:
chiguyong 2026-06-10 15:05:40 +08:00
parent 658e188939
commit 6852dfe892
10 changed files with 770 additions and 278 deletions

View File

@ -0,0 +1,287 @@
# fix: AgentKit P2 Security & Reliability Hardening
**Status:** active
**Created:** 2026-06-10
**Origin:** ce-code-review findings on main branch after 7-capabilities merge
---
## Summary
Fix all remaining P2 issues from the final code review of the `feat/agentkit-7-capabilities` merge. These 16 issues span security (SSRF, auth, injection), reliability (retry, concurrency, resource leaks), and architecture (auth consistency, config reload). All are non-blocking for current functionality but represent real risk in production.
## Problem Frame
After merging 7 core capabilities and fixing P0/P1 issues, the codebase has 16 P2 findings that fall into four categories:
1. **Security gaps** — SSRF DNS rebinding, API key timing attacks, WebSocket auth bypass, approval race conditions, shell pattern matching, .env injection
2. **Reliability gaps** — retry storms, unbounded memory growth, concurrent modification, config change loss
3. **Architecture gaps** — dual auth layer inconsistency, WebSocket broadcast without isolation
4. **Defensive gaps** — process kill race, lazy init race, httpx client double-creation
## Requirements
- R1: All P2 security findings must be resolved or explicitly deferred with documented rationale
- R2: Retry logic must use exponential backoff with jitter
- R3: All in-memory stores must have bounded growth with proper eviction
- R4: Concurrent access to shared mutable state must be protected
- R5: Authentication must be consistent across all endpoints and transports
- R6: No regression in existing test suite (764+ tests must pass)
## Key Technical Decisions
- **KTD1: DNS resolution for SSRF** — Use `asyncio.getaddrinfo` (not blocking `socket.getaddrinfo`) to resolve hostnames before IP validation. Fail-closed on DNS errors. Accept that TOCTOU remains between check and connect; full mitigation requires IP pinning at the httpx transport level (deferred).
- **KTD2: Auth consolidation strategy** — Fix per-endpoint auth first (hmac.compare_digest, WebSocket auth-before-accept), then consolidate into middleware in a separate pass. This avoids a big-bang refactor and allows incremental testing.
- **KTD3: WorkflowStore concurrency** — Add `asyncio.Lock` to WorkflowStore, convert sync methods to async. This is the most invasive change but necessary for correctness under concurrent load.
- **KTD4: Shell dangerous patterns** — Replace substring matching with token-based matching using the already-parsed shlex tokens. Separate binary-level and flag-level dangerous patterns.
- **KTD5: WebSocket isolation** — Replace flat subscriber list with `dict[str, set[WebSocket]]` keyed by execution_id. Clients subscribe to specific executions.
---
## Implementation Units
### U1. API Key Timing Attack & WebSocket Auth-Before-Accept
**Goal:** Eliminate timing side-channel in API key comparison and fix WebSocket accept-before-auth pattern.
**Requirements:** R5
**Dependencies:** None
**Files:**
- `src/agentkit/server/routes/portal.py`
- `src/agentkit/server/routes/workflows.py`
- `src/agentkit/server/middleware.py`
- `tests/unit/server/test_portal_routes.py`
- `tests/unit/server/test_workflow_routes.py`
**Approach:**
1. Add `import hmac` to portal.py, workflows.py, middleware.py
2. Replace all `!=` API key comparisons with `hmac.compare_digest()`
3. In both WebSocket handlers, move auth validation BEFORE `websocket.accept()`. On auth failure, call `await websocket.close(code=4001)` without accepting, then return
4. In middleware.py, replace `provided_key not in valid_keys` with `any(hmac.compare_digest(...) for k in valid_keys)`
5. Handle None/empty provided key by defaulting to empty string before comparison
**Test scenarios:**
- API key comparison with correct key returns True
- API key comparison with incorrect key returns False
- API key comparison with None provided defaults to empty string
- WebSocket connection without api_key is rejected (no 101 upgrade)
- WebSocket connection with wrong api_key is rejected
- WebSocket connection with correct api_key is accepted
**Verification:** All existing tests pass; new auth tests pass
---
### U2. Shell Dangerous Pattern Token-Based Matching
**Goal:** Replace substring-based dangerous pattern matching with precise token-based matching to eliminate false positives and bypasses.
**Requirements:** R1
**Dependencies:** None
**Files:**
- `src/agentkit/tools/shell.py`
- `tests/unit/test_shell_tool.py`
**Approach:**
1. Replace `_DANGEROUS_PATTERNS` tuple with two new structures:
- `_DANGEROUS_BINARIES`: set of binary names that are always dangerous (e.g., `rm`, `mkfs`, `dd`, `shutdown`)
- `_DANGEROUS_BINARY_FLAGS`: dict mapping binary name to set of dangerous flag combinations (e.g., `kill: {-9, -KILL}`, `chmod: {777}`)
2. In `_is_dangerous()`, after shlex parsing:
- Check `binary` against `_DANGEROUS_BINARIES` (exact match)
- Check `binary` + flags against `_DANGEROUS_BINARY_FLAGS`
- Keep regex-based arg patterns for cross-token matches (e.g., `> /dev/`, `drop table`)
3. Remove overly broad patterns like `format`, `erase` as substrings; add `format` as binary-only
4. Normalize whitespace in command before matching
**Test scenarios:**
- `rm -rf /` detected as dangerous (binary match)
- `echo 'performing rm operations'` NOT detected as dangerous (no binary match)
- `kill -9 1234` detected as dangerous (binary + flag match)
- `kill -15 1234` NOT detected as dangerous (flag not in dangerous set)
- `rm\t/tmp/file` detected as dangerous (tab normalization)
- `format C:` detected as dangerous (binary match)
- `informatica` NOT detected as dangerous (not a binary match)
**Verification:** Shell tool tests pass with new pattern logic
---
### U3. .env File Prefix Filtering & Retry Backoff
**Goal:** Prevent arbitrary environment variable injection from .env files and add exponential backoff to plan executor retries.
**Requirements:** R1, R2
**Dependencies:** None
**Files:**
- `src/agentkit/server/app.py`
- `src/agentkit/core/plan_executor.py`
- `tests/unit/test_plan_executor.py`
**Approach:**
1. In app.py .env loading, add an allowlist of prefixes: `AGENTKIT_`, `OPENAI_`, `ANTHROPIC_`, `GEMINI_`, `TAVILY_`, `SERPER_`, plus exact matches for `DATABASE_URL`, `REDIS_URL`
2. Log a warning when a .env variable is skipped due to prefix filtering
3. In plan_executor.py `_execute_step_with_retry()`, add exponential backoff with jitter after each failure:
- `base_retry_delay=1.0s`, `max_retry_delay=30.0s`
- `delay = min(base * 2^(retry-1), max) * (0.5 + random() * 0.5)`
- Skip delay on first attempt (retry_count == 0)
4. Add `base_retry_delay` and `max_retry_delay` as constructor parameters with defaults
**Test scenarios:**
- .env with `AGENTKIT_FOO=bar` is loaded
- .env with `PATH=/malicious` is skipped with warning
- .env with `OPENAI_API_KEY=sk-xxx` is loaded
- Retry backoff: first retry waits ~1s, second ~2s, third ~4s (with jitter)
- Max retry delay caps at 30s
- Plan executor with max_retries=0 still works (no backoff)
**Verification:** Unit tests pass; .env loading respects prefix filter
---
### U4. WorkflowStore Concurrency & Resource Management
**Goal:** Make WorkflowStore thread-safe, fix unbounded growth, add proper cleanup for approval events and running tasks.
**Requirements:** R3, R4
**Dependencies:** U1 (hmac.compare_digest in workflow routes)
**Files:**
- `src/agentkit/server/routes/workflows.py`
- `tests/unit/server/test_workflow_routes.py`
**Approach:**
1. Add `asyncio.Lock` to WorkflowStore; convert `save`, `delete`, `create_execution`, `update_execution` to async methods
2. Initialize `_running_tasks` in `__init__` instead of lazy `getattr`
3. Add `_evict_execution()` helper that removes execution AND its approval events (with `event.set()` to wake waiting coroutines)
4. Replace `_ws_subscribers: list` with `_ws_subscribers: dict[str, set[WebSocket]]` keyed by execution_id, protected by `asyncio.Lock`
5. Add `_subscribe(execution_id, ws)` and `_unsubscribe(execution_id, ws)` helpers
6. Modify `_broadcast_ws` to accept `execution_id` parameter and only send to relevant subscribers
7. Update all call sites from sync to async (all endpoint functions already async)
8. Add `_execution_locks: dict[str, asyncio.Lock]` for per-execution approve/cancel serialization
**Test scenarios:**
- Concurrent save operations don't cause KeyError or over-eviction
- Evicted execution's approval events are cleaned up
- WebSocket subscriber receives only messages for subscribed execution_id
- WebSocket disconnect properly cleans up from subscriber dict
- Concurrent approve/cancel for same execution serialized by lock
- _running_tasks initialized in __init__, no getattr fallback
**Verification:** Workflow route tests pass; no RuntimeError from concurrent modification
---
### U5. SSRF DNS Resolution & ComputerUse Client Safety
**Goal:** Add DNS resolution to SSRF protection and make ComputerUseTool's httpx client creation thread-safe.
**Requirements:** R1
**Dependencies:** None
**Files:**
- `src/agentkit/utils/security.py`
- `src/agentkit/tools/computer_use.py`
- `tests/unit/test_security.py`
- `tests/unit/tools/test_computer_use.py`
**Approach:**
1. In `is_safe_url()`, after the ValueError catch for domain names, add DNS resolution:
- Use `socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)` (sync, acceptable for tool init)
- Validate each resolved IP against `_is_unsafe_ip()`
- Fail-closed on DNS errors (return False)
- Add 3-second timeout on DNS resolution
2. Add `is_safe_url_async()` variant using `asyncio.getaddrinfo` for async callers
3. In ComputerUseTool, add `asyncio.Lock` to `_get_http_client()`, make it async with double-check locking
4. Update caller from `client = self._get_http_client()` to `client = await self._get_http_client()`
**Test scenarios:**
- is_safe_url with domain resolving to 127.0.0.1 returns False
- is_safe_url with domain resolving to public IP returns True
- is_safe_url with unresolvable domain returns False (fail-closed)
- is_safe_url with ::ffff:127.0.0.1 returns False
- is_safe_url with .internal TLD returns False
- ComputerUseTool concurrent _get_http_client returns same client instance
- ComputerUseTool close() properly cleans up client
**Verification:** Security and computer_use tests pass
---
### U6. Config Hot-Reload & Defensive Fixes
**Goal:** Fix config hot-reload lock race, add proc.kill() error handling, and initialize config reload lock eagerly.
**Requirements:** R4
**Dependencies:** None
**Files:**
- `src/agentkit/server/app.py`
- `src/agentkit/tools/shell.py`
**Approach:**
1. In app.py `_on_config_change()`:
- Initialize lock eagerly in `create_app()` (store on `app.state._config_reload_lock`)
- Replace `lock.locked()` skip with pending-flag pattern: set `app.state._config_reload_pending = True`, then in `_reload()` loop while pending flag is set
- This ensures no config change is silently dropped
2. In shell.py `_execute_standalone()`:
- Wrap `proc.kill()` in `try/except (ProcessLookupError, OSError)`
- Use `proc.returncode` if available instead of hardcoded -1
**Test scenarios:**
- Two rapid config changes both get applied (no silent drop)
- Config reload pending flag is cleared after reload completes
- Shell timeout with already-exited process returns clean error (no ProcessLookupError)
- Shell timeout output still shows timeout message
**Verification:** App startup and shell tool tests pass
---
## Scope Boundaries
### In Scope
- All 16 P2 findings from ce-code-review
- Test coverage for new security/reliability patterns
### Deferred to Follow-Up Work
- Full auth middleware consolidation (ASGI-level WebSocket middleware) — complex, separate PR
- IP pinning at httpx transport level for complete SSRF protection
- python-dotenv migration for robust .env parsing
- clients.yaml caching in middleware (per-request disk read)
- WorkflowStore persistence (currently in-memory only)
### Outside This Plan's Identity
- New features or capability additions
- Performance optimization beyond what's needed for reliability
- UI/UX changes
---
## Risks & Mitigations
| Risk | Impact | Mitigation |
|------|--------|------------|
| WorkflowStore async refactor breaks existing endpoints | High | Update all call sites; comprehensive test coverage |
| DNS resolution adds latency to URL validation | Low | 3s timeout; acceptable for tool initialization |
| Token-based shell matching misses new bypass patterns | Medium | Default-to-dangerous fallback; add patterns as discovered |
| .env prefix filtering breaks existing deployments | Medium | Comprehensive allowlist; log warnings for skipped vars |
| WebSocket isolation breaks existing WS clients | Medium | Maintain backward compat: accept without subscription, send all |
| Config reload pending loop runs indefinitely | Low | Coalescing naturally limits; only re-runs if new change arrived |
---
## System-Wide Impact
- **Security posture** significantly improved: timing attacks, SSRF, auth bypass all addressed
- **Reliability** improved: retry storms prevented, memory bounded, concurrent access safe
- **API compatibility**: WorkflowStore method signatures change from sync to async — all call sites updated
- **WebSocket protocol**: subscription message becomes recommended but not required for backward compat

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import asyncio
import logging
import random
import time
from dataclasses import dataclass, field
from enum import Enum
@ -94,6 +95,8 @@ class PlanExecutor:
max_retries: int = 2,
step_timeout: float = 300.0,
max_parallel: int = 5,
base_retry_delay: float = 1.0,
max_retry_delay: float = 30.0,
on_step_complete: OnStepCompleteCallback | None = None,
on_step_failed: OnStepFailedCallback | None = None,
on_human_intervention: OnHumanInterventionCallback | None = None,
@ -104,6 +107,8 @@ class PlanExecutor:
max_retries: 步骤失败后最大重试次数
step_timeout: 单个步骤超时时间
max_parallel: 最大并行步骤数
base_retry_delay: 重试基础延迟
max_retry_delay: 重试最大延迟
on_step_complete: 步骤完成回调
on_step_failed: 步骤失败回调返回 FailureAction 决定后续处理
on_human_intervention: 人工介入回调
@ -112,6 +117,8 @@ class PlanExecutor:
self._max_retries = max_retries
self._step_timeout = step_timeout
self._max_parallel = max_parallel
self._base_retry_delay = base_retry_delay
self._max_retry_delay = max_retry_delay
self._on_step_complete = on_step_complete
self._on_step_failed = on_step_failed
self._on_human_intervention = on_human_intervention
@ -250,6 +257,15 @@ class PlanExecutor:
retry_count += 1
if retry_count <= self._max_retries:
delay = min(
self._base_retry_delay * (2 ** (retry_count - 1)),
self._max_retry_delay,
)
delay *= (0.5 + random.random() * 0.5) # jitter
logger.info(f"Retrying step '{step.step_id}' in {delay:.1f}s (attempt {retry_count + 1}/{self._max_retries + 1})")
await asyncio.sleep(delay)
# 所有重试耗尽
step.status = PlanStepStatus.FAILED
step.error = last_error

View File

@ -30,6 +30,12 @@ from agentkit.telemetry.setup import setup_telemetry
logger = logging.getLogger(__name__)
_ALLOWED_ENV_PREFIXES = (
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
)
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
"""Build LLMGateway from ServerConfig, registering all providers."""
@ -224,70 +230,69 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
Uses a lock to prevent concurrent config reloads from racing.
"""
lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None)
if lock is None:
lock = asyncio.Lock()
app.state._config_reload_lock = lock
lock: asyncio.Lock = app.state._config_reload_lock
if lock.locked():
logger.warning("Config reload already in progress, skipping")
return
app.state._config_reload_pending = True
async def _reload():
if lock.locked():
return # Another reload running; it will check pending flag
async with lock:
# Increment config version for audit
current_version = getattr(app.state, "config_version", 0) + 1
app.state.config_version = current_version
logger.info(f"Config change detected (v{current_version}), reloading...")
while getattr(app.state, "_config_reload_pending", False):
app.state._config_reload_pending = False
# Increment config version for audit
current_version = getattr(app.state, "config_version", 0) + 1
app.state.config_version = current_version
logger.info(f"Config change detected (v{current_version}), reloading...")
# Rebuild LLMGateway if llm config changed
try:
new_gateway = _build_llm_gateway(config)
app.state.llm_gateway = new_gateway
# Also update the agent pool's gateway reference
# Rebuild LLMGateway if llm config changed
try:
new_gateway = _build_llm_gateway(config)
app.state.llm_gateway = new_gateway
# Also update the agent pool's gateway reference
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._llm_gateway = new_gateway
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
app.state.intent_router._llm_gateway = new_gateway
logger.info(f"LLM Gateway reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload LLM Gateway: {e}")
# Reload skills if skill paths changed
try:
new_skill_registry = _build_skill_registry(config)
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
tool_registry = getattr(app.state, "tool_registry", None)
if tool_registry:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=new_skill_registry,
tool_registry=tool_registry,
)
for skill_path in (config.skill_paths or []):
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loader.load_from_directory(str(p))
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
app.state.skill_registry = new_skill_registry
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._skill_registry = new_skill_registry
logger.info(f"Skills reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload skills: {e}")
# Update config version on all agents
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._llm_gateway = new_gateway
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
app.state.intent_router._llm_gateway = new_gateway
logger.info(f"LLM Gateway reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload LLM Gateway: {e}")
for agent in app.state.agent_pool._agents.values():
if hasattr(agent, "_config_version"):
agent._config_version = current_version
# Reload skills if skill paths changed
try:
new_skill_registry = _build_skill_registry(config)
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
tool_registry = getattr(app.state, "tool_registry", None)
if tool_registry:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=new_skill_registry,
tool_registry=tool_registry,
)
for skill_path in (config.skill_paths or []):
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loader.load_from_directory(str(p))
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
app.state.skill_registry = new_skill_registry
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._skill_registry = new_skill_registry
logger.info(f"Skills reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload skills: {e}")
# Update config version on all agents
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
for agent in app.state.agent_pool._agents.values():
if hasattr(agent, "_config_version"):
agent._config_version = current_version
logger.info(f"Config reload complete (v{current_version})")
logger.info(f"Config reload complete (v{current_version})")
# Schedule the reload as a task (non-blocking for the watcher thread)
try:
@ -327,6 +332,10 @@ def create_app(
_key = _key.strip()
_val = _val.strip().strip("\"'")
if _key and _key not in os.environ:
allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT
if not allowed:
logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)")
continue
os.environ[_key] = _val
server_config = ServerConfig.from_yaml(config_path)
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)

View File

@ -1,5 +1,6 @@
"""Server middleware - Authentication and Rate Limiting"""
import hmac
import os
import time
from collections import defaultdict
@ -75,7 +76,7 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
# Check API key from header
provided_key = request.headers.get("X-API-Key")
if not provided_key or provided_key not in valid_keys:
if not provided_key or not any(hmac.compare_digest(provided_key.encode(), k.encode()) for k in valid_keys):
return JSONResponse(
status_code=401,
content={"error": "Unauthorized", "message": "Invalid or missing API key"},

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import hmac
import json
import logging
import uuid
@ -47,7 +48,7 @@ async def _verify_api_key(
return
provided = api_key_header or api_key_query
if provided != configured_api_key:
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
raise HTTPException(
status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
@ -475,11 +476,7 @@ async def portal_websocket(websocket: WebSocket):
# Check api_key query param
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json(
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
)
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
await websocket.close(code=4001, reason="Invalid or missing api_key")
return

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import hmac
import json
import logging
import re
@ -52,7 +53,7 @@ async def _verify_api_key(
return
provided = api_key_header or api_key_query
if provided != configured_api_key:
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
raise HTTPException(
status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
@ -65,7 +66,7 @@ async def _verify_api_key(
class WorkflowStore:
"""In-memory workflow store."""
"""In-memory workflow store with async-safe mutation methods."""
def __init__(self, max_workflows: int = 500, max_executions: int = 1000):
self._workflows: dict[str, WorkflowDefinition] = {}
@ -73,17 +74,30 @@ class WorkflowStore:
self._max_workflows = max_workflows
self._max_executions = max_executions
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
self._running_tasks: dict[str, asyncio.Task] = {}
self._execution_locks: dict[str, asyncio.Lock] = {}
self._lock = asyncio.Lock()
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
workflow.updated_at = datetime.now(timezone.utc).isoformat()
self._workflows[workflow.workflow_id] = workflow
# Evict oldest if over limit
if len(self._workflows) > self._max_workflows:
oldest_id = min(
self._workflows, key=lambda k: self._workflows[k].updated_at
)
del self._workflows[oldest_id]
return workflow
def _evict_execution(self, execution_id: str) -> None:
"""Remove execution and its associated approval events."""
self._executions.pop(execution_id, None)
keys_to_remove = [k for k in self._approval_events if k.startswith(f"{execution_id}:")]
for k in keys_to_remove:
event = self._approval_events.pop(k, None)
if event is not None:
event.set() # Wake any waiting coroutine
async def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
async with self._lock:
workflow.updated_at = datetime.now(timezone.utc).isoformat()
self._workflows[workflow.workflow_id] = workflow
# Evict oldest if over limit
if len(self._workflows) > self._max_workflows:
oldest_id = min(
self._workflows, key=lambda k: self._workflows[k].updated_at
)
del self._workflows[oldest_id]
return workflow
def get(self, workflow_id: str) -> WorkflowDefinition | None:
return self._workflows.get(workflow_id)
@ -109,47 +123,72 @@ class WorkflowStore:
)
return summaries
def delete(self, workflow_id: str) -> bool:
if workflow_id in self._workflows:
del self._workflows[workflow_id]
return True
return False
async def delete(self, workflow_id: str) -> bool:
async with self._lock:
if workflow_id in self._workflows:
del self._workflows[workflow_id]
return True
return False
def create_execution(self, workflow_id: str) -> WorkflowExecution:
execution = WorkflowExecution(
execution_id=str(uuid.uuid4()),
workflow_id=workflow_id,
status="pending",
started_at=datetime.now(timezone.utc).isoformat(),
)
self._executions[execution.execution_id] = execution
# Evict oldest if over limit
if len(self._executions) > self._max_executions:
oldest_id = min(
self._executions,
key=lambda k: self._executions[k].started_at or "",
async def create_execution(self, workflow_id: str) -> WorkflowExecution:
async with self._lock:
execution = WorkflowExecution(
execution_id=str(uuid.uuid4()),
workflow_id=workflow_id,
status="pending",
started_at=datetime.now(timezone.utc).isoformat(),
)
del self._executions[oldest_id]
return execution
self._executions[execution.execution_id] = execution
# Evict oldest if over limit
if len(self._executions) > self._max_executions:
oldest_id = min(
self._executions,
key=lambda k: self._executions[k].started_at or "",
)
self._evict_execution(oldest_id)
return execution
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
return self._executions.get(execution_id)
def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
execution = self._executions.get(execution_id)
if execution is None:
raise KeyError(f"Execution '{execution_id}' not found")
for key, value in kwargs.items():
if hasattr(execution, key):
setattr(execution, key, value)
return execution
async def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
async with self._lock:
execution = self._executions.get(execution_id)
if execution is None:
raise KeyError(f"Execution '{execution_id}' not found")
for key, value in kwargs.items():
if hasattr(execution, key):
setattr(execution, key, value)
return execution
def get_execution_lock(self, execution_id: str) -> asyncio.Lock:
"""Get or create a per-execution lock for approve/cancel serialization."""
if execution_id not in self._execution_locks:
self._execution_locks[execution_id] = asyncio.Lock()
return self._execution_locks[execution_id]
# Module-level singleton
_workflow_store = WorkflowStore()
# WebSocket subscribers for real-time execution progress
_ws_subscribers: list[WebSocket] = []
# WebSocket subscribers for real-time execution progress (keyed by execution_id)
_ws_subscribers: dict[str, set[WebSocket]] = {}
_ws_subscribers_lock = asyncio.Lock()
async def _ws_subscribe(execution_id: str, ws: WebSocket) -> None:
async with _ws_subscribers_lock:
if execution_id not in _ws_subscribers:
_ws_subscribers[execution_id] = set()
_ws_subscribers[execution_id].add(ws)
async def _ws_unsubscribe(execution_id: str, ws: WebSocket) -> None:
async with _ws_subscribers_lock:
if execution_id in _ws_subscribers:
_ws_subscribers[execution_id].discard(ws)
if not _ws_subscribers[execution_id]:
del _ws_subscribers[execution_id]
def _get_store(request: Request) -> WorkflowStore:
@ -210,7 +249,7 @@ async def _execute_workflow(
"""Execute a workflow by running its stages in topological order."""
_store = store or _workflow_store
execution.status = "running"
_store.update_execution(execution.execution_id, status="running")
await _store.update_execution(execution.execution_id, status="running")
# Topological sort
stage_map = {s.name: s for s in workflow.stages}
@ -229,7 +268,7 @@ async def _execute_workflow(
execution.status = "failed"
execution.error = "循环依赖"
execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="failed",
error="循环依赖",
@ -246,7 +285,7 @@ async def _execute_workflow(
for stage_name in ordered:
stage = stage_map[stage_name]
execution.current_stage = stage_name
_store.update_execution(
await _store.update_execution(
execution.execution_id,
current_stage=stage_name,
)
@ -256,7 +295,7 @@ async def _execute_workflow(
"event": "stage_started",
"execution_id": execution.execution_id,
"stage": stage_name,
})
}, execution_id=execution.execution_id)
try:
if stage.type == "approval":
@ -267,7 +306,7 @@ async def _execute_workflow(
execution.status = "paused"
execution.current_stage = stage_name
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="paused",
current_stage=stage_name,
@ -276,7 +315,7 @@ async def _execute_workflow(
"event": "approval_required",
"execution_id": execution.execution_id,
"stage": stage_name,
})
}, execution_id=execution.execution_id)
# Wait for approval with timeout
try:
@ -289,13 +328,13 @@ async def _execute_workflow(
"execution_id": execution.execution_id,
"stage": stage_name,
"error": "Approval rejected",
})
}, execution_id=execution.execution_id)
return
# Approval was granted — the /approve endpoint already set stage_results
# Only update status to running if not already set
if execution.status != "running":
execution.status = "running"
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="running",
)
@ -308,7 +347,7 @@ async def _execute_workflow(
execution.status = "failed"
execution.error = f"Approval timeout for stage {stage_name}"
execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="failed",
error=execution.error,
@ -320,7 +359,7 @@ async def _execute_workflow(
"execution_id": execution.execution_id,
"stage": stage_name,
"error": "Approval timeout",
})
}, execution_id=execution.execution_id)
return
finally:
_store._approval_events.pop(event_key, None)
@ -332,7 +371,7 @@ async def _execute_workflow(
"status": "completed",
"condition_result": result,
}
_store.update_execution(
await _store.update_execution(
execution.execution_id,
stage_results=execution.stage_results,
)
@ -342,7 +381,7 @@ async def _execute_workflow(
"status": "completed",
"output": {"dry_run": True, "action": stage.action},
}
_store.update_execution(
await _store.update_execution(
execution.execution_id,
stage_results=execution.stage_results,
)
@ -351,7 +390,7 @@ async def _execute_workflow(
"event": "stage_completed",
"execution_id": execution.execution_id,
"stage": stage_name,
})
}, execution_id=execution.execution_id)
except Exception as e:
execution.stage_results[stage_name] = {
@ -361,7 +400,7 @@ async def _execute_workflow(
execution.status = "failed"
execution.error = f"阶段 '{stage_name}' 执行失败: {e}"
execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="failed",
error=execution.error,
@ -373,13 +412,13 @@ async def _execute_workflow(
"execution_id": execution.execution_id,
"stage": stage_name,
"error": str(e),
})
}, execution_id=execution.execution_id)
return
execution.status = "completed"
execution.completed_at = datetime.now(timezone.utc).isoformat()
execution.current_stage = None
_store.update_execution(
await _store.update_execution(
execution.execution_id,
status="completed",
completed_at=execution.completed_at,
@ -388,7 +427,7 @@ async def _execute_workflow(
await _broadcast_ws({
"event": "execution_completed",
"execution_id": execution.execution_id,
})
}, execution_id=execution.execution_id)
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
@ -455,16 +494,25 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
raise ValueError(f"Invalid condition expression: {expression}")
async def _broadcast_ws(message: dict[str, Any]) -> None:
"""Broadcast a message to all WebSocket subscribers."""
async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None) -> None:
"""Broadcast a message to WebSocket subscribers for a specific execution."""
async with _ws_subscribers_lock:
targets = set()
if execution_id and execution_id in _ws_subscribers:
targets = set(_ws_subscribers[execution_id]) # snapshot
disconnected = []
for ws in _ws_subscribers:
for ws in targets:
try:
await ws.send_json(message)
except Exception:
disconnected.append(ws)
for ws in disconnected:
_ws_subscribers.remove(ws)
if disconnected:
async with _ws_subscribers_lock:
for ws in disconnected:
for eid in list(_ws_subscribers.keys()):
_ws_subscribers[eid].discard(ws)
if not _ws_subscribers[eid]:
del _ws_subscribers[eid]
# ---------------------------------------------------------------------------
@ -494,7 +542,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth:
variables_schema=body.variables_schema,
output_schema=body.output_schema,
)
saved = store.save(workflow)
saved = await store.save(workflow)
return saved.model_dump()
@ -527,7 +575,7 @@ async def update_workflow(
existing.variables_schema = body.variables_schema
existing.output_schema = body.output_schema
existing.version += 1
saved = store.save(existing)
saved = await store.save(existing)
return saved.model_dump()
@ -535,7 +583,7 @@ async def update_workflow(
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
"""Delete a workflow."""
store = _get_store(request)
deleted = store.delete(workflow_id)
deleted = await store.delete(workflow_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
return {"message": "已删除"}
@ -551,14 +599,13 @@ async def execute_workflow(
if workflow is None:
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
execution = store.create_execution(workflow_id)
execution = await store.create_execution(workflow_id)
execution.variables = body.variables
# Start execution in background
task = asyncio.create_task(
_execute_workflow(workflow, execution, body.variables, store=store)
)
store._running_tasks = getattr(store, "_running_tasks", {})
store._running_tasks[execution.execution_id] = task
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
@ -593,51 +640,60 @@ async def approve_execution(
raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
)
if execution.status != "paused":
raise HTTPException(
status_code=400, detail="当前执行状态不是等待审批"
)
if body.approved:
if execution.current_stage:
execution.stage_results[execution.current_stage] = {
"status": "approved",
"approver": "user",
"comment": body.comment,
}
execution.status = "running"
store.update_execution(
execution.execution_id,
status="running",
stage_results=execution.stage_results,
)
# Resume the waiting execution by setting the approval event
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
else:
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
if execution.current_stage:
execution.stage_results[execution.current_stage] = {
"status": "rejected",
"approver": "user",
"comment": body.comment,
}
store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
# Set the approval event so the waiting coroutine can observe the cancelled state
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
exec_lock = store.get_execution_lock(execution_id)
async with exec_lock:
# Re-fetch execution after acquiring lock
execution = store.get_execution(execution_id)
if execution is None:
raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
)
if execution.status != "paused":
raise HTTPException(
status_code=400, detail="当前执行状态不是等待审批"
)
if body.approved:
if execution.current_stage:
execution.stage_results[execution.current_stage] = {
"status": "approved",
"approver": "user",
"comment": body.comment,
}
execution.status = "running"
await store.update_execution(
execution.execution_id,
status="running",
stage_results=execution.stage_results,
)
# Resume the waiting execution by setting the approval event
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
else:
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
if execution.current_stage:
execution.stage_results[execution.current_stage] = {
"status": "rejected",
"approver": "user",
"comment": body.comment,
}
await store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
# Set the approval event so the waiting coroutine can observe the cancelled state
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump()
@ -651,23 +707,33 @@ async def cancel_execution(request: Request, execution_id: str, _auth: None = De
raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
)
if execution.status not in ("running", "paused", "pending"):
raise HTTPException(
status_code=400, detail="当前执行状态无法取消"
)
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
)
# Set any pending approval event so a paused workflow can observe the cancelled state
if hasattr(execution, "current_stage") and execution.current_stage:
event_key = f"{execution_id}:{execution.current_stage}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
exec_lock = store.get_execution_lock(execution_id)
async with exec_lock:
# Re-fetch execution after acquiring lock
execution = store.get_execution(execution_id)
if execution is None:
raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
)
if execution.status not in ("running", "paused", "pending"):
raise HTTPException(
status_code=400, detail="当前执行状态无法取消"
)
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
await store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
)
# Set any pending approval event so a paused workflow can observe the cancelled state
if hasattr(execution, "current_stage") and execution.current_stage:
event_key = f"{execution_id}:{execution.current_stage}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump()
@ -683,21 +749,37 @@ async def workflow_websocket(websocket: WebSocket):
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json(
{"event": "error", "data": {"message": "Invalid or missing api_key"}}
)
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
await websocket.close(code=4001, reason="Invalid or missing api_key")
return
await websocket.accept()
_ws_subscribers.append(websocket)
# Determine execution_id from query params; may be None for backward compat
execution_id = websocket.query_params.get("execution_id")
subscribed_execution_id: str | None = None
if execution_id:
await _ws_subscribe(execution_id, websocket)
subscribed_execution_id = execution_id
try:
while True:
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0)
# Handle subscription messages
try:
msg = json.loads(raw)
if isinstance(msg, dict) and msg.get("type") == "subscribe":
new_eid = msg.get("execution_id")
if new_eid:
# Unsubscribe from previous if any
if subscribed_execution_id:
await _ws_unsubscribe(subscribed_execution_id, websocket)
await _ws_subscribe(new_eid, websocket)
subscribed_execution_id = new_eid
except (json.JSONDecodeError, KeyError):
pass
except asyncio.TimeoutError:
await websocket.close(code=1000, reason="Heartbeat timeout")
return
@ -707,5 +789,5 @@ async def workflow_websocket(websocket: WebSocket):
except Exception as e:
logger.error(f"Workflow WebSocket error: {e}")
finally:
if websocket in _ws_subscribers:
_ws_subscribers.remove(websocket)
if subscribed_execution_id:
await _ws_unsubscribe(subscribed_execution_id, websocket)

View File

@ -7,6 +7,7 @@
from __future__ import annotations
import asyncio
import base64
import logging
from typing import Any, Callable, Awaitable
@ -95,11 +96,14 @@ class ComputerUseTool(Tool):
self._max_retries = max_retries
self._request_timeout = request_timeout
self._http_client: httpx.AsyncClient | None = None
self._client_lock = asyncio.Lock()
def _get_http_client(self) -> httpx.AsyncClient:
async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
return self._http_client
async def close(self) -> None:
@ -391,7 +395,7 @@ class ComputerUseTool(Tool):
"content-type": "application/json",
}
client = self._get_http_client()
client = await self._get_http_client()
response = await client.post(
self._api_base_url,
json=request_body,

View File

@ -66,40 +66,33 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"docker images",
)
# 危险命令模式:这些命令需要人工确认
_DANGEROUS_PATTERNS: tuple[str, ...] = (
"rm ",
"rm -",
"rmdir",
"mkfs",
"dd ",
"format",
"del ",
"erase",
"> /dev/",
"shutdown",
"reboot",
"init 0",
"init 6",
"kill -9",
"killall",
"chmod 777",
"chown",
"mv /",
"pip uninstall",
"npm uninstall",
"apt remove",
"yum remove",
"brew uninstall",
"docker rm",
"docker rmi",
"git push --force",
"git reset --hard",
"git clean -f",
"drop table",
"drop database",
"truncate",
)
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
# 总是危险的二进制命令(无论参数)
_DANGEROUS_BINARIES: frozenset[str] = frozenset({
"rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot",
"halt", "killall", "chown", "fdisk", "parted",
})
# 需要特定参数才危险的二进制命令binary → 危险 flag/子命令集合
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
"rm": {"-rf", "-fr", "-r", "-f"},
"kill": {"-9", "-kill"},
"chmod": {"777", "000"},
"git": {"push --force", "push -f", "reset --hard", "clean -f"},
"pip": {"uninstall"},
"npm": {"uninstall"},
"docker": {"rm", "rmi", "system prune"},
}
# 跨 token 的危险模式(编译后的正则)
_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
re.compile(r">\s*/dev/", re.IGNORECASE),
re.compile(r">\s*/etc/", re.IGNORECASE),
re.compile(r"drop\s+table", re.IGNORECASE),
re.compile(r"drop\s+database", re.IGNORECASE),
re.compile(r"truncate\s+table", re.IGNORECASE),
]
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
@ -300,10 +293,15 @@ class ShellTool(Tool):
timeout=timeout,
)
except asyncio.TimeoutError:
proc.kill()
try:
proc.kill()
except ProcessLookupError:
logger.debug("Process already exited before kill()")
except OSError:
logger.debug("OSError killing process")
await proc.wait()
output = f"命令执行超时({timeout}s"
exit_code = -1
exit_code = proc.returncode if proc.returncode is not None else -1
else:
output = stdout.decode("utf-8", errors="replace") if stdout else ""
exit_code = proc.returncode if proc.returncode is not None else 0
@ -394,10 +392,24 @@ class ShellTool(Tool):
if binary.lower() == prefix_stripped:
return False
# Dangerous pattern check
# Dangerous pattern check — token-based matching
binary_lower = binary.lower()
# 1. Binary is always dangerous regardless of flags
if binary_lower in _DANGEROUS_BINARIES:
return True
# 2. Binary is dangerous with specific flags/subcommands
if binary_lower in _DANGEROUS_BINARY_FLAGS:
cmd_str = " ".join(tokens).lower()
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
if flag_pattern in cmd_str:
return True
# 3. Cross-token dangerous patterns (regex)
command_lower = command_stripped.lower()
for pattern in _DANGEROUS_PATTERNS:
if pattern in command_lower:
for pattern in _DANGEROUS_ARG_PATTERNS:
if pattern.search(command_lower):
return True
return True # Unknown commands are dangerous by default

View File

@ -1,8 +1,16 @@
"""Security utilities for URL validation."""
import ipaddress
import socket
from urllib.parse import urlparse
_blocked_hostnames = {
"localhost",
"metadata.google.internal",
"metadata.internal",
"metadata.azure.com",
}
def is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks).
@ -20,13 +28,6 @@ def is_safe_url(url: str) -> bool:
hostname = parsed.hostname
if not hostname:
return False
# Block known internal/metadata hostnames
_blocked_hostnames = {
"localhost",
"metadata.google.internal",
"metadata.internal",
"metadata.azure.com",
}
hostname_lower = hostname.lower()
if hostname_lower in _blocked_hostnames:
return False
@ -37,9 +38,15 @@ def is_safe_url(url: str) -> bool:
if _is_unsafe_ip(ip):
return False
except ValueError:
# hostname is a domain, not a literal IP — DNS rebinding risk remains
# (would need DNS resolution to fully mitigate)
pass
# hostname is a domain — resolve DNS and check IPs
try:
addr_infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, type_, proto, canonname, sockaddr in addr_infos:
ip = ipaddress.ip_address(sockaddr[0])
if _is_unsafe_ip(ip):
return False
except (socket.gaierror, socket.timeout, OSError):
return False # fail-closed on DNS errors
return True
except Exception:
return False
@ -58,3 +65,36 @@ def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
return True
return False
async def is_safe_url_async(url: str) -> bool:
"""Async variant of is_safe_url with DNS resolution."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
if hostname_lower in _blocked_hostnames:
return False
if hostname_lower.endswith((".internal", ".local", ".localhost")):
return False
try:
ip = ipaddress.ip_address(hostname)
if _is_unsafe_ip(ip):
return False
except ValueError:
try:
import asyncio
addr_infos = await asyncio.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, type_, proto, canonname, sockaddr in addr_infos:
ip = ipaddress.ip_address(sockaddr[0])
if _is_unsafe_ip(ip):
return False
except (socket.gaierror, socket.timeout, OSError):
return False
return True
except Exception:
return False

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import pytest
from fastapi.testclient import TestClient
@ -119,7 +121,7 @@ class TestWorkflowStore:
store = WorkflowStore()
wf = WorkflowDefinition(workflow_id="test-1", name="Test")
store.save(wf)
asyncio.run(store.save(wf))
result = store.get("test-1")
assert result is not None
assert result.name == "Test"
@ -131,36 +133,45 @@ class TestWorkflowStore:
def test_list(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore()
for i in range(3):
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
summaries = store.list()
assert len(summaries) == 3
async def _run():
store = WorkflowStore()
for i in range(3):
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
summaries = store.list()
assert len(summaries) == 3
asyncio.run(_run())
def test_list_limit(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore()
for i in range(5):
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
summaries = store.list(limit=2)
assert len(summaries) == 2
async def _run():
store = WorkflowStore()
for i in range(5):
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
summaries = store.list(limit=2)
assert len(summaries) == 2
asyncio.run(_run())
def test_delete(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore()
store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
assert store.delete("del-1") is True
assert store.get("del-1") is None
async def _run():
store = WorkflowStore()
await store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
assert await store.delete("del-1") is True
assert store.get("del-1") is None
asyncio.run(_run())
def test_delete_not_found(self):
store = WorkflowStore()
assert store.delete("nonexistent") is False
assert asyncio.run(store.delete("nonexistent")) is False
def test_create_and_get_execution(self):
store = WorkflowStore()
execution = store.create_execution("wf-1")
execution = asyncio.run(store.create_execution("wf-1"))
assert execution.workflow_id == "wf-1"
assert execution.status == "pending"
@ -173,18 +184,51 @@ class TestWorkflowStore:
assert store.get_execution("nonexistent") is None
def test_update_execution(self):
store = WorkflowStore()
execution = store.create_execution("wf-1")
updated = store.update_execution(
execution.execution_id, status="running", current_stage="step-1"
)
assert updated.status == "running"
assert updated.current_stage == "step-1"
async def _run():
store = WorkflowStore()
execution = await store.create_execution("wf-1")
updated = await store.update_execution(
execution.execution_id, status="running", current_stage="step-1"
)
assert updated.status == "running"
assert updated.current_stage == "step-1"
asyncio.run(_run())
def test_update_execution_not_found(self):
store = WorkflowStore()
with pytest.raises(KeyError):
store.update_execution("nonexistent", status="running")
asyncio.run(store.update_execution("nonexistent", status="running"))
def test_running_tasks_initialized(self):
store = WorkflowStore()
assert hasattr(store, "_running_tasks")
assert isinstance(store._running_tasks, dict)
assert len(store._running_tasks) == 0
def test_execution_locks_initialized(self):
store = WorkflowStore()
assert hasattr(store, "_execution_locks")
assert isinstance(store._execution_locks, dict)
def test_evict_execution_cleans_approval_events(self):
async def _run():
store = WorkflowStore(max_executions=2)
e1 = await store.create_execution("wf-1")
e2 = await store.create_execution("wf-2")
# Add an approval event for e1
event_key = f"{e1.execution_id}:stage1"
event = asyncio.Event()
store._approval_events[event_key] = event
# Create a third execution to trigger eviction of e1
e3 = await store.create_execution("wf-3")
# e1 should be evicted, and its approval event should be cleaned up
assert store.get_execution(e1.execution_id) is None
assert event_key not in store._approval_events
# The event should have been set (to wake any waiting coroutine)
assert event.is_set()
asyncio.run(_run())
# ---------------------------------------------------------------------------