diff --git a/docs/plans/2026-06-10-018-fix-agentkit-p2-hardening-plan.md b/docs/plans/2026-06-10-018-fix-agentkit-p2-hardening-plan.md new file mode 100644 index 0000000..67b848e --- /dev/null +++ b/docs/plans/2026-06-10-018-fix-agentkit-p2-hardening-plan.md @@ -0,0 +1,287 @@ +# fix: AgentKit P2 Security & Reliability Hardening + +**Status:** active +**Created:** 2026-06-10 +**Origin:** ce-code-review findings on main branch after 7-capabilities merge + +--- + +## Summary + +Fix all remaining P2 issues from the final code review of the `feat/agentkit-7-capabilities` merge. These 16 issues span security (SSRF, auth, injection), reliability (retry, concurrency, resource leaks), and architecture (auth consistency, config reload). All are non-blocking for current functionality but represent real risk in production. + +## Problem Frame + +After merging 7 core capabilities and fixing P0/P1 issues, the codebase has 16 P2 findings that fall into four categories: + +1. **Security gaps** — SSRF DNS rebinding, API key timing attacks, WebSocket auth bypass, approval race conditions, shell pattern matching, .env injection +2. **Reliability gaps** — retry storms, unbounded memory growth, concurrent modification, config change loss +3. **Architecture gaps** — dual auth layer inconsistency, WebSocket broadcast without isolation +4. **Defensive gaps** — process kill race, lazy init race, httpx client double-creation + +## Requirements + +- R1: All P2 security findings must be resolved or explicitly deferred with documented rationale +- R2: Retry logic must use exponential backoff with jitter +- R3: All in-memory stores must have bounded growth with proper eviction +- R4: Concurrent access to shared mutable state must be protected +- R5: Authentication must be consistent across all endpoints and transports +- R6: No regression in existing test suite (764+ tests must pass) + +## Key Technical Decisions + +- **KTD1: DNS resolution for SSRF** — Use `asyncio.getaddrinfo` (not blocking `socket.getaddrinfo`) to resolve hostnames before IP validation. Fail-closed on DNS errors. Accept that TOCTOU remains between check and connect; full mitigation requires IP pinning at the httpx transport level (deferred). +- **KTD2: Auth consolidation strategy** — Fix per-endpoint auth first (hmac.compare_digest, WebSocket auth-before-accept), then consolidate into middleware in a separate pass. This avoids a big-bang refactor and allows incremental testing. +- **KTD3: WorkflowStore concurrency** — Add `asyncio.Lock` to WorkflowStore, convert sync methods to async. This is the most invasive change but necessary for correctness under concurrent load. +- **KTD4: Shell dangerous patterns** — Replace substring matching with token-based matching using the already-parsed shlex tokens. Separate binary-level and flag-level dangerous patterns. +- **KTD5: WebSocket isolation** — Replace flat subscriber list with `dict[str, set[WebSocket]]` keyed by execution_id. Clients subscribe to specific executions. + +--- + +## Implementation Units + +### U1. API Key Timing Attack & WebSocket Auth-Before-Accept + +**Goal:** Eliminate timing side-channel in API key comparison and fix WebSocket accept-before-auth pattern. + +**Requirements:** R5 + +**Dependencies:** None + +**Files:** +- `src/agentkit/server/routes/portal.py` +- `src/agentkit/server/routes/workflows.py` +- `src/agentkit/server/middleware.py` +- `tests/unit/server/test_portal_routes.py` +- `tests/unit/server/test_workflow_routes.py` + +**Approach:** +1. Add `import hmac` to portal.py, workflows.py, middleware.py +2. Replace all `!=` API key comparisons with `hmac.compare_digest()` +3. In both WebSocket handlers, move auth validation BEFORE `websocket.accept()`. On auth failure, call `await websocket.close(code=4001)` without accepting, then return +4. In middleware.py, replace `provided_key not in valid_keys` with `any(hmac.compare_digest(...) for k in valid_keys)` +5. Handle None/empty provided key by defaulting to empty string before comparison + +**Test scenarios:** +- API key comparison with correct key returns True +- API key comparison with incorrect key returns False +- API key comparison with None provided defaults to empty string +- WebSocket connection without api_key is rejected (no 101 upgrade) +- WebSocket connection with wrong api_key is rejected +- WebSocket connection with correct api_key is accepted + +**Verification:** All existing tests pass; new auth tests pass + +--- + +### U2. Shell Dangerous Pattern Token-Based Matching + +**Goal:** Replace substring-based dangerous pattern matching with precise token-based matching to eliminate false positives and bypasses. + +**Requirements:** R1 + +**Dependencies:** None + +**Files:** +- `src/agentkit/tools/shell.py` +- `tests/unit/test_shell_tool.py` + +**Approach:** +1. Replace `_DANGEROUS_PATTERNS` tuple with two new structures: + - `_DANGEROUS_BINARIES`: set of binary names that are always dangerous (e.g., `rm`, `mkfs`, `dd`, `shutdown`) + - `_DANGEROUS_BINARY_FLAGS`: dict mapping binary name to set of dangerous flag combinations (e.g., `kill: {-9, -KILL}`, `chmod: {777}`) +2. In `_is_dangerous()`, after shlex parsing: + - Check `binary` against `_DANGEROUS_BINARIES` (exact match) + - Check `binary` + flags against `_DANGEROUS_BINARY_FLAGS` + - Keep regex-based arg patterns for cross-token matches (e.g., `> /dev/`, `drop table`) +3. Remove overly broad patterns like `format`, `erase` as substrings; add `format` as binary-only +4. Normalize whitespace in command before matching + +**Test scenarios:** +- `rm -rf /` detected as dangerous (binary match) +- `echo 'performing rm operations'` NOT detected as dangerous (no binary match) +- `kill -9 1234` detected as dangerous (binary + flag match) +- `kill -15 1234` NOT detected as dangerous (flag not in dangerous set) +- `rm\t/tmp/file` detected as dangerous (tab normalization) +- `format C:` detected as dangerous (binary match) +- `informatica` NOT detected as dangerous (not a binary match) + +**Verification:** Shell tool tests pass with new pattern logic + +--- + +### U3. .env File Prefix Filtering & Retry Backoff + +**Goal:** Prevent arbitrary environment variable injection from .env files and add exponential backoff to plan executor retries. + +**Requirements:** R1, R2 + +**Dependencies:** None + +**Files:** +- `src/agentkit/server/app.py` +- `src/agentkit/core/plan_executor.py` +- `tests/unit/test_plan_executor.py` + +**Approach:** +1. In app.py .env loading, add an allowlist of prefixes: `AGENTKIT_`, `OPENAI_`, `ANTHROPIC_`, `GEMINI_`, `TAVILY_`, `SERPER_`, plus exact matches for `DATABASE_URL`, `REDIS_URL` +2. Log a warning when a .env variable is skipped due to prefix filtering +3. In plan_executor.py `_execute_step_with_retry()`, add exponential backoff with jitter after each failure: + - `base_retry_delay=1.0s`, `max_retry_delay=30.0s` + - `delay = min(base * 2^(retry-1), max) * (0.5 + random() * 0.5)` + - Skip delay on first attempt (retry_count == 0) +4. Add `base_retry_delay` and `max_retry_delay` as constructor parameters with defaults + +**Test scenarios:** +- .env with `AGENTKIT_FOO=bar` is loaded +- .env with `PATH=/malicious` is skipped with warning +- .env with `OPENAI_API_KEY=sk-xxx` is loaded +- Retry backoff: first retry waits ~1s, second ~2s, third ~4s (with jitter) +- Max retry delay caps at 30s +- Plan executor with max_retries=0 still works (no backoff) + +**Verification:** Unit tests pass; .env loading respects prefix filter + +--- + +### U4. WorkflowStore Concurrency & Resource Management + +**Goal:** Make WorkflowStore thread-safe, fix unbounded growth, add proper cleanup for approval events and running tasks. + +**Requirements:** R3, R4 + +**Dependencies:** U1 (hmac.compare_digest in workflow routes) + +**Files:** +- `src/agentkit/server/routes/workflows.py` +- `tests/unit/server/test_workflow_routes.py` + +**Approach:** +1. Add `asyncio.Lock` to WorkflowStore; convert `save`, `delete`, `create_execution`, `update_execution` to async methods +2. Initialize `_running_tasks` in `__init__` instead of lazy `getattr` +3. Add `_evict_execution()` helper that removes execution AND its approval events (with `event.set()` to wake waiting coroutines) +4. Replace `_ws_subscribers: list` with `_ws_subscribers: dict[str, set[WebSocket]]` keyed by execution_id, protected by `asyncio.Lock` +5. Add `_subscribe(execution_id, ws)` and `_unsubscribe(execution_id, ws)` helpers +6. Modify `_broadcast_ws` to accept `execution_id` parameter and only send to relevant subscribers +7. Update all call sites from sync to async (all endpoint functions already async) +8. Add `_execution_locks: dict[str, asyncio.Lock]` for per-execution approve/cancel serialization + +**Test scenarios:** +- Concurrent save operations don't cause KeyError or over-eviction +- Evicted execution's approval events are cleaned up +- WebSocket subscriber receives only messages for subscribed execution_id +- WebSocket disconnect properly cleans up from subscriber dict +- Concurrent approve/cancel for same execution serialized by lock +- _running_tasks initialized in __init__, no getattr fallback + +**Verification:** Workflow route tests pass; no RuntimeError from concurrent modification + +--- + +### U5. SSRF DNS Resolution & ComputerUse Client Safety + +**Goal:** Add DNS resolution to SSRF protection and make ComputerUseTool's httpx client creation thread-safe. + +**Requirements:** R1 + +**Dependencies:** None + +**Files:** +- `src/agentkit/utils/security.py` +- `src/agentkit/tools/computer_use.py` +- `tests/unit/test_security.py` +- `tests/unit/tools/test_computer_use.py` + +**Approach:** +1. In `is_safe_url()`, after the ValueError catch for domain names, add DNS resolution: + - Use `socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)` (sync, acceptable for tool init) + - Validate each resolved IP against `_is_unsafe_ip()` + - Fail-closed on DNS errors (return False) + - Add 3-second timeout on DNS resolution +2. Add `is_safe_url_async()` variant using `asyncio.getaddrinfo` for async callers +3. In ComputerUseTool, add `asyncio.Lock` to `_get_http_client()`, make it async with double-check locking +4. Update caller from `client = self._get_http_client()` to `client = await self._get_http_client()` + +**Test scenarios:** +- is_safe_url with domain resolving to 127.0.0.1 returns False +- is_safe_url with domain resolving to public IP returns True +- is_safe_url with unresolvable domain returns False (fail-closed) +- is_safe_url with ::ffff:127.0.0.1 returns False +- is_safe_url with .internal TLD returns False +- ComputerUseTool concurrent _get_http_client returns same client instance +- ComputerUseTool close() properly cleans up client + +**Verification:** Security and computer_use tests pass + +--- + +### U6. Config Hot-Reload & Defensive Fixes + +**Goal:** Fix config hot-reload lock race, add proc.kill() error handling, and initialize config reload lock eagerly. + +**Requirements:** R4 + +**Dependencies:** None + +**Files:** +- `src/agentkit/server/app.py` +- `src/agentkit/tools/shell.py` + +**Approach:** +1. In app.py `_on_config_change()`: + - Initialize lock eagerly in `create_app()` (store on `app.state._config_reload_lock`) + - Replace `lock.locked()` skip with pending-flag pattern: set `app.state._config_reload_pending = True`, then in `_reload()` loop while pending flag is set + - This ensures no config change is silently dropped +2. In shell.py `_execute_standalone()`: + - Wrap `proc.kill()` in `try/except (ProcessLookupError, OSError)` + - Use `proc.returncode` if available instead of hardcoded -1 + +**Test scenarios:** +- Two rapid config changes both get applied (no silent drop) +- Config reload pending flag is cleared after reload completes +- Shell timeout with already-exited process returns clean error (no ProcessLookupError) +- Shell timeout output still shows timeout message + +**Verification:** App startup and shell tool tests pass + +--- + +## Scope Boundaries + +### In Scope +- All 16 P2 findings from ce-code-review +- Test coverage for new security/reliability patterns + +### Deferred to Follow-Up Work +- Full auth middleware consolidation (ASGI-level WebSocket middleware) — complex, separate PR +- IP pinning at httpx transport level for complete SSRF protection +- python-dotenv migration for robust .env parsing +- clients.yaml caching in middleware (per-request disk read) +- WorkflowStore persistence (currently in-memory only) + +### Outside This Plan's Identity +- New features or capability additions +- Performance optimization beyond what's needed for reliability +- UI/UX changes + +--- + +## Risks & Mitigations + +| Risk | Impact | Mitigation | +|------|--------|------------| +| WorkflowStore async refactor breaks existing endpoints | High | Update all call sites; comprehensive test coverage | +| DNS resolution adds latency to URL validation | Low | 3s timeout; acceptable for tool initialization | +| Token-based shell matching misses new bypass patterns | Medium | Default-to-dangerous fallback; add patterns as discovered | +| .env prefix filtering breaks existing deployments | Medium | Comprehensive allowlist; log warnings for skipped vars | +| WebSocket isolation breaks existing WS clients | Medium | Maintain backward compat: accept without subscription, send all | +| Config reload pending loop runs indefinitely | Low | Coalescing naturally limits; only re-runs if new change arrived | + +--- + +## System-Wide Impact + +- **Security posture** significantly improved: timing attacks, SSRF, auth bypass all addressed +- **Reliability** improved: retry storms prevented, memory bounded, concurrent access safe +- **API compatibility**: WorkflowStore method signatures change from sync to async — all call sites updated +- **WebSocket protocol**: subscription message becomes recommended but not required for backward compat diff --git a/src/agentkit/core/plan_executor.py b/src/agentkit/core/plan_executor.py index 6d76680..4f736e1 100644 --- a/src/agentkit/core/plan_executor.py +++ b/src/agentkit/core/plan_executor.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio import logging +import random import time from dataclasses import dataclass, field from enum import Enum @@ -94,6 +95,8 @@ class PlanExecutor: max_retries: int = 2, step_timeout: float = 300.0, max_parallel: int = 5, + base_retry_delay: float = 1.0, + max_retry_delay: float = 30.0, on_step_complete: OnStepCompleteCallback | None = None, on_step_failed: OnStepFailedCallback | None = None, on_human_intervention: OnHumanInterventionCallback | None = None, @@ -104,6 +107,8 @@ class PlanExecutor: max_retries: 步骤失败后最大重试次数 step_timeout: 单个步骤超时时间(秒) max_parallel: 最大并行步骤数 + base_retry_delay: 重试基础延迟(秒) + max_retry_delay: 重试最大延迟(秒) on_step_complete: 步骤完成回调 on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理 on_human_intervention: 人工介入回调 @@ -112,6 +117,8 @@ class PlanExecutor: self._max_retries = max_retries self._step_timeout = step_timeout self._max_parallel = max_parallel + self._base_retry_delay = base_retry_delay + self._max_retry_delay = max_retry_delay self._on_step_complete = on_step_complete self._on_step_failed = on_step_failed self._on_human_intervention = on_human_intervention @@ -250,6 +257,15 @@ class PlanExecutor: retry_count += 1 + if retry_count <= self._max_retries: + delay = min( + self._base_retry_delay * (2 ** (retry_count - 1)), + self._max_retry_delay, + ) + delay *= (0.5 + random.random() * 0.5) # jitter + logger.info(f"Retrying step '{step.step_id}' in {delay:.1f}s (attempt {retry_count + 1}/{self._max_retries + 1})") + await asyncio.sleep(delay) + # 所有重试耗尽 step.status = PlanStepStatus.FAILED step.error = last_error diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 03b82ed..9a55efe 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -30,6 +30,12 @@ from agentkit.telemetry.setup import setup_telemetry logger = logging.getLogger(__name__) +_ALLOWED_ENV_PREFIXES = ( + 'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_', + 'TAVILY_', 'SERPER_', 'DEEPSEEK_', +) +_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'} + def _build_llm_gateway(config: ServerConfig) -> LLMGateway: """Build LLMGateway from ServerConfig, registering all providers.""" @@ -224,70 +230,69 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None: Uses a lock to prevent concurrent config reloads from racing. """ - lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None) - if lock is None: - lock = asyncio.Lock() - app.state._config_reload_lock = lock + lock: asyncio.Lock = app.state._config_reload_lock - if lock.locked(): - logger.warning("Config reload already in progress, skipping") - return + app.state._config_reload_pending = True async def _reload(): + if lock.locked(): + return # Another reload running; it will check pending flag async with lock: - # Increment config version for audit - current_version = getattr(app.state, "config_version", 0) + 1 - app.state.config_version = current_version - logger.info(f"Config change detected (v{current_version}), reloading...") + while getattr(app.state, "_config_reload_pending", False): + app.state._config_reload_pending = False + # Increment config version for audit + current_version = getattr(app.state, "config_version", 0) + 1 + app.state.config_version = current_version + logger.info(f"Config change detected (v{current_version}), reloading...") - # Rebuild LLMGateway if llm config changed - try: - new_gateway = _build_llm_gateway(config) - app.state.llm_gateway = new_gateway - # Also update the agent pool's gateway reference + # Rebuild LLMGateway if llm config changed + try: + new_gateway = _build_llm_gateway(config) + app.state.llm_gateway = new_gateway + # Also update the agent pool's gateway reference + if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: + app.state.agent_pool._llm_gateway = new_gateway + if hasattr(app.state, "intent_router") and app.state.intent_router is not None: + app.state.intent_router._llm_gateway = new_gateway + logger.info(f"LLM Gateway reloaded (config v{current_version})") + except Exception as e: + logger.error(f"Failed to reload LLM Gateway: {e}") + + # Reload skills if skill paths changed + try: + new_skill_registry = _build_skill_registry(config) + # Re-bind tools from the shared tool_registry so skills don't lose their bindings + tool_registry = getattr(app.state, "tool_registry", None) + if tool_registry: + from agentkit.skills.loader import SkillLoader + loader = SkillLoader( + skill_registry=new_skill_registry, + tool_registry=tool_registry, + ) + for skill_path in (config.skill_paths or []): + from pathlib import Path as _P + p = _P(skill_path) + if p.is_dir(): + loader.load_from_directory(str(p)) + elif p.is_file() and p.suffix in (".yaml", ".yml"): + try: + loader.load_from_file(str(p)) + except Exception: + pass + app.state.skill_registry = new_skill_registry + if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: + app.state.agent_pool._skill_registry = new_skill_registry + logger.info(f"Skills reloaded (config v{current_version})") + except Exception as e: + logger.error(f"Failed to reload skills: {e}") + + # Update config version on all agents if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: - app.state.agent_pool._llm_gateway = new_gateway - if hasattr(app.state, "intent_router") and app.state.intent_router is not None: - app.state.intent_router._llm_gateway = new_gateway - logger.info(f"LLM Gateway reloaded (config v{current_version})") - except Exception as e: - logger.error(f"Failed to reload LLM Gateway: {e}") + for agent in app.state.agent_pool._agents.values(): + if hasattr(agent, "_config_version"): + agent._config_version = current_version - # Reload skills if skill paths changed - try: - new_skill_registry = _build_skill_registry(config) - # Re-bind tools from the shared tool_registry so skills don't lose their bindings - tool_registry = getattr(app.state, "tool_registry", None) - if tool_registry: - from agentkit.skills.loader import SkillLoader - loader = SkillLoader( - skill_registry=new_skill_registry, - tool_registry=tool_registry, - ) - for skill_path in (config.skill_paths or []): - from pathlib import Path as _P - p = _P(skill_path) - if p.is_dir(): - loader.load_from_directory(str(p)) - elif p.is_file() and p.suffix in (".yaml", ".yml"): - try: - loader.load_from_file(str(p)) - except Exception: - pass - app.state.skill_registry = new_skill_registry - if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: - app.state.agent_pool._skill_registry = new_skill_registry - logger.info(f"Skills reloaded (config v{current_version})") - except Exception as e: - logger.error(f"Failed to reload skills: {e}") - - # Update config version on all agents - if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: - for agent in app.state.agent_pool._agents.values(): - if hasattr(agent, "_config_version"): - agent._config_version = current_version - - logger.info(f"Config reload complete (v{current_version})") + logger.info(f"Config reload complete (v{current_version})") # Schedule the reload as a task (non-blocking for the watcher thread) try: @@ -327,6 +332,10 @@ def create_app( _key = _key.strip() _val = _val.strip().strip("\"'") if _key and _key not in os.environ: + allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT + if not allowed: + logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)") + continue os.environ[_key] = _val server_config = ServerConfig.from_yaml(config_path) app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan) diff --git a/src/agentkit/server/middleware.py b/src/agentkit/server/middleware.py index 1e0b85d..7ea0836 100644 --- a/src/agentkit/server/middleware.py +++ b/src/agentkit/server/middleware.py @@ -1,5 +1,6 @@ """Server middleware - Authentication and Rate Limiting""" +import hmac import os import time from collections import defaultdict @@ -75,7 +76,7 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware): # Check API key from header provided_key = request.headers.get("X-API-Key") - if not provided_key or provided_key not in valid_keys: + if not provided_key or not any(hmac.compare_digest(provided_key.encode(), k.encode()) for k in valid_keys): return JSONResponse( status_code=401, content={"error": "Unauthorized", "message": "Invalid or missing API key"}, diff --git a/src/agentkit/server/routes/portal.py b/src/agentkit/server/routes/portal.py index e90c8db..7b5f160 100644 --- a/src/agentkit/server/routes/portal.py +++ b/src/agentkit/server/routes/portal.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import hmac import json import logging import uuid @@ -47,7 +48,7 @@ async def _verify_api_key( return provided = api_key_header or api_key_query - if provided != configured_api_key: + if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()): raise HTTPException( status_code=401, detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.", @@ -475,11 +476,7 @@ async def portal_websocket(websocket: WebSocket): # Check api_key query param if configured_api_key: provided = websocket.query_params.get("api_key") - if provided != configured_api_key: - await websocket.accept() - await websocket.send_json( - {"type": "error", "data": {"message": "Invalid or missing api_key"}} - ) + if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()): await websocket.close(code=4001, reason="Invalid or missing api_key") return diff --git a/src/agentkit/server/routes/workflows.py b/src/agentkit/server/routes/workflows.py index b292233..f250150 100644 --- a/src/agentkit/server/routes/workflows.py +++ b/src/agentkit/server/routes/workflows.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import hmac import json import logging import re @@ -52,7 +53,7 @@ async def _verify_api_key( return provided = api_key_header or api_key_query - if provided != configured_api_key: + if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()): raise HTTPException( status_code=401, detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.", @@ -65,7 +66,7 @@ async def _verify_api_key( class WorkflowStore: - """In-memory workflow store.""" + """In-memory workflow store with async-safe mutation methods.""" def __init__(self, max_workflows: int = 500, max_executions: int = 1000): self._workflows: dict[str, WorkflowDefinition] = {} @@ -73,17 +74,30 @@ class WorkflowStore: self._max_workflows = max_workflows self._max_executions = max_executions self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}" + self._running_tasks: dict[str, asyncio.Task] = {} + self._execution_locks: dict[str, asyncio.Lock] = {} + self._lock = asyncio.Lock() - def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition: - workflow.updated_at = datetime.now(timezone.utc).isoformat() - self._workflows[workflow.workflow_id] = workflow - # Evict oldest if over limit - if len(self._workflows) > self._max_workflows: - oldest_id = min( - self._workflows, key=lambda k: self._workflows[k].updated_at - ) - del self._workflows[oldest_id] - return workflow + def _evict_execution(self, execution_id: str) -> None: + """Remove execution and its associated approval events.""" + self._executions.pop(execution_id, None) + keys_to_remove = [k for k in self._approval_events if k.startswith(f"{execution_id}:")] + for k in keys_to_remove: + event = self._approval_events.pop(k, None) + if event is not None: + event.set() # Wake any waiting coroutine + + async def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition: + async with self._lock: + workflow.updated_at = datetime.now(timezone.utc).isoformat() + self._workflows[workflow.workflow_id] = workflow + # Evict oldest if over limit + if len(self._workflows) > self._max_workflows: + oldest_id = min( + self._workflows, key=lambda k: self._workflows[k].updated_at + ) + del self._workflows[oldest_id] + return workflow def get(self, workflow_id: str) -> WorkflowDefinition | None: return self._workflows.get(workflow_id) @@ -109,47 +123,72 @@ class WorkflowStore: ) return summaries - def delete(self, workflow_id: str) -> bool: - if workflow_id in self._workflows: - del self._workflows[workflow_id] - return True - return False + async def delete(self, workflow_id: str) -> bool: + async with self._lock: + if workflow_id in self._workflows: + del self._workflows[workflow_id] + return True + return False - def create_execution(self, workflow_id: str) -> WorkflowExecution: - execution = WorkflowExecution( - execution_id=str(uuid.uuid4()), - workflow_id=workflow_id, - status="pending", - started_at=datetime.now(timezone.utc).isoformat(), - ) - self._executions[execution.execution_id] = execution - # Evict oldest if over limit - if len(self._executions) > self._max_executions: - oldest_id = min( - self._executions, - key=lambda k: self._executions[k].started_at or "", + async def create_execution(self, workflow_id: str) -> WorkflowExecution: + async with self._lock: + execution = WorkflowExecution( + execution_id=str(uuid.uuid4()), + workflow_id=workflow_id, + status="pending", + started_at=datetime.now(timezone.utc).isoformat(), ) - del self._executions[oldest_id] - return execution + self._executions[execution.execution_id] = execution + # Evict oldest if over limit + if len(self._executions) > self._max_executions: + oldest_id = min( + self._executions, + key=lambda k: self._executions[k].started_at or "", + ) + self._evict_execution(oldest_id) + return execution def get_execution(self, execution_id: str) -> WorkflowExecution | None: return self._executions.get(execution_id) - def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution: - execution = self._executions.get(execution_id) - if execution is None: - raise KeyError(f"Execution '{execution_id}' not found") - for key, value in kwargs.items(): - if hasattr(execution, key): - setattr(execution, key, value) - return execution + async def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution: + async with self._lock: + execution = self._executions.get(execution_id) + if execution is None: + raise KeyError(f"Execution '{execution_id}' not found") + for key, value in kwargs.items(): + if hasattr(execution, key): + setattr(execution, key, value) + return execution + + def get_execution_lock(self, execution_id: str) -> asyncio.Lock: + """Get or create a per-execution lock for approve/cancel serialization.""" + if execution_id not in self._execution_locks: + self._execution_locks[execution_id] = asyncio.Lock() + return self._execution_locks[execution_id] # Module-level singleton _workflow_store = WorkflowStore() -# WebSocket subscribers for real-time execution progress -_ws_subscribers: list[WebSocket] = [] +# WebSocket subscribers for real-time execution progress (keyed by execution_id) +_ws_subscribers: dict[str, set[WebSocket]] = {} +_ws_subscribers_lock = asyncio.Lock() + + +async def _ws_subscribe(execution_id: str, ws: WebSocket) -> None: + async with _ws_subscribers_lock: + if execution_id not in _ws_subscribers: + _ws_subscribers[execution_id] = set() + _ws_subscribers[execution_id].add(ws) + + +async def _ws_unsubscribe(execution_id: str, ws: WebSocket) -> None: + async with _ws_subscribers_lock: + if execution_id in _ws_subscribers: + _ws_subscribers[execution_id].discard(ws) + if not _ws_subscribers[execution_id]: + del _ws_subscribers[execution_id] def _get_store(request: Request) -> WorkflowStore: @@ -210,7 +249,7 @@ async def _execute_workflow( """Execute a workflow by running its stages in topological order.""" _store = store or _workflow_store execution.status = "running" - _store.update_execution(execution.execution_id, status="running") + await _store.update_execution(execution.execution_id, status="running") # Topological sort stage_map = {s.name: s for s in workflow.stages} @@ -229,7 +268,7 @@ async def _execute_workflow( execution.status = "failed" execution.error = "循环依赖" execution.completed_at = datetime.now(timezone.utc).isoformat() - _store.update_execution( + await _store.update_execution( execution.execution_id, status="failed", error="循环依赖", @@ -246,7 +285,7 @@ async def _execute_workflow( for stage_name in ordered: stage = stage_map[stage_name] execution.current_stage = stage_name - _store.update_execution( + await _store.update_execution( execution.execution_id, current_stage=stage_name, ) @@ -256,7 +295,7 @@ async def _execute_workflow( "event": "stage_started", "execution_id": execution.execution_id, "stage": stage_name, - }) + }, execution_id=execution.execution_id) try: if stage.type == "approval": @@ -267,7 +306,7 @@ async def _execute_workflow( execution.status = "paused" execution.current_stage = stage_name - _store.update_execution( + await _store.update_execution( execution.execution_id, status="paused", current_stage=stage_name, @@ -276,7 +315,7 @@ async def _execute_workflow( "event": "approval_required", "execution_id": execution.execution_id, "stage": stage_name, - }) + }, execution_id=execution.execution_id) # Wait for approval with timeout try: @@ -289,13 +328,13 @@ async def _execute_workflow( "execution_id": execution.execution_id, "stage": stage_name, "error": "Approval rejected", - }) + }, execution_id=execution.execution_id) return # Approval was granted — the /approve endpoint already set stage_results # Only update status to running if not already set if execution.status != "running": execution.status = "running" - _store.update_execution( + await _store.update_execution( execution.execution_id, status="running", ) @@ -308,7 +347,7 @@ async def _execute_workflow( execution.status = "failed" execution.error = f"Approval timeout for stage {stage_name}" execution.completed_at = datetime.now(timezone.utc).isoformat() - _store.update_execution( + await _store.update_execution( execution.execution_id, status="failed", error=execution.error, @@ -320,7 +359,7 @@ async def _execute_workflow( "execution_id": execution.execution_id, "stage": stage_name, "error": "Approval timeout", - }) + }, execution_id=execution.execution_id) return finally: _store._approval_events.pop(event_key, None) @@ -332,7 +371,7 @@ async def _execute_workflow( "status": "completed", "condition_result": result, } - _store.update_execution( + await _store.update_execution( execution.execution_id, stage_results=execution.stage_results, ) @@ -342,7 +381,7 @@ async def _execute_workflow( "status": "completed", "output": {"dry_run": True, "action": stage.action}, } - _store.update_execution( + await _store.update_execution( execution.execution_id, stage_results=execution.stage_results, ) @@ -351,7 +390,7 @@ async def _execute_workflow( "event": "stage_completed", "execution_id": execution.execution_id, "stage": stage_name, - }) + }, execution_id=execution.execution_id) except Exception as e: execution.stage_results[stage_name] = { @@ -361,7 +400,7 @@ async def _execute_workflow( execution.status = "failed" execution.error = f"阶段 '{stage_name}' 执行失败: {e}" execution.completed_at = datetime.now(timezone.utc).isoformat() - _store.update_execution( + await _store.update_execution( execution.execution_id, status="failed", error=execution.error, @@ -373,13 +412,13 @@ async def _execute_workflow( "execution_id": execution.execution_id, "stage": stage_name, "error": str(e), - }) + }, execution_id=execution.execution_id) return execution.status = "completed" execution.completed_at = datetime.now(timezone.utc).isoformat() execution.current_stage = None - _store.update_execution( + await _store.update_execution( execution.execution_id, status="completed", completed_at=execution.completed_at, @@ -388,7 +427,7 @@ async def _execute_workflow( await _broadcast_ws({ "event": "execution_completed", "execution_id": execution.execution_id, - }) + }, execution_id=execution.execution_id) _SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') @@ -455,16 +494,25 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool: raise ValueError(f"Invalid condition expression: {expression}") -async def _broadcast_ws(message: dict[str, Any]) -> None: - """Broadcast a message to all WebSocket subscribers.""" +async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None) -> None: + """Broadcast a message to WebSocket subscribers for a specific execution.""" + async with _ws_subscribers_lock: + targets = set() + if execution_id and execution_id in _ws_subscribers: + targets = set(_ws_subscribers[execution_id]) # snapshot disconnected = [] - for ws in _ws_subscribers: + for ws in targets: try: await ws.send_json(message) except Exception: disconnected.append(ws) - for ws in disconnected: - _ws_subscribers.remove(ws) + if disconnected: + async with _ws_subscribers_lock: + for ws in disconnected: + for eid in list(_ws_subscribers.keys()): + _ws_subscribers[eid].discard(ws) + if not _ws_subscribers[eid]: + del _ws_subscribers[eid] # --------------------------------------------------------------------------- @@ -494,7 +542,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth: variables_schema=body.variables_schema, output_schema=body.output_schema, ) - saved = store.save(workflow) + saved = await store.save(workflow) return saved.model_dump() @@ -527,7 +575,7 @@ async def update_workflow( existing.variables_schema = body.variables_schema existing.output_schema = body.output_schema existing.version += 1 - saved = store.save(existing) + saved = await store.save(existing) return saved.model_dump() @@ -535,7 +583,7 @@ async def update_workflow( async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)): """Delete a workflow.""" store = _get_store(request) - deleted = store.delete(workflow_id) + deleted = await store.delete(workflow_id) if not deleted: raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在") return {"message": "已删除"} @@ -551,14 +599,13 @@ async def execute_workflow( if workflow is None: raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在") - execution = store.create_execution(workflow_id) + execution = await store.create_execution(workflow_id) execution.variables = body.variables # Start execution in background task = asyncio.create_task( _execute_workflow(workflow, execution, body.variables, store=store) ) - store._running_tasks = getattr(store, "_running_tasks", {}) store._running_tasks[execution.execution_id] = task task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None)) @@ -593,51 +640,60 @@ async def approve_execution( raise HTTPException( status_code=404, detail=f"执行记录 '{execution_id}' 不存在" ) - if execution.status != "paused": - raise HTTPException( - status_code=400, detail="当前执行状态不是等待审批" - ) - if body.approved: - if execution.current_stage: - execution.stage_results[execution.current_stage] = { - "status": "approved", - "approver": "user", - "comment": body.comment, - } - execution.status = "running" - store.update_execution( - execution.execution_id, - status="running", - stage_results=execution.stage_results, - ) - # Resume the waiting execution by setting the approval event - stage_name = execution.current_stage - if stage_name: - event_key = f"{execution_id}:{stage_name}" - if event_key in store._approval_events: - store._approval_events[event_key].set() - else: - execution.status = "cancelled" - execution.completed_at = datetime.now(timezone.utc).isoformat() - if execution.current_stage: - execution.stage_results[execution.current_stage] = { - "status": "rejected", - "approver": "user", - "comment": body.comment, - } - store.update_execution( - execution.execution_id, - status="cancelled", - completed_at=execution.completed_at, - stage_results=execution.stage_results, - ) - # Set the approval event so the waiting coroutine can observe the cancelled state - stage_name = execution.current_stage - if stage_name: - event_key = f"{execution_id}:{stage_name}" - if event_key in store._approval_events: - store._approval_events[event_key].set() + exec_lock = store.get_execution_lock(execution_id) + async with exec_lock: + # Re-fetch execution after acquiring lock + execution = store.get_execution(execution_id) + if execution is None: + raise HTTPException( + status_code=404, detail=f"执行记录 '{execution_id}' 不存在" + ) + if execution.status != "paused": + raise HTTPException( + status_code=400, detail="当前执行状态不是等待审批" + ) + + if body.approved: + if execution.current_stage: + execution.stage_results[execution.current_stage] = { + "status": "approved", + "approver": "user", + "comment": body.comment, + } + execution.status = "running" + await store.update_execution( + execution.execution_id, + status="running", + stage_results=execution.stage_results, + ) + # Resume the waiting execution by setting the approval event + stage_name = execution.current_stage + if stage_name: + event_key = f"{execution_id}:{stage_name}" + if event_key in store._approval_events: + store._approval_events[event_key].set() + else: + execution.status = "cancelled" + execution.completed_at = datetime.now(timezone.utc).isoformat() + if execution.current_stage: + execution.stage_results[execution.current_stage] = { + "status": "rejected", + "approver": "user", + "comment": body.comment, + } + await store.update_execution( + execution.execution_id, + status="cancelled", + completed_at=execution.completed_at, + stage_results=execution.stage_results, + ) + # Set the approval event so the waiting coroutine can observe the cancelled state + stage_name = execution.current_stage + if stage_name: + event_key = f"{execution_id}:{stage_name}" + if event_key in store._approval_events: + store._approval_events[event_key].set() return execution.model_dump() @@ -651,23 +707,33 @@ async def cancel_execution(request: Request, execution_id: str, _auth: None = De raise HTTPException( status_code=404, detail=f"执行记录 '{execution_id}' 不存在" ) - if execution.status not in ("running", "paused", "pending"): - raise HTTPException( - status_code=400, detail="当前执行状态无法取消" - ) - execution.status = "cancelled" - execution.completed_at = datetime.now(timezone.utc).isoformat() - store.update_execution( - execution.execution_id, - status="cancelled", - completed_at=execution.completed_at, - ) - # Set any pending approval event so a paused workflow can observe the cancelled state - if hasattr(execution, "current_stage") and execution.current_stage: - event_key = f"{execution_id}:{execution.current_stage}" - if event_key in store._approval_events: - store._approval_events[event_key].set() + exec_lock = store.get_execution_lock(execution_id) + async with exec_lock: + # Re-fetch execution after acquiring lock + execution = store.get_execution(execution_id) + if execution is None: + raise HTTPException( + status_code=404, detail=f"执行记录 '{execution_id}' 不存在" + ) + if execution.status not in ("running", "paused", "pending"): + raise HTTPException( + status_code=400, detail="当前执行状态无法取消" + ) + + execution.status = "cancelled" + execution.completed_at = datetime.now(timezone.utc).isoformat() + await store.update_execution( + execution.execution_id, + status="cancelled", + completed_at=execution.completed_at, + ) + # Set any pending approval event so a paused workflow can observe the cancelled state + if hasattr(execution, "current_stage") and execution.current_stage: + event_key = f"{execution_id}:{execution.current_stage}" + if event_key in store._approval_events: + store._approval_events[event_key].set() + return execution.model_dump() @@ -683,21 +749,37 @@ async def workflow_websocket(websocket: WebSocket): if configured_api_key: provided = websocket.query_params.get("api_key") - if provided != configured_api_key: - await websocket.accept() - await websocket.send_json( - {"event": "error", "data": {"message": "Invalid or missing api_key"}} - ) + if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()): await websocket.close(code=4001, reason="Invalid or missing api_key") return await websocket.accept() - _ws_subscribers.append(websocket) + + # Determine execution_id from query params; may be None for backward compat + execution_id = websocket.query_params.get("execution_id") + subscribed_execution_id: str | None = None + + if execution_id: + await _ws_subscribe(execution_id, websocket) + subscribed_execution_id = execution_id try: while True: try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0) + # Handle subscription messages + try: + msg = json.loads(raw) + if isinstance(msg, dict) and msg.get("type") == "subscribe": + new_eid = msg.get("execution_id") + if new_eid: + # Unsubscribe from previous if any + if subscribed_execution_id: + await _ws_unsubscribe(subscribed_execution_id, websocket) + await _ws_subscribe(new_eid, websocket) + subscribed_execution_id = new_eid + except (json.JSONDecodeError, KeyError): + pass except asyncio.TimeoutError: await websocket.close(code=1000, reason="Heartbeat timeout") return @@ -707,5 +789,5 @@ async def workflow_websocket(websocket: WebSocket): except Exception as e: logger.error(f"Workflow WebSocket error: {e}") finally: - if websocket in _ws_subscribers: - _ws_subscribers.remove(websocket) + if subscribed_execution_id: + await _ws_unsubscribe(subscribed_execution_id, websocket) diff --git a/src/agentkit/tools/computer_use.py b/src/agentkit/tools/computer_use.py index 2a08878..cacfc6c 100644 --- a/src/agentkit/tools/computer_use.py +++ b/src/agentkit/tools/computer_use.py @@ -7,6 +7,7 @@ from __future__ import annotations +import asyncio import base64 import logging from typing import Any, Callable, Awaitable @@ -95,11 +96,14 @@ class ComputerUseTool(Tool): self._max_retries = max_retries self._request_timeout = request_timeout self._http_client: httpx.AsyncClient | None = None + self._client_lock = asyncio.Lock() - def _get_http_client(self) -> httpx.AsyncClient: + async def _get_http_client(self) -> httpx.AsyncClient: """Get or create a persistent httpx.AsyncClient for connection reuse.""" if self._http_client is None or self._http_client.is_closed: - self._http_client = httpx.AsyncClient(timeout=self._request_timeout) + async with self._client_lock: + if self._http_client is None or self._http_client.is_closed: + self._http_client = httpx.AsyncClient(timeout=self._request_timeout) return self._http_client async def close(self) -> None: @@ -391,7 +395,7 @@ class ComputerUseTool(Tool): "content-type": "application/json", } - client = self._get_http_client() + client = await self._get_http_client() response = await client.post( self._api_base_url, json=request_body, diff --git a/src/agentkit/tools/shell.py b/src/agentkit/tools/shell.py index 0d9a970..97f35ad 100644 --- a/src/agentkit/tools/shell.py +++ b/src/agentkit/tools/shell.py @@ -66,40 +66,33 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( "docker images", ) -# 危险命令模式:这些命令需要人工确认 -_DANGEROUS_PATTERNS: tuple[str, ...] = ( - "rm ", - "rm -", - "rmdir", - "mkfs", - "dd ", - "format", - "del ", - "erase", - "> /dev/", - "shutdown", - "reboot", - "init 0", - "init 6", - "kill -9", - "killall", - "chmod 777", - "chown", - "mv /", - "pip uninstall", - "npm uninstall", - "apt remove", - "yum remove", - "brew uninstall", - "docker rm", - "docker rmi", - "git push --force", - "git reset --hard", - "git clean -f", - "drop table", - "drop database", - "truncate", -) +# 危险命令检测 — 基于精确 token 匹配,避免子串误判 + +# 总是危险的二进制命令(无论参数) +_DANGEROUS_BINARIES: frozenset[str] = frozenset({ + "rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot", + "halt", "killall", "chown", "fdisk", "parted", +}) + +# 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合 +_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = { + "rm": {"-rf", "-fr", "-r", "-f"}, + "kill": {"-9", "-kill"}, + "chmod": {"777", "000"}, + "git": {"push --force", "push -f", "reset --hard", "clean -f"}, + "pip": {"uninstall"}, + "npm": {"uninstall"}, + "docker": {"rm", "rmi", "system prune"}, +} + +# 跨 token 的危险模式(编译后的正则) +_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [ + re.compile(r">\s*/dev/", re.IGNORECASE), + re.compile(r">\s*/etc/", re.IGNORECASE), + re.compile(r"drop\s+table", re.IGNORECASE), + re.compile(r"drop\s+database", re.IGNORECASE), + re.compile(r"truncate\s+table", re.IGNORECASE), +] _SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n') @@ -300,10 +293,15 @@ class ShellTool(Tool): timeout=timeout, ) except asyncio.TimeoutError: - proc.kill() + try: + proc.kill() + except ProcessLookupError: + logger.debug("Process already exited before kill()") + except OSError: + logger.debug("OSError killing process") await proc.wait() output = f"命令执行超时({timeout}s)" - exit_code = -1 + exit_code = proc.returncode if proc.returncode is not None else -1 else: output = stdout.decode("utf-8", errors="replace") if stdout else "" exit_code = proc.returncode if proc.returncode is not None else 0 @@ -394,10 +392,24 @@ class ShellTool(Tool): if binary.lower() == prefix_stripped: return False - # Dangerous pattern check + # Dangerous pattern check — token-based matching + binary_lower = binary.lower() + + # 1. Binary is always dangerous regardless of flags + if binary_lower in _DANGEROUS_BINARIES: + return True + + # 2. Binary is dangerous with specific flags/subcommands + if binary_lower in _DANGEROUS_BINARY_FLAGS: + cmd_str = " ".join(tokens).lower() + for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]: + if flag_pattern in cmd_str: + return True + + # 3. Cross-token dangerous patterns (regex) command_lower = command_stripped.lower() - for pattern in _DANGEROUS_PATTERNS: - if pattern in command_lower: + for pattern in _DANGEROUS_ARG_PATTERNS: + if pattern.search(command_lower): return True return True # Unknown commands are dangerous by default diff --git a/src/agentkit/utils/security.py b/src/agentkit/utils/security.py index c832859..8337e0f 100644 --- a/src/agentkit/utils/security.py +++ b/src/agentkit/utils/security.py @@ -1,8 +1,16 @@ """Security utilities for URL validation.""" import ipaddress +import socket from urllib.parse import urlparse +_blocked_hostnames = { + "localhost", + "metadata.google.internal", + "metadata.internal", + "metadata.azure.com", +} + def is_safe_url(url: str) -> bool: """Check if URL is safe (not pointing to private/internal networks). @@ -20,13 +28,6 @@ def is_safe_url(url: str) -> bool: hostname = parsed.hostname if not hostname: return False - # Block known internal/metadata hostnames - _blocked_hostnames = { - "localhost", - "metadata.google.internal", - "metadata.internal", - "metadata.azure.com", - } hostname_lower = hostname.lower() if hostname_lower in _blocked_hostnames: return False @@ -37,9 +38,15 @@ def is_safe_url(url: str) -> bool: if _is_unsafe_ip(ip): return False except ValueError: - # hostname is a domain, not a literal IP — DNS rebinding risk remains - # (would need DNS resolution to fully mitigate) - pass + # hostname is a domain — resolve DNS and check IPs + try: + addr_infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) + for family, type_, proto, canonname, sockaddr in addr_infos: + ip = ipaddress.ip_address(sockaddr[0]) + if _is_unsafe_ip(ip): + return False + except (socket.gaierror, socket.timeout, OSError): + return False # fail-closed on DNS errors return True except Exception: return False @@ -58,3 +65,36 @@ def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local: return True return False + + +async def is_safe_url_async(url: str) -> bool: + """Async variant of is_safe_url with DNS resolution.""" + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False + hostname = parsed.hostname + if not hostname: + return False + hostname_lower = hostname.lower() + if hostname_lower in _blocked_hostnames: + return False + if hostname_lower.endswith((".internal", ".local", ".localhost")): + return False + try: + ip = ipaddress.ip_address(hostname) + if _is_unsafe_ip(ip): + return False + except ValueError: + try: + import asyncio + addr_infos = await asyncio.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) + for family, type_, proto, canonname, sockaddr in addr_infos: + ip = ipaddress.ip_address(sockaddr[0]) + if _is_unsafe_ip(ip): + return False + except (socket.gaierror, socket.timeout, OSError): + return False + return True + except Exception: + return False diff --git a/tests/unit/server/test_workflow_routes.py b/tests/unit/server/test_workflow_routes.py index 1231c96..99c9637 100644 --- a/tests/unit/server/test_workflow_routes.py +++ b/tests/unit/server/test_workflow_routes.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio + import pytest from fastapi.testclient import TestClient @@ -119,7 +121,7 @@ class TestWorkflowStore: store = WorkflowStore() wf = WorkflowDefinition(workflow_id="test-1", name="Test") - store.save(wf) + asyncio.run(store.save(wf)) result = store.get("test-1") assert result is not None assert result.name == "Test" @@ -131,36 +133,45 @@ class TestWorkflowStore: def test_list(self): from agentkit.orchestrator.workflow_schema import WorkflowDefinition - store = WorkflowStore() - for i in range(3): - store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) - summaries = store.list() - assert len(summaries) == 3 + async def _run(): + store = WorkflowStore() + for i in range(3): + await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) + summaries = store.list() + assert len(summaries) == 3 + + asyncio.run(_run()) def test_list_limit(self): from agentkit.orchestrator.workflow_schema import WorkflowDefinition - store = WorkflowStore() - for i in range(5): - store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) - summaries = store.list(limit=2) - assert len(summaries) == 2 + async def _run(): + store = WorkflowStore() + for i in range(5): + await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) + summaries = store.list(limit=2) + assert len(summaries) == 2 + + asyncio.run(_run()) def test_delete(self): from agentkit.orchestrator.workflow_schema import WorkflowDefinition - store = WorkflowStore() - store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me")) - assert store.delete("del-1") is True - assert store.get("del-1") is None + async def _run(): + store = WorkflowStore() + await store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me")) + assert await store.delete("del-1") is True + assert store.get("del-1") is None + + asyncio.run(_run()) def test_delete_not_found(self): store = WorkflowStore() - assert store.delete("nonexistent") is False + assert asyncio.run(store.delete("nonexistent")) is False def test_create_and_get_execution(self): store = WorkflowStore() - execution = store.create_execution("wf-1") + execution = asyncio.run(store.create_execution("wf-1")) assert execution.workflow_id == "wf-1" assert execution.status == "pending" @@ -173,18 +184,51 @@ class TestWorkflowStore: assert store.get_execution("nonexistent") is None def test_update_execution(self): - store = WorkflowStore() - execution = store.create_execution("wf-1") - updated = store.update_execution( - execution.execution_id, status="running", current_stage="step-1" - ) - assert updated.status == "running" - assert updated.current_stage == "step-1" + async def _run(): + store = WorkflowStore() + execution = await store.create_execution("wf-1") + updated = await store.update_execution( + execution.execution_id, status="running", current_stage="step-1" + ) + assert updated.status == "running" + assert updated.current_stage == "step-1" + + asyncio.run(_run()) def test_update_execution_not_found(self): store = WorkflowStore() with pytest.raises(KeyError): - store.update_execution("nonexistent", status="running") + asyncio.run(store.update_execution("nonexistent", status="running")) + + def test_running_tasks_initialized(self): + store = WorkflowStore() + assert hasattr(store, "_running_tasks") + assert isinstance(store._running_tasks, dict) + assert len(store._running_tasks) == 0 + + def test_execution_locks_initialized(self): + store = WorkflowStore() + assert hasattr(store, "_execution_locks") + assert isinstance(store._execution_locks, dict) + + def test_evict_execution_cleans_approval_events(self): + async def _run(): + store = WorkflowStore(max_executions=2) + e1 = await store.create_execution("wf-1") + e2 = await store.create_execution("wf-2") + # Add an approval event for e1 + event_key = f"{e1.execution_id}:stage1" + event = asyncio.Event() + store._approval_events[event_key] = event + # Create a third execution to trigger eviction of e1 + e3 = await store.create_execution("wf-3") + # e1 should be evicted, and its approval event should be cleaned up + assert store.get_execution(e1.execution_id) is None + assert event_key not in store._approval_events + # The event should have been set (to wake any waiting coroutine) + assert event.is_set() + + asyncio.run(_run()) # ---------------------------------------------------------------------------