fix(security,reliability): resolve all P2 findings from code review

This commit is contained in:
chiguyong 2026-06-10 15:05:40 +08:00
parent 658e188939
commit 6852dfe892
10 changed files with 770 additions and 278 deletions

View File

@ -0,0 +1,287 @@
# fix: AgentKit P2 Security & Reliability Hardening
**Status:** active
**Created:** 2026-06-10
**Origin:** ce-code-review findings on main branch after 7-capabilities merge
---
## Summary
Fix all remaining P2 issues from the final code review of the `feat/agentkit-7-capabilities` merge. These 16 issues span security (SSRF, auth, injection), reliability (retry, concurrency, resource leaks), and architecture (auth consistency, config reload). All are non-blocking for current functionality but represent real risk in production.
## Problem Frame
After merging 7 core capabilities and fixing P0/P1 issues, the codebase has 16 P2 findings that fall into four categories:
1. **Security gaps** — SSRF DNS rebinding, API key timing attacks, WebSocket auth bypass, approval race conditions, shell pattern matching, .env injection
2. **Reliability gaps** — retry storms, unbounded memory growth, concurrent modification, config change loss
3. **Architecture gaps** — dual auth layer inconsistency, WebSocket broadcast without isolation
4. **Defensive gaps** — process kill race, lazy init race, httpx client double-creation
## Requirements
- R1: All P2 security findings must be resolved or explicitly deferred with documented rationale
- R2: Retry logic must use exponential backoff with jitter
- R3: All in-memory stores must have bounded growth with proper eviction
- R4: Concurrent access to shared mutable state must be protected
- R5: Authentication must be consistent across all endpoints and transports
- R6: No regression in existing test suite (764+ tests must pass)
## Key Technical Decisions
- **KTD1: DNS resolution for SSRF** — Use `asyncio.getaddrinfo` (not blocking `socket.getaddrinfo`) to resolve hostnames before IP validation. Fail-closed on DNS errors. Accept that TOCTOU remains between check and connect; full mitigation requires IP pinning at the httpx transport level (deferred).
- **KTD2: Auth consolidation strategy** — Fix per-endpoint auth first (hmac.compare_digest, WebSocket auth-before-accept), then consolidate into middleware in a separate pass. This avoids a big-bang refactor and allows incremental testing.
- **KTD3: WorkflowStore concurrency** — Add `asyncio.Lock` to WorkflowStore, convert sync methods to async. This is the most invasive change but necessary for correctness under concurrent load.
- **KTD4: Shell dangerous patterns** — Replace substring matching with token-based matching using the already-parsed shlex tokens. Separate binary-level and flag-level dangerous patterns.
- **KTD5: WebSocket isolation** — Replace flat subscriber list with `dict[str, set[WebSocket]]` keyed by execution_id. Clients subscribe to specific executions.
---
## Implementation Units
### U1. API Key Timing Attack & WebSocket Auth-Before-Accept
**Goal:** Eliminate timing side-channel in API key comparison and fix WebSocket accept-before-auth pattern.
**Requirements:** R5
**Dependencies:** None
**Files:**
- `src/agentkit/server/routes/portal.py`
- `src/agentkit/server/routes/workflows.py`
- `src/agentkit/server/middleware.py`
- `tests/unit/server/test_portal_routes.py`
- `tests/unit/server/test_workflow_routes.py`
**Approach:**
1. Add `import hmac` to portal.py, workflows.py, middleware.py
2. Replace all `!=` API key comparisons with `hmac.compare_digest()`
3. In both WebSocket handlers, move auth validation BEFORE `websocket.accept()`. On auth failure, call `await websocket.close(code=4001)` without accepting, then return
4. In middleware.py, replace `provided_key not in valid_keys` with `any(hmac.compare_digest(...) for k in valid_keys)`
5. Handle None/empty provided key by defaulting to empty string before comparison
**Test scenarios:**
- API key comparison with correct key returns True
- API key comparison with incorrect key returns False
- API key comparison with None provided defaults to empty string
- WebSocket connection without api_key is rejected (no 101 upgrade)
- WebSocket connection with wrong api_key is rejected
- WebSocket connection with correct api_key is accepted
**Verification:** All existing tests pass; new auth tests pass
---
### U2. Shell Dangerous Pattern Token-Based Matching
**Goal:** Replace substring-based dangerous pattern matching with precise token-based matching to eliminate false positives and bypasses.
**Requirements:** R1
**Dependencies:** None
**Files:**
- `src/agentkit/tools/shell.py`
- `tests/unit/test_shell_tool.py`
**Approach:**
1. Replace `_DANGEROUS_PATTERNS` tuple with two new structures:
- `_DANGEROUS_BINARIES`: set of binary names that are always dangerous (e.g., `rm`, `mkfs`, `dd`, `shutdown`)
- `_DANGEROUS_BINARY_FLAGS`: dict mapping binary name to set of dangerous flag combinations (e.g., `kill: {-9, -KILL}`, `chmod: {777}`)
2. In `_is_dangerous()`, after shlex parsing:
- Check `binary` against `_DANGEROUS_BINARIES` (exact match)
- Check `binary` + flags against `_DANGEROUS_BINARY_FLAGS`
- Keep regex-based arg patterns for cross-token matches (e.g., `> /dev/`, `drop table`)
3. Remove overly broad patterns like `format`, `erase` as substrings; add `format` as binary-only
4. Normalize whitespace in command before matching
**Test scenarios:**
- `rm -rf /` detected as dangerous (binary match)
- `echo 'performing rm operations'` NOT detected as dangerous (no binary match)
- `kill -9 1234` detected as dangerous (binary + flag match)
- `kill -15 1234` NOT detected as dangerous (flag not in dangerous set)
- `rm\t/tmp/file` detected as dangerous (tab normalization)
- `format C:` detected as dangerous (binary match)
- `informatica` NOT detected as dangerous (not a binary match)
**Verification:** Shell tool tests pass with new pattern logic
---
### U3. .env File Prefix Filtering & Retry Backoff
**Goal:** Prevent arbitrary environment variable injection from .env files and add exponential backoff to plan executor retries.
**Requirements:** R1, R2
**Dependencies:** None
**Files:**
- `src/agentkit/server/app.py`
- `src/agentkit/core/plan_executor.py`
- `tests/unit/test_plan_executor.py`
**Approach:**
1. In app.py .env loading, add an allowlist of prefixes: `AGENTKIT_`, `OPENAI_`, `ANTHROPIC_`, `GEMINI_`, `TAVILY_`, `SERPER_`, plus exact matches for `DATABASE_URL`, `REDIS_URL`
2. Log a warning when a .env variable is skipped due to prefix filtering
3. In plan_executor.py `_execute_step_with_retry()`, add exponential backoff with jitter after each failure:
- `base_retry_delay=1.0s`, `max_retry_delay=30.0s`
- `delay = min(base * 2^(retry-1), max) * (0.5 + random() * 0.5)`
- Skip delay on first attempt (retry_count == 0)
4. Add `base_retry_delay` and `max_retry_delay` as constructor parameters with defaults
**Test scenarios:**
- .env with `AGENTKIT_FOO=bar` is loaded
- .env with `PATH=/malicious` is skipped with warning
- .env with `OPENAI_API_KEY=sk-xxx` is loaded
- Retry backoff: first retry waits ~1s, second ~2s, third ~4s (with jitter)
- Max retry delay caps at 30s
- Plan executor with max_retries=0 still works (no backoff)
**Verification:** Unit tests pass; .env loading respects prefix filter
---
### U4. WorkflowStore Concurrency & Resource Management
**Goal:** Make WorkflowStore thread-safe, fix unbounded growth, add proper cleanup for approval events and running tasks.
**Requirements:** R3, R4
**Dependencies:** U1 (hmac.compare_digest in workflow routes)
**Files:**
- `src/agentkit/server/routes/workflows.py`
- `tests/unit/server/test_workflow_routes.py`
**Approach:**
1. Add `asyncio.Lock` to WorkflowStore; convert `save`, `delete`, `create_execution`, `update_execution` to async methods
2. Initialize `_running_tasks` in `__init__` instead of lazy `getattr`
3. Add `_evict_execution()` helper that removes execution AND its approval events (with `event.set()` to wake waiting coroutines)
4. Replace `_ws_subscribers: list` with `_ws_subscribers: dict[str, set[WebSocket]]` keyed by execution_id, protected by `asyncio.Lock`
5. Add `_subscribe(execution_id, ws)` and `_unsubscribe(execution_id, ws)` helpers
6. Modify `_broadcast_ws` to accept `execution_id` parameter and only send to relevant subscribers
7. Update all call sites from sync to async (all endpoint functions already async)
8. Add `_execution_locks: dict[str, asyncio.Lock]` for per-execution approve/cancel serialization
**Test scenarios:**
- Concurrent save operations don't cause KeyError or over-eviction
- Evicted execution's approval events are cleaned up
- WebSocket subscriber receives only messages for subscribed execution_id
- WebSocket disconnect properly cleans up from subscriber dict
- Concurrent approve/cancel for same execution serialized by lock
- _running_tasks initialized in __init__, no getattr fallback
**Verification:** Workflow route tests pass; no RuntimeError from concurrent modification
---
### U5. SSRF DNS Resolution & ComputerUse Client Safety
**Goal:** Add DNS resolution to SSRF protection and make ComputerUseTool's httpx client creation thread-safe.
**Requirements:** R1
**Dependencies:** None
**Files:**
- `src/agentkit/utils/security.py`
- `src/agentkit/tools/computer_use.py`
- `tests/unit/test_security.py`
- `tests/unit/tools/test_computer_use.py`
**Approach:**
1. In `is_safe_url()`, after the ValueError catch for domain names, add DNS resolution:
- Use `socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)` (sync, acceptable for tool init)
- Validate each resolved IP against `_is_unsafe_ip()`
- Fail-closed on DNS errors (return False)
- Add 3-second timeout on DNS resolution
2. Add `is_safe_url_async()` variant using `asyncio.getaddrinfo` for async callers
3. In ComputerUseTool, add `asyncio.Lock` to `_get_http_client()`, make it async with double-check locking
4. Update caller from `client = self._get_http_client()` to `client = await self._get_http_client()`
**Test scenarios:**
- is_safe_url with domain resolving to 127.0.0.1 returns False
- is_safe_url with domain resolving to public IP returns True
- is_safe_url with unresolvable domain returns False (fail-closed)
- is_safe_url with ::ffff:127.0.0.1 returns False
- is_safe_url with .internal TLD returns False
- ComputerUseTool concurrent _get_http_client returns same client instance
- ComputerUseTool close() properly cleans up client
**Verification:** Security and computer_use tests pass
---
### U6. Config Hot-Reload & Defensive Fixes
**Goal:** Fix config hot-reload lock race, add proc.kill() error handling, and initialize config reload lock eagerly.
**Requirements:** R4
**Dependencies:** None
**Files:**
- `src/agentkit/server/app.py`
- `src/agentkit/tools/shell.py`
**Approach:**
1. In app.py `_on_config_change()`:
- Initialize lock eagerly in `create_app()` (store on `app.state._config_reload_lock`)
- Replace `lock.locked()` skip with pending-flag pattern: set `app.state._config_reload_pending = True`, then in `_reload()` loop while pending flag is set
- This ensures no config change is silently dropped
2. In shell.py `_execute_standalone()`:
- Wrap `proc.kill()` in `try/except (ProcessLookupError, OSError)`
- Use `proc.returncode` if available instead of hardcoded -1
**Test scenarios:**
- Two rapid config changes both get applied (no silent drop)
- Config reload pending flag is cleared after reload completes
- Shell timeout with already-exited process returns clean error (no ProcessLookupError)
- Shell timeout output still shows timeout message
**Verification:** App startup and shell tool tests pass
---
## Scope Boundaries
### In Scope
- All 16 P2 findings from ce-code-review
- Test coverage for new security/reliability patterns
### Deferred to Follow-Up Work
- Full auth middleware consolidation (ASGI-level WebSocket middleware) — complex, separate PR
- IP pinning at httpx transport level for complete SSRF protection
- python-dotenv migration for robust .env parsing
- clients.yaml caching in middleware (per-request disk read)
- WorkflowStore persistence (currently in-memory only)
### Outside This Plan's Identity
- New features or capability additions
- Performance optimization beyond what's needed for reliability
- UI/UX changes
---
## Risks & Mitigations
| Risk | Impact | Mitigation |
|------|--------|------------|
| WorkflowStore async refactor breaks existing endpoints | High | Update all call sites; comprehensive test coverage |
| DNS resolution adds latency to URL validation | Low | 3s timeout; acceptable for tool initialization |
| Token-based shell matching misses new bypass patterns | Medium | Default-to-dangerous fallback; add patterns as discovered |
| .env prefix filtering breaks existing deployments | Medium | Comprehensive allowlist; log warnings for skipped vars |
| WebSocket isolation breaks existing WS clients | Medium | Maintain backward compat: accept without subscription, send all |
| Config reload pending loop runs indefinitely | Low | Coalescing naturally limits; only re-runs if new change arrived |
---
## System-Wide Impact
- **Security posture** significantly improved: timing attacks, SSRF, auth bypass all addressed
- **Reliability** improved: retry storms prevented, memory bounded, concurrent access safe
- **API compatibility**: WorkflowStore method signatures change from sync to async — all call sites updated
- **WebSocket protocol**: subscription message becomes recommended but not required for backward compat

View File

@ -14,6 +14,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import random
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -94,6 +95,8 @@ class PlanExecutor:
max_retries: int = 2, max_retries: int = 2,
step_timeout: float = 300.0, step_timeout: float = 300.0,
max_parallel: int = 5, max_parallel: int = 5,
base_retry_delay: float = 1.0,
max_retry_delay: float = 30.0,
on_step_complete: OnStepCompleteCallback | None = None, on_step_complete: OnStepCompleteCallback | None = None,
on_step_failed: OnStepFailedCallback | None = None, on_step_failed: OnStepFailedCallback | None = None,
on_human_intervention: OnHumanInterventionCallback | None = None, on_human_intervention: OnHumanInterventionCallback | None = None,
@ -104,6 +107,8 @@ class PlanExecutor:
max_retries: 步骤失败后最大重试次数 max_retries: 步骤失败后最大重试次数
step_timeout: 单个步骤超时时间 step_timeout: 单个步骤超时时间
max_parallel: 最大并行步骤数 max_parallel: 最大并行步骤数
base_retry_delay: 重试基础延迟
max_retry_delay: 重试最大延迟
on_step_complete: 步骤完成回调 on_step_complete: 步骤完成回调
on_step_failed: 步骤失败回调返回 FailureAction 决定后续处理 on_step_failed: 步骤失败回调返回 FailureAction 决定后续处理
on_human_intervention: 人工介入回调 on_human_intervention: 人工介入回调
@ -112,6 +117,8 @@ class PlanExecutor:
self._max_retries = max_retries self._max_retries = max_retries
self._step_timeout = step_timeout self._step_timeout = step_timeout
self._max_parallel = max_parallel self._max_parallel = max_parallel
self._base_retry_delay = base_retry_delay
self._max_retry_delay = max_retry_delay
self._on_step_complete = on_step_complete self._on_step_complete = on_step_complete
self._on_step_failed = on_step_failed self._on_step_failed = on_step_failed
self._on_human_intervention = on_human_intervention self._on_human_intervention = on_human_intervention
@ -250,6 +257,15 @@ class PlanExecutor:
retry_count += 1 retry_count += 1
if retry_count <= self._max_retries:
delay = min(
self._base_retry_delay * (2 ** (retry_count - 1)),
self._max_retry_delay,
)
delay *= (0.5 + random.random() * 0.5) # jitter
logger.info(f"Retrying step '{step.step_id}' in {delay:.1f}s (attempt {retry_count + 1}/{self._max_retries + 1})")
await asyncio.sleep(delay)
# 所有重试耗尽 # 所有重试耗尽
step.status = PlanStepStatus.FAILED step.status = PlanStepStatus.FAILED
step.error = last_error step.error = last_error

View File

@ -30,6 +30,12 @@ from agentkit.telemetry.setup import setup_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ALLOWED_ENV_PREFIXES = (
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
)
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
def _build_llm_gateway(config: ServerConfig) -> LLMGateway: def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
"""Build LLMGateway from ServerConfig, registering all providers.""" """Build LLMGateway from ServerConfig, registering all providers."""
@ -224,70 +230,69 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
Uses a lock to prevent concurrent config reloads from racing. Uses a lock to prevent concurrent config reloads from racing.
""" """
lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None) lock: asyncio.Lock = app.state._config_reload_lock
if lock is None:
lock = asyncio.Lock()
app.state._config_reload_lock = lock
if lock.locked(): app.state._config_reload_pending = True
logger.warning("Config reload already in progress, skipping")
return
async def _reload(): async def _reload():
if lock.locked():
return # Another reload running; it will check pending flag
async with lock: async with lock:
# Increment config version for audit while getattr(app.state, "_config_reload_pending", False):
current_version = getattr(app.state, "config_version", 0) + 1 app.state._config_reload_pending = False
app.state.config_version = current_version # Increment config version for audit
logger.info(f"Config change detected (v{current_version}), reloading...") current_version = getattr(app.state, "config_version", 0) + 1
app.state.config_version = current_version
logger.info(f"Config change detected (v{current_version}), reloading...")
# Rebuild LLMGateway if llm config changed # Rebuild LLMGateway if llm config changed
try: try:
new_gateway = _build_llm_gateway(config) new_gateway = _build_llm_gateway(config)
app.state.llm_gateway = new_gateway app.state.llm_gateway = new_gateway
# Also update the agent pool's gateway reference # Also update the agent pool's gateway reference
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._llm_gateway = new_gateway
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
app.state.intent_router._llm_gateway = new_gateway
logger.info(f"LLM Gateway reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload LLM Gateway: {e}")
# Reload skills if skill paths changed
try:
new_skill_registry = _build_skill_registry(config)
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
tool_registry = getattr(app.state, "tool_registry", None)
if tool_registry:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=new_skill_registry,
tool_registry=tool_registry,
)
for skill_path in (config.skill_paths or []):
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loader.load_from_directory(str(p))
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
app.state.skill_registry = new_skill_registry
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._skill_registry = new_skill_registry
logger.info(f"Skills reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload skills: {e}")
# Update config version on all agents
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None: if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._llm_gateway = new_gateway for agent in app.state.agent_pool._agents.values():
if hasattr(app.state, "intent_router") and app.state.intent_router is not None: if hasattr(agent, "_config_version"):
app.state.intent_router._llm_gateway = new_gateway agent._config_version = current_version
logger.info(f"LLM Gateway reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload LLM Gateway: {e}")
# Reload skills if skill paths changed logger.info(f"Config reload complete (v{current_version})")
try:
new_skill_registry = _build_skill_registry(config)
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
tool_registry = getattr(app.state, "tool_registry", None)
if tool_registry:
from agentkit.skills.loader import SkillLoader
loader = SkillLoader(
skill_registry=new_skill_registry,
tool_registry=tool_registry,
)
for skill_path in (config.skill_paths or []):
from pathlib import Path as _P
p = _P(skill_path)
if p.is_dir():
loader.load_from_directory(str(p))
elif p.is_file() and p.suffix in (".yaml", ".yml"):
try:
loader.load_from_file(str(p))
except Exception:
pass
app.state.skill_registry = new_skill_registry
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
app.state.agent_pool._skill_registry = new_skill_registry
logger.info(f"Skills reloaded (config v{current_version})")
except Exception as e:
logger.error(f"Failed to reload skills: {e}")
# Update config version on all agents
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
for agent in app.state.agent_pool._agents.values():
if hasattr(agent, "_config_version"):
agent._config_version = current_version
logger.info(f"Config reload complete (v{current_version})")
# Schedule the reload as a task (non-blocking for the watcher thread) # Schedule the reload as a task (non-blocking for the watcher thread)
try: try:
@ -327,6 +332,10 @@ def create_app(
_key = _key.strip() _key = _key.strip()
_val = _val.strip().strip("\"'") _val = _val.strip().strip("\"'")
if _key and _key not in os.environ: if _key and _key not in os.environ:
allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT
if not allowed:
logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)")
continue
os.environ[_key] = _val os.environ[_key] = _val
server_config = ServerConfig.from_yaml(config_path) server_config = ServerConfig.from_yaml(config_path)
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan) app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)

View File

@ -1,5 +1,6 @@
"""Server middleware - Authentication and Rate Limiting""" """Server middleware - Authentication and Rate Limiting"""
import hmac
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
@ -75,7 +76,7 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
# Check API key from header # Check API key from header
provided_key = request.headers.get("X-API-Key") provided_key = request.headers.get("X-API-Key")
if not provided_key or provided_key not in valid_keys: if not provided_key or not any(hmac.compare_digest(provided_key.encode(), k.encode()) for k in valid_keys):
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={"error": "Unauthorized", "message": "Invalid or missing API key"}, content={"error": "Unauthorized", "message": "Invalid or missing API key"},

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import hmac
import json import json
import logging import logging
import uuid import uuid
@ -47,7 +48,7 @@ async def _verify_api_key(
return return
provided = api_key_header or api_key_query provided = api_key_header or api_key_query
if provided != configured_api_key: if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.", detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
@ -475,11 +476,7 @@ async def portal_websocket(websocket: WebSocket):
# Check api_key query param # Check api_key query param
if configured_api_key: if configured_api_key:
provided = websocket.query_params.get("api_key") provided = websocket.query_params.get("api_key")
if provided != configured_api_key: if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
await websocket.accept()
await websocket.send_json(
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
)
await websocket.close(code=4001, reason="Invalid or missing api_key") await websocket.close(code=4001, reason="Invalid or missing api_key")
return return

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import hmac
import json import json
import logging import logging
import re import re
@ -52,7 +53,7 @@ async def _verify_api_key(
return return
provided = api_key_header or api_key_query provided = api_key_header or api_key_query
if provided != configured_api_key: if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.", detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
@ -65,7 +66,7 @@ async def _verify_api_key(
class WorkflowStore: class WorkflowStore:
"""In-memory workflow store.""" """In-memory workflow store with async-safe mutation methods."""
def __init__(self, max_workflows: int = 500, max_executions: int = 1000): def __init__(self, max_workflows: int = 500, max_executions: int = 1000):
self._workflows: dict[str, WorkflowDefinition] = {} self._workflows: dict[str, WorkflowDefinition] = {}
@ -73,17 +74,30 @@ class WorkflowStore:
self._max_workflows = max_workflows self._max_workflows = max_workflows
self._max_executions = max_executions self._max_executions = max_executions
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}" self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
self._running_tasks: dict[str, asyncio.Task] = {}
self._execution_locks: dict[str, asyncio.Lock] = {}
self._lock = asyncio.Lock()
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition: def _evict_execution(self, execution_id: str) -> None:
workflow.updated_at = datetime.now(timezone.utc).isoformat() """Remove execution and its associated approval events."""
self._workflows[workflow.workflow_id] = workflow self._executions.pop(execution_id, None)
# Evict oldest if over limit keys_to_remove = [k for k in self._approval_events if k.startswith(f"{execution_id}:")]
if len(self._workflows) > self._max_workflows: for k in keys_to_remove:
oldest_id = min( event = self._approval_events.pop(k, None)
self._workflows, key=lambda k: self._workflows[k].updated_at if event is not None:
) event.set() # Wake any waiting coroutine
del self._workflows[oldest_id]
return workflow async def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
async with self._lock:
workflow.updated_at = datetime.now(timezone.utc).isoformat()
self._workflows[workflow.workflow_id] = workflow
# Evict oldest if over limit
if len(self._workflows) > self._max_workflows:
oldest_id = min(
self._workflows, key=lambda k: self._workflows[k].updated_at
)
del self._workflows[oldest_id]
return workflow
def get(self, workflow_id: str) -> WorkflowDefinition | None: def get(self, workflow_id: str) -> WorkflowDefinition | None:
return self._workflows.get(workflow_id) return self._workflows.get(workflow_id)
@ -109,47 +123,72 @@ class WorkflowStore:
) )
return summaries return summaries
def delete(self, workflow_id: str) -> bool: async def delete(self, workflow_id: str) -> bool:
if workflow_id in self._workflows: async with self._lock:
del self._workflows[workflow_id] if workflow_id in self._workflows:
return True del self._workflows[workflow_id]
return False return True
return False
def create_execution(self, workflow_id: str) -> WorkflowExecution: async def create_execution(self, workflow_id: str) -> WorkflowExecution:
execution = WorkflowExecution( async with self._lock:
execution_id=str(uuid.uuid4()), execution = WorkflowExecution(
workflow_id=workflow_id, execution_id=str(uuid.uuid4()),
status="pending", workflow_id=workflow_id,
started_at=datetime.now(timezone.utc).isoformat(), status="pending",
) started_at=datetime.now(timezone.utc).isoformat(),
self._executions[execution.execution_id] = execution
# Evict oldest if over limit
if len(self._executions) > self._max_executions:
oldest_id = min(
self._executions,
key=lambda k: self._executions[k].started_at or "",
) )
del self._executions[oldest_id] self._executions[execution.execution_id] = execution
return execution # Evict oldest if over limit
if len(self._executions) > self._max_executions:
oldest_id = min(
self._executions,
key=lambda k: self._executions[k].started_at or "",
)
self._evict_execution(oldest_id)
return execution
def get_execution(self, execution_id: str) -> WorkflowExecution | None: def get_execution(self, execution_id: str) -> WorkflowExecution | None:
return self._executions.get(execution_id) return self._executions.get(execution_id)
def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution: async def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
execution = self._executions.get(execution_id) async with self._lock:
if execution is None: execution = self._executions.get(execution_id)
raise KeyError(f"Execution '{execution_id}' not found") if execution is None:
for key, value in kwargs.items(): raise KeyError(f"Execution '{execution_id}' not found")
if hasattr(execution, key): for key, value in kwargs.items():
setattr(execution, key, value) if hasattr(execution, key):
return execution setattr(execution, key, value)
return execution
def get_execution_lock(self, execution_id: str) -> asyncio.Lock:
"""Get or create a per-execution lock for approve/cancel serialization."""
if execution_id not in self._execution_locks:
self._execution_locks[execution_id] = asyncio.Lock()
return self._execution_locks[execution_id]
# Module-level singleton # Module-level singleton
_workflow_store = WorkflowStore() _workflow_store = WorkflowStore()
# WebSocket subscribers for real-time execution progress # WebSocket subscribers for real-time execution progress (keyed by execution_id)
_ws_subscribers: list[WebSocket] = [] _ws_subscribers: dict[str, set[WebSocket]] = {}
_ws_subscribers_lock = asyncio.Lock()
async def _ws_subscribe(execution_id: str, ws: WebSocket) -> None:
async with _ws_subscribers_lock:
if execution_id not in _ws_subscribers:
_ws_subscribers[execution_id] = set()
_ws_subscribers[execution_id].add(ws)
async def _ws_unsubscribe(execution_id: str, ws: WebSocket) -> None:
async with _ws_subscribers_lock:
if execution_id in _ws_subscribers:
_ws_subscribers[execution_id].discard(ws)
if not _ws_subscribers[execution_id]:
del _ws_subscribers[execution_id]
def _get_store(request: Request) -> WorkflowStore: def _get_store(request: Request) -> WorkflowStore:
@ -210,7 +249,7 @@ async def _execute_workflow(
"""Execute a workflow by running its stages in topological order.""" """Execute a workflow by running its stages in topological order."""
_store = store or _workflow_store _store = store or _workflow_store
execution.status = "running" execution.status = "running"
_store.update_execution(execution.execution_id, status="running") await _store.update_execution(execution.execution_id, status="running")
# Topological sort # Topological sort
stage_map = {s.name: s for s in workflow.stages} stage_map = {s.name: s for s in workflow.stages}
@ -229,7 +268,7 @@ async def _execute_workflow(
execution.status = "failed" execution.status = "failed"
execution.error = "循环依赖" execution.error = "循环依赖"
execution.completed_at = datetime.now(timezone.utc).isoformat() execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="failed", status="failed",
error="循环依赖", error="循环依赖",
@ -246,7 +285,7 @@ async def _execute_workflow(
for stage_name in ordered: for stage_name in ordered:
stage = stage_map[stage_name] stage = stage_map[stage_name]
execution.current_stage = stage_name execution.current_stage = stage_name
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
current_stage=stage_name, current_stage=stage_name,
) )
@ -256,7 +295,7 @@ async def _execute_workflow(
"event": "stage_started", "event": "stage_started",
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
}) }, execution_id=execution.execution_id)
try: try:
if stage.type == "approval": if stage.type == "approval":
@ -267,7 +306,7 @@ async def _execute_workflow(
execution.status = "paused" execution.status = "paused"
execution.current_stage = stage_name execution.current_stage = stage_name
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="paused", status="paused",
current_stage=stage_name, current_stage=stage_name,
@ -276,7 +315,7 @@ async def _execute_workflow(
"event": "approval_required", "event": "approval_required",
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
}) }, execution_id=execution.execution_id)
# Wait for approval with timeout # Wait for approval with timeout
try: try:
@ -289,13 +328,13 @@ async def _execute_workflow(
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
"error": "Approval rejected", "error": "Approval rejected",
}) }, execution_id=execution.execution_id)
return return
# Approval was granted — the /approve endpoint already set stage_results # Approval was granted — the /approve endpoint already set stage_results
# Only update status to running if not already set # Only update status to running if not already set
if execution.status != "running": if execution.status != "running":
execution.status = "running" execution.status = "running"
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="running", status="running",
) )
@ -308,7 +347,7 @@ async def _execute_workflow(
execution.status = "failed" execution.status = "failed"
execution.error = f"Approval timeout for stage {stage_name}" execution.error = f"Approval timeout for stage {stage_name}"
execution.completed_at = datetime.now(timezone.utc).isoformat() execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="failed", status="failed",
error=execution.error, error=execution.error,
@ -320,7 +359,7 @@ async def _execute_workflow(
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
"error": "Approval timeout", "error": "Approval timeout",
}) }, execution_id=execution.execution_id)
return return
finally: finally:
_store._approval_events.pop(event_key, None) _store._approval_events.pop(event_key, None)
@ -332,7 +371,7 @@ async def _execute_workflow(
"status": "completed", "status": "completed",
"condition_result": result, "condition_result": result,
} }
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
stage_results=execution.stage_results, stage_results=execution.stage_results,
) )
@ -342,7 +381,7 @@ async def _execute_workflow(
"status": "completed", "status": "completed",
"output": {"dry_run": True, "action": stage.action}, "output": {"dry_run": True, "action": stage.action},
} }
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
stage_results=execution.stage_results, stage_results=execution.stage_results,
) )
@ -351,7 +390,7 @@ async def _execute_workflow(
"event": "stage_completed", "event": "stage_completed",
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
}) }, execution_id=execution.execution_id)
except Exception as e: except Exception as e:
execution.stage_results[stage_name] = { execution.stage_results[stage_name] = {
@ -361,7 +400,7 @@ async def _execute_workflow(
execution.status = "failed" execution.status = "failed"
execution.error = f"阶段 '{stage_name}' 执行失败: {e}" execution.error = f"阶段 '{stage_name}' 执行失败: {e}"
execution.completed_at = datetime.now(timezone.utc).isoformat() execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="failed", status="failed",
error=execution.error, error=execution.error,
@ -373,13 +412,13 @@ async def _execute_workflow(
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
"error": str(e), "error": str(e),
}) }, execution_id=execution.execution_id)
return return
execution.status = "completed" execution.status = "completed"
execution.completed_at = datetime.now(timezone.utc).isoformat() execution.completed_at = datetime.now(timezone.utc).isoformat()
execution.current_stage = None execution.current_stage = None
_store.update_execution( await _store.update_execution(
execution.execution_id, execution.execution_id,
status="completed", status="completed",
completed_at=execution.completed_at, completed_at=execution.completed_at,
@ -388,7 +427,7 @@ async def _execute_workflow(
await _broadcast_ws({ await _broadcast_ws({
"event": "execution_completed", "event": "execution_completed",
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
}) }, execution_id=execution.execution_id)
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') _SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
@ -455,16 +494,25 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
raise ValueError(f"Invalid condition expression: {expression}") raise ValueError(f"Invalid condition expression: {expression}")
async def _broadcast_ws(message: dict[str, Any]) -> None: async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None) -> None:
"""Broadcast a message to all WebSocket subscribers.""" """Broadcast a message to WebSocket subscribers for a specific execution."""
async with _ws_subscribers_lock:
targets = set()
if execution_id and execution_id in _ws_subscribers:
targets = set(_ws_subscribers[execution_id]) # snapshot
disconnected = [] disconnected = []
for ws in _ws_subscribers: for ws in targets:
try: try:
await ws.send_json(message) await ws.send_json(message)
except Exception: except Exception:
disconnected.append(ws) disconnected.append(ws)
for ws in disconnected: if disconnected:
_ws_subscribers.remove(ws) async with _ws_subscribers_lock:
for ws in disconnected:
for eid in list(_ws_subscribers.keys()):
_ws_subscribers[eid].discard(ws)
if not _ws_subscribers[eid]:
del _ws_subscribers[eid]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -494,7 +542,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth:
variables_schema=body.variables_schema, variables_schema=body.variables_schema,
output_schema=body.output_schema, output_schema=body.output_schema,
) )
saved = store.save(workflow) saved = await store.save(workflow)
return saved.model_dump() return saved.model_dump()
@ -527,7 +575,7 @@ async def update_workflow(
existing.variables_schema = body.variables_schema existing.variables_schema = body.variables_schema
existing.output_schema = body.output_schema existing.output_schema = body.output_schema
existing.version += 1 existing.version += 1
saved = store.save(existing) saved = await store.save(existing)
return saved.model_dump() return saved.model_dump()
@ -535,7 +583,7 @@ async def update_workflow(
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)): async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
"""Delete a workflow.""" """Delete a workflow."""
store = _get_store(request) store = _get_store(request)
deleted = store.delete(workflow_id) deleted = await store.delete(workflow_id)
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在") raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
return {"message": "已删除"} return {"message": "已删除"}
@ -551,14 +599,13 @@ async def execute_workflow(
if workflow is None: if workflow is None:
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在") raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
execution = store.create_execution(workflow_id) execution = await store.create_execution(workflow_id)
execution.variables = body.variables execution.variables = body.variables
# Start execution in background # Start execution in background
task = asyncio.create_task( task = asyncio.create_task(
_execute_workflow(workflow, execution, body.variables, store=store) _execute_workflow(workflow, execution, body.variables, store=store)
) )
store._running_tasks = getattr(store, "_running_tasks", {})
store._running_tasks[execution.execution_id] = task store._running_tasks[execution.execution_id] = task
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None)) task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
@ -593,51 +640,60 @@ async def approve_execution(
raise HTTPException( raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在" status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
) )
if execution.status != "paused":
raise HTTPException(
status_code=400, detail="当前执行状态不是等待审批"
)
if body.approved: exec_lock = store.get_execution_lock(execution_id)
if execution.current_stage: async with exec_lock:
execution.stage_results[execution.current_stage] = { # Re-fetch execution after acquiring lock
"status": "approved", execution = store.get_execution(execution_id)
"approver": "user", if execution is None:
"comment": body.comment, raise HTTPException(
} status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
execution.status = "running" )
store.update_execution( if execution.status != "paused":
execution.execution_id, raise HTTPException(
status="running", status_code=400, detail="当前执行状态不是等待审批"
stage_results=execution.stage_results, )
)
# Resume the waiting execution by setting the approval event if body.approved:
stage_name = execution.current_stage if execution.current_stage:
if stage_name: execution.stage_results[execution.current_stage] = {
event_key = f"{execution_id}:{stage_name}" "status": "approved",
if event_key in store._approval_events: "approver": "user",
store._approval_events[event_key].set() "comment": body.comment,
else: }
execution.status = "cancelled" execution.status = "running"
execution.completed_at = datetime.now(timezone.utc).isoformat() await store.update_execution(
if execution.current_stage: execution.execution_id,
execution.stage_results[execution.current_stage] = { status="running",
"status": "rejected", stage_results=execution.stage_results,
"approver": "user", )
"comment": body.comment, # Resume the waiting execution by setting the approval event
} stage_name = execution.current_stage
store.update_execution( if stage_name:
execution.execution_id, event_key = f"{execution_id}:{stage_name}"
status="cancelled", if event_key in store._approval_events:
completed_at=execution.completed_at, store._approval_events[event_key].set()
stage_results=execution.stage_results, else:
) execution.status = "cancelled"
# Set the approval event so the waiting coroutine can observe the cancelled state execution.completed_at = datetime.now(timezone.utc).isoformat()
stage_name = execution.current_stage if execution.current_stage:
if stage_name: execution.stage_results[execution.current_stage] = {
event_key = f"{execution_id}:{stage_name}" "status": "rejected",
if event_key in store._approval_events: "approver": "user",
store._approval_events[event_key].set() "comment": body.comment,
}
await store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
# Set the approval event so the waiting coroutine can observe the cancelled state
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump() return execution.model_dump()
@ -651,23 +707,33 @@ async def cancel_execution(request: Request, execution_id: str, _auth: None = De
raise HTTPException( raise HTTPException(
status_code=404, detail=f"执行记录 '{execution_id}' 不存在" status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
) )
if execution.status not in ("running", "paused", "pending"):
raise HTTPException(
status_code=400, detail="当前执行状态无法取消"
)
execution.status = "cancelled" exec_lock = store.get_execution_lock(execution_id)
execution.completed_at = datetime.now(timezone.utc).isoformat() async with exec_lock:
store.update_execution( # Re-fetch execution after acquiring lock
execution.execution_id, execution = store.get_execution(execution_id)
status="cancelled", if execution is None:
completed_at=execution.completed_at, raise HTTPException(
) status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
# Set any pending approval event so a paused workflow can observe the cancelled state )
if hasattr(execution, "current_stage") and execution.current_stage: if execution.status not in ("running", "paused", "pending"):
event_key = f"{execution_id}:{execution.current_stage}" raise HTTPException(
if event_key in store._approval_events: status_code=400, detail="当前执行状态无法取消"
store._approval_events[event_key].set() )
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
await store.update_execution(
execution.execution_id,
status="cancelled",
completed_at=execution.completed_at,
)
# Set any pending approval event so a paused workflow can observe the cancelled state
if hasattr(execution, "current_stage") and execution.current_stage:
event_key = f"{execution_id}:{execution.current_stage}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump() return execution.model_dump()
@ -683,21 +749,37 @@ async def workflow_websocket(websocket: WebSocket):
if configured_api_key: if configured_api_key:
provided = websocket.query_params.get("api_key") provided = websocket.query_params.get("api_key")
if provided != configured_api_key: if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
await websocket.accept()
await websocket.send_json(
{"event": "error", "data": {"message": "Invalid or missing api_key"}}
)
await websocket.close(code=4001, reason="Invalid or missing api_key") await websocket.close(code=4001, reason="Invalid or missing api_key")
return return
await websocket.accept() await websocket.accept()
_ws_subscribers.append(websocket)
# Determine execution_id from query params; may be None for backward compat
execution_id = websocket.query_params.get("execution_id")
subscribed_execution_id: str | None = None
if execution_id:
await _ws_subscribe(execution_id, websocket)
subscribed_execution_id = execution_id
try: try:
while True: while True:
try: try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0) raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0)
# Handle subscription messages
try:
msg = json.loads(raw)
if isinstance(msg, dict) and msg.get("type") == "subscribe":
new_eid = msg.get("execution_id")
if new_eid:
# Unsubscribe from previous if any
if subscribed_execution_id:
await _ws_unsubscribe(subscribed_execution_id, websocket)
await _ws_subscribe(new_eid, websocket)
subscribed_execution_id = new_eid
except (json.JSONDecodeError, KeyError):
pass
except asyncio.TimeoutError: except asyncio.TimeoutError:
await websocket.close(code=1000, reason="Heartbeat timeout") await websocket.close(code=1000, reason="Heartbeat timeout")
return return
@ -707,5 +789,5 @@ async def workflow_websocket(websocket: WebSocket):
except Exception as e: except Exception as e:
logger.error(f"Workflow WebSocket error: {e}") logger.error(f"Workflow WebSocket error: {e}")
finally: finally:
if websocket in _ws_subscribers: if subscribed_execution_id:
_ws_subscribers.remove(websocket) await _ws_unsubscribe(subscribed_execution_id, websocket)

View File

@ -7,6 +7,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
import logging import logging
from typing import Any, Callable, Awaitable from typing import Any, Callable, Awaitable
@ -95,11 +96,14 @@ class ComputerUseTool(Tool):
self._max_retries = max_retries self._max_retries = max_retries
self._request_timeout = request_timeout self._request_timeout = request_timeout
self._http_client: httpx.AsyncClient | None = None self._http_client: httpx.AsyncClient | None = None
self._client_lock = asyncio.Lock()
def _get_http_client(self) -> httpx.AsyncClient: async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create a persistent httpx.AsyncClient for connection reuse.""" """Get or create a persistent httpx.AsyncClient for connection reuse."""
if self._http_client is None or self._http_client.is_closed: if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(timeout=self._request_timeout) async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
return self._http_client return self._http_client
async def close(self) -> None: async def close(self) -> None:
@ -391,7 +395,7 @@ class ComputerUseTool(Tool):
"content-type": "application/json", "content-type": "application/json",
} }
client = self._get_http_client() client = await self._get_http_client()
response = await client.post( response = await client.post(
self._api_base_url, self._api_base_url,
json=request_body, json=request_body,

View File

@ -66,40 +66,33 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"docker images", "docker images",
) )
# 危险命令模式:这些命令需要人工确认 # 危险命令检测 — 基于精确 token 匹配,避免子串误判
_DANGEROUS_PATTERNS: tuple[str, ...] = (
"rm ", # 总是危险的二进制命令(无论参数)
"rm -", _DANGEROUS_BINARIES: frozenset[str] = frozenset({
"rmdir", "rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot",
"mkfs", "halt", "killall", "chown", "fdisk", "parted",
"dd ", })
"format",
"del ", # 需要特定参数才危险的二进制命令binary → 危险 flag/子命令集合
"erase", _DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
"> /dev/", "rm": {"-rf", "-fr", "-r", "-f"},
"shutdown", "kill": {"-9", "-kill"},
"reboot", "chmod": {"777", "000"},
"init 0", "git": {"push --force", "push -f", "reset --hard", "clean -f"},
"init 6", "pip": {"uninstall"},
"kill -9", "npm": {"uninstall"},
"killall", "docker": {"rm", "rmi", "system prune"},
"chmod 777", }
"chown",
"mv /", # 跨 token 的危险模式(编译后的正则)
"pip uninstall", _DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
"npm uninstall", re.compile(r">\s*/dev/", re.IGNORECASE),
"apt remove", re.compile(r">\s*/etc/", re.IGNORECASE),
"yum remove", re.compile(r"drop\s+table", re.IGNORECASE),
"brew uninstall", re.compile(r"drop\s+database", re.IGNORECASE),
"docker rm", re.compile(r"truncate\s+table", re.IGNORECASE),
"docker rmi", ]
"git push --force",
"git reset --hard",
"git clean -f",
"drop table",
"drop database",
"truncate",
)
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n') _SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
@ -300,10 +293,15 @@ class ShellTool(Tool):
timeout=timeout, timeout=timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
proc.kill() try:
proc.kill()
except ProcessLookupError:
logger.debug("Process already exited before kill()")
except OSError:
logger.debug("OSError killing process")
await proc.wait() await proc.wait()
output = f"命令执行超时({timeout}s" output = f"命令执行超时({timeout}s"
exit_code = -1 exit_code = proc.returncode if proc.returncode is not None else -1
else: else:
output = stdout.decode("utf-8", errors="replace") if stdout else "" output = stdout.decode("utf-8", errors="replace") if stdout else ""
exit_code = proc.returncode if proc.returncode is not None else 0 exit_code = proc.returncode if proc.returncode is not None else 0
@ -394,10 +392,24 @@ class ShellTool(Tool):
if binary.lower() == prefix_stripped: if binary.lower() == prefix_stripped:
return False return False
# Dangerous pattern check # Dangerous pattern check — token-based matching
binary_lower = binary.lower()
# 1. Binary is always dangerous regardless of flags
if binary_lower in _DANGEROUS_BINARIES:
return True
# 2. Binary is dangerous with specific flags/subcommands
if binary_lower in _DANGEROUS_BINARY_FLAGS:
cmd_str = " ".join(tokens).lower()
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
if flag_pattern in cmd_str:
return True
# 3. Cross-token dangerous patterns (regex)
command_lower = command_stripped.lower() command_lower = command_stripped.lower()
for pattern in _DANGEROUS_PATTERNS: for pattern in _DANGEROUS_ARG_PATTERNS:
if pattern in command_lower: if pattern.search(command_lower):
return True return True
return True # Unknown commands are dangerous by default return True # Unknown commands are dangerous by default

View File

@ -1,8 +1,16 @@
"""Security utilities for URL validation.""" """Security utilities for URL validation."""
import ipaddress import ipaddress
import socket
from urllib.parse import urlparse from urllib.parse import urlparse
_blocked_hostnames = {
"localhost",
"metadata.google.internal",
"metadata.internal",
"metadata.azure.com",
}
def is_safe_url(url: str) -> bool: def is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks). """Check if URL is safe (not pointing to private/internal networks).
@ -20,13 +28,6 @@ def is_safe_url(url: str) -> bool:
hostname = parsed.hostname hostname = parsed.hostname
if not hostname: if not hostname:
return False return False
# Block known internal/metadata hostnames
_blocked_hostnames = {
"localhost",
"metadata.google.internal",
"metadata.internal",
"metadata.azure.com",
}
hostname_lower = hostname.lower() hostname_lower = hostname.lower()
if hostname_lower in _blocked_hostnames: if hostname_lower in _blocked_hostnames:
return False return False
@ -37,9 +38,15 @@ def is_safe_url(url: str) -> bool:
if _is_unsafe_ip(ip): if _is_unsafe_ip(ip):
return False return False
except ValueError: except ValueError:
# hostname is a domain, not a literal IP — DNS rebinding risk remains # hostname is a domain — resolve DNS and check IPs
# (would need DNS resolution to fully mitigate) try:
pass addr_infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, type_, proto, canonname, sockaddr in addr_infos:
ip = ipaddress.ip_address(sockaddr[0])
if _is_unsafe_ip(ip):
return False
except (socket.gaierror, socket.timeout, OSError):
return False # fail-closed on DNS errors
return True return True
except Exception: except Exception:
return False return False
@ -58,3 +65,36 @@ def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local: if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
return True return True
return False return False
async def is_safe_url_async(url: str) -> bool:
"""Async variant of is_safe_url with DNS resolution."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
if hostname_lower in _blocked_hostnames:
return False
if hostname_lower.endswith((".internal", ".local", ".localhost")):
return False
try:
ip = ipaddress.ip_address(hostname)
if _is_unsafe_ip(ip):
return False
except ValueError:
try:
import asyncio
addr_infos = await asyncio.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
for family, type_, proto, canonname, sockaddr in addr_infos:
ip = ipaddress.ip_address(sockaddr[0])
if _is_unsafe_ip(ip):
return False
except (socket.gaierror, socket.timeout, OSError):
return False
return True
except Exception:
return False

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -119,7 +121,7 @@ class TestWorkflowStore:
store = WorkflowStore() store = WorkflowStore()
wf = WorkflowDefinition(workflow_id="test-1", name="Test") wf = WorkflowDefinition(workflow_id="test-1", name="Test")
store.save(wf) asyncio.run(store.save(wf))
result = store.get("test-1") result = store.get("test-1")
assert result is not None assert result is not None
assert result.name == "Test" assert result.name == "Test"
@ -131,36 +133,45 @@ class TestWorkflowStore:
def test_list(self): def test_list(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore() async def _run():
for i in range(3): store = WorkflowStore()
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) for i in range(3):
summaries = store.list() await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
assert len(summaries) == 3 summaries = store.list()
assert len(summaries) == 3
asyncio.run(_run())
def test_list_limit(self): def test_list_limit(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore() async def _run():
for i in range(5): store = WorkflowStore()
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}")) for i in range(5):
summaries = store.list(limit=2) await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
assert len(summaries) == 2 summaries = store.list(limit=2)
assert len(summaries) == 2
asyncio.run(_run())
def test_delete(self): def test_delete(self):
from agentkit.orchestrator.workflow_schema import WorkflowDefinition from agentkit.orchestrator.workflow_schema import WorkflowDefinition
store = WorkflowStore() async def _run():
store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me")) store = WorkflowStore()
assert store.delete("del-1") is True await store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
assert store.get("del-1") is None assert await store.delete("del-1") is True
assert store.get("del-1") is None
asyncio.run(_run())
def test_delete_not_found(self): def test_delete_not_found(self):
store = WorkflowStore() store = WorkflowStore()
assert store.delete("nonexistent") is False assert asyncio.run(store.delete("nonexistent")) is False
def test_create_and_get_execution(self): def test_create_and_get_execution(self):
store = WorkflowStore() store = WorkflowStore()
execution = store.create_execution("wf-1") execution = asyncio.run(store.create_execution("wf-1"))
assert execution.workflow_id == "wf-1" assert execution.workflow_id == "wf-1"
assert execution.status == "pending" assert execution.status == "pending"
@ -173,18 +184,51 @@ class TestWorkflowStore:
assert store.get_execution("nonexistent") is None assert store.get_execution("nonexistent") is None
def test_update_execution(self): def test_update_execution(self):
store = WorkflowStore() async def _run():
execution = store.create_execution("wf-1") store = WorkflowStore()
updated = store.update_execution( execution = await store.create_execution("wf-1")
execution.execution_id, status="running", current_stage="step-1" updated = await store.update_execution(
) execution.execution_id, status="running", current_stage="step-1"
assert updated.status == "running" )
assert updated.current_stage == "step-1" assert updated.status == "running"
assert updated.current_stage == "step-1"
asyncio.run(_run())
def test_update_execution_not_found(self): def test_update_execution_not_found(self):
store = WorkflowStore() store = WorkflowStore()
with pytest.raises(KeyError): with pytest.raises(KeyError):
store.update_execution("nonexistent", status="running") asyncio.run(store.update_execution("nonexistent", status="running"))
def test_running_tasks_initialized(self):
store = WorkflowStore()
assert hasattr(store, "_running_tasks")
assert isinstance(store._running_tasks, dict)
assert len(store._running_tasks) == 0
def test_execution_locks_initialized(self):
store = WorkflowStore()
assert hasattr(store, "_execution_locks")
assert isinstance(store._execution_locks, dict)
def test_evict_execution_cleans_approval_events(self):
async def _run():
store = WorkflowStore(max_executions=2)
e1 = await store.create_execution("wf-1")
e2 = await store.create_execution("wf-2")
# Add an approval event for e1
event_key = f"{e1.execution_id}:stage1"
event = asyncio.Event()
store._approval_events[event_key] = event
# Create a third execution to trigger eviction of e1
e3 = await store.create_execution("wf-3")
# e1 should be evicted, and its approval event should be cleaned up
assert store.get_execution(e1.execution_id) is None
assert event_key not in store._approval_events
# The event should have been set (to wake any waiting coroutine)
assert event.is_set()
asyncio.run(_run())
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------