fix(security,reliability): resolve all P2 findings from code review
This commit is contained in:
parent
658e188939
commit
6852dfe892
|
|
@ -0,0 +1,287 @@
|
|||
# fix: AgentKit P2 Security & Reliability Hardening
|
||||
|
||||
**Status:** active
|
||||
**Created:** 2026-06-10
|
||||
**Origin:** ce-code-review findings on main branch after 7-capabilities merge
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
Fix all remaining P2 issues from the final code review of the `feat/agentkit-7-capabilities` merge. These 16 issues span security (SSRF, auth, injection), reliability (retry, concurrency, resource leaks), and architecture (auth consistency, config reload). All are non-blocking for current functionality but represent real risk in production.
|
||||
|
||||
## Problem Frame
|
||||
|
||||
After merging 7 core capabilities and fixing P0/P1 issues, the codebase has 16 P2 findings that fall into four categories:
|
||||
|
||||
1. **Security gaps** — SSRF DNS rebinding, API key timing attacks, WebSocket auth bypass, approval race conditions, shell pattern matching, .env injection
|
||||
2. **Reliability gaps** — retry storms, unbounded memory growth, concurrent modification, config change loss
|
||||
3. **Architecture gaps** — dual auth layer inconsistency, WebSocket broadcast without isolation
|
||||
4. **Defensive gaps** — process kill race, lazy init race, httpx client double-creation
|
||||
|
||||
## Requirements
|
||||
|
||||
- R1: All P2 security findings must be resolved or explicitly deferred with documented rationale
|
||||
- R2: Retry logic must use exponential backoff with jitter
|
||||
- R3: All in-memory stores must have bounded growth with proper eviction
|
||||
- R4: Concurrent access to shared mutable state must be protected
|
||||
- R5: Authentication must be consistent across all endpoints and transports
|
||||
- R6: No regression in existing test suite (764+ tests must pass)
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
- **KTD1: DNS resolution for SSRF** — Use `asyncio.getaddrinfo` (not blocking `socket.getaddrinfo`) to resolve hostnames before IP validation. Fail-closed on DNS errors. Accept that TOCTOU remains between check and connect; full mitigation requires IP pinning at the httpx transport level (deferred).
|
||||
- **KTD2: Auth consolidation strategy** — Fix per-endpoint auth first (hmac.compare_digest, WebSocket auth-before-accept), then consolidate into middleware in a separate pass. This avoids a big-bang refactor and allows incremental testing.
|
||||
- **KTD3: WorkflowStore concurrency** — Add `asyncio.Lock` to WorkflowStore, convert sync methods to async. This is the most invasive change but necessary for correctness under concurrent load.
|
||||
- **KTD4: Shell dangerous patterns** — Replace substring matching with token-based matching using the already-parsed shlex tokens. Separate binary-level and flag-level dangerous patterns.
|
||||
- **KTD5: WebSocket isolation** — Replace flat subscriber list with `dict[str, set[WebSocket]]` keyed by execution_id. Clients subscribe to specific executions.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. API Key Timing Attack & WebSocket Auth-Before-Accept
|
||||
|
||||
**Goal:** Eliminate timing side-channel in API key comparison and fix WebSocket accept-before-auth pattern.
|
||||
|
||||
**Requirements:** R5
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/server/routes/portal.py`
|
||||
- `src/agentkit/server/routes/workflows.py`
|
||||
- `src/agentkit/server/middleware.py`
|
||||
- `tests/unit/server/test_portal_routes.py`
|
||||
- `tests/unit/server/test_workflow_routes.py`
|
||||
|
||||
**Approach:**
|
||||
1. Add `import hmac` to portal.py, workflows.py, middleware.py
|
||||
2. Replace all `!=` API key comparisons with `hmac.compare_digest()`
|
||||
3. In both WebSocket handlers, move auth validation BEFORE `websocket.accept()`. On auth failure, call `await websocket.close(code=4001)` without accepting, then return
|
||||
4. In middleware.py, replace `provided_key not in valid_keys` with `any(hmac.compare_digest(...) for k in valid_keys)`
|
||||
5. Handle None/empty provided key by defaulting to empty string before comparison
|
||||
|
||||
**Test scenarios:**
|
||||
- API key comparison with correct key returns True
|
||||
- API key comparison with incorrect key returns False
|
||||
- API key comparison with None provided defaults to empty string
|
||||
- WebSocket connection without api_key is rejected (no 101 upgrade)
|
||||
- WebSocket connection with wrong api_key is rejected
|
||||
- WebSocket connection with correct api_key is accepted
|
||||
|
||||
**Verification:** All existing tests pass; new auth tests pass
|
||||
|
||||
---
|
||||
|
||||
### U2. Shell Dangerous Pattern Token-Based Matching
|
||||
|
||||
**Goal:** Replace substring-based dangerous pattern matching with precise token-based matching to eliminate false positives and bypasses.
|
||||
|
||||
**Requirements:** R1
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/tools/shell.py`
|
||||
- `tests/unit/test_shell_tool.py`
|
||||
|
||||
**Approach:**
|
||||
1. Replace `_DANGEROUS_PATTERNS` tuple with two new structures:
|
||||
- `_DANGEROUS_BINARIES`: set of binary names that are always dangerous (e.g., `rm`, `mkfs`, `dd`, `shutdown`)
|
||||
- `_DANGEROUS_BINARY_FLAGS`: dict mapping binary name to set of dangerous flag combinations (e.g., `kill: {-9, -KILL}`, `chmod: {777}`)
|
||||
2. In `_is_dangerous()`, after shlex parsing:
|
||||
- Check `binary` against `_DANGEROUS_BINARIES` (exact match)
|
||||
- Check `binary` + flags against `_DANGEROUS_BINARY_FLAGS`
|
||||
- Keep regex-based arg patterns for cross-token matches (e.g., `> /dev/`, `drop table`)
|
||||
3. Remove overly broad patterns like `format`, `erase` as substrings; add `format` as binary-only
|
||||
4. Normalize whitespace in command before matching
|
||||
|
||||
**Test scenarios:**
|
||||
- `rm -rf /` detected as dangerous (binary match)
|
||||
- `echo 'performing rm operations'` NOT detected as dangerous (no binary match)
|
||||
- `kill -9 1234` detected as dangerous (binary + flag match)
|
||||
- `kill -15 1234` NOT detected as dangerous (flag not in dangerous set)
|
||||
- `rm\t/tmp/file` detected as dangerous (tab normalization)
|
||||
- `format C:` detected as dangerous (binary match)
|
||||
- `informatica` NOT detected as dangerous (not a binary match)
|
||||
|
||||
**Verification:** Shell tool tests pass with new pattern logic
|
||||
|
||||
---
|
||||
|
||||
### U3. .env File Prefix Filtering & Retry Backoff
|
||||
|
||||
**Goal:** Prevent arbitrary environment variable injection from .env files and add exponential backoff to plan executor retries.
|
||||
|
||||
**Requirements:** R1, R2
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/server/app.py`
|
||||
- `src/agentkit/core/plan_executor.py`
|
||||
- `tests/unit/test_plan_executor.py`
|
||||
|
||||
**Approach:**
|
||||
1. In app.py .env loading, add an allowlist of prefixes: `AGENTKIT_`, `OPENAI_`, `ANTHROPIC_`, `GEMINI_`, `TAVILY_`, `SERPER_`, plus exact matches for `DATABASE_URL`, `REDIS_URL`
|
||||
2. Log a warning when a .env variable is skipped due to prefix filtering
|
||||
3. In plan_executor.py `_execute_step_with_retry()`, add exponential backoff with jitter after each failure:
|
||||
- `base_retry_delay=1.0s`, `max_retry_delay=30.0s`
|
||||
- `delay = min(base * 2^(retry-1), max) * (0.5 + random() * 0.5)`
|
||||
- Skip delay on first attempt (retry_count == 0)
|
||||
4. Add `base_retry_delay` and `max_retry_delay` as constructor parameters with defaults
|
||||
|
||||
**Test scenarios:**
|
||||
- .env with `AGENTKIT_FOO=bar` is loaded
|
||||
- .env with `PATH=/malicious` is skipped with warning
|
||||
- .env with `OPENAI_API_KEY=sk-xxx` is loaded
|
||||
- Retry backoff: first retry waits ~1s, second ~2s, third ~4s (with jitter)
|
||||
- Max retry delay caps at 30s
|
||||
- Plan executor with max_retries=0 still works (no backoff)
|
||||
|
||||
**Verification:** Unit tests pass; .env loading respects prefix filter
|
||||
|
||||
---
|
||||
|
||||
### U4. WorkflowStore Concurrency & Resource Management
|
||||
|
||||
**Goal:** Make WorkflowStore thread-safe, fix unbounded growth, add proper cleanup for approval events and running tasks.
|
||||
|
||||
**Requirements:** R3, R4
|
||||
|
||||
**Dependencies:** U1 (hmac.compare_digest in workflow routes)
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/server/routes/workflows.py`
|
||||
- `tests/unit/server/test_workflow_routes.py`
|
||||
|
||||
**Approach:**
|
||||
1. Add `asyncio.Lock` to WorkflowStore; convert `save`, `delete`, `create_execution`, `update_execution` to async methods
|
||||
2. Initialize `_running_tasks` in `__init__` instead of lazy `getattr`
|
||||
3. Add `_evict_execution()` helper that removes execution AND its approval events (with `event.set()` to wake waiting coroutines)
|
||||
4. Replace `_ws_subscribers: list` with `_ws_subscribers: dict[str, set[WebSocket]]` keyed by execution_id, protected by `asyncio.Lock`
|
||||
5. Add `_subscribe(execution_id, ws)` and `_unsubscribe(execution_id, ws)` helpers
|
||||
6. Modify `_broadcast_ws` to accept `execution_id` parameter and only send to relevant subscribers
|
||||
7. Update all call sites from sync to async (all endpoint functions already async)
|
||||
8. Add `_execution_locks: dict[str, asyncio.Lock]` for per-execution approve/cancel serialization
|
||||
|
||||
**Test scenarios:**
|
||||
- Concurrent save operations don't cause KeyError or over-eviction
|
||||
- Evicted execution's approval events are cleaned up
|
||||
- WebSocket subscriber receives only messages for subscribed execution_id
|
||||
- WebSocket disconnect properly cleans up from subscriber dict
|
||||
- Concurrent approve/cancel for same execution serialized by lock
|
||||
- _running_tasks initialized in __init__, no getattr fallback
|
||||
|
||||
**Verification:** Workflow route tests pass; no RuntimeError from concurrent modification
|
||||
|
||||
---
|
||||
|
||||
### U5. SSRF DNS Resolution & ComputerUse Client Safety
|
||||
|
||||
**Goal:** Add DNS resolution to SSRF protection and make ComputerUseTool's httpx client creation thread-safe.
|
||||
|
||||
**Requirements:** R1
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/utils/security.py`
|
||||
- `src/agentkit/tools/computer_use.py`
|
||||
- `tests/unit/test_security.py`
|
||||
- `tests/unit/tools/test_computer_use.py`
|
||||
|
||||
**Approach:**
|
||||
1. In `is_safe_url()`, after the ValueError catch for domain names, add DNS resolution:
|
||||
- Use `socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)` (sync, acceptable for tool init)
|
||||
- Validate each resolved IP against `_is_unsafe_ip()`
|
||||
- Fail-closed on DNS errors (return False)
|
||||
- Add 3-second timeout on DNS resolution
|
||||
2. Add `is_safe_url_async()` variant using `asyncio.getaddrinfo` for async callers
|
||||
3. In ComputerUseTool, add `asyncio.Lock` to `_get_http_client()`, make it async with double-check locking
|
||||
4. Update caller from `client = self._get_http_client()` to `client = await self._get_http_client()`
|
||||
|
||||
**Test scenarios:**
|
||||
- is_safe_url with domain resolving to 127.0.0.1 returns False
|
||||
- is_safe_url with domain resolving to public IP returns True
|
||||
- is_safe_url with unresolvable domain returns False (fail-closed)
|
||||
- is_safe_url with ::ffff:127.0.0.1 returns False
|
||||
- is_safe_url with .internal TLD returns False
|
||||
- ComputerUseTool concurrent _get_http_client returns same client instance
|
||||
- ComputerUseTool close() properly cleans up client
|
||||
|
||||
**Verification:** Security and computer_use tests pass
|
||||
|
||||
---
|
||||
|
||||
### U6. Config Hot-Reload & Defensive Fixes
|
||||
|
||||
**Goal:** Fix config hot-reload lock race, add proc.kill() error handling, and initialize config reload lock eagerly.
|
||||
|
||||
**Requirements:** R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/server/app.py`
|
||||
- `src/agentkit/tools/shell.py`
|
||||
|
||||
**Approach:**
|
||||
1. In app.py `_on_config_change()`:
|
||||
- Initialize lock eagerly in `create_app()` (store on `app.state._config_reload_lock`)
|
||||
- Replace `lock.locked()` skip with pending-flag pattern: set `app.state._config_reload_pending = True`, then in `_reload()` loop while pending flag is set
|
||||
- This ensures no config change is silently dropped
|
||||
2. In shell.py `_execute_standalone()`:
|
||||
- Wrap `proc.kill()` in `try/except (ProcessLookupError, OSError)`
|
||||
- Use `proc.returncode` if available instead of hardcoded -1
|
||||
|
||||
**Test scenarios:**
|
||||
- Two rapid config changes both get applied (no silent drop)
|
||||
- Config reload pending flag is cleared after reload completes
|
||||
- Shell timeout with already-exited process returns clean error (no ProcessLookupError)
|
||||
- Shell timeout output still shows timeout message
|
||||
|
||||
**Verification:** App startup and shell tool tests pass
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- All 16 P2 findings from ce-code-review
|
||||
- Test coverage for new security/reliability patterns
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
- Full auth middleware consolidation (ASGI-level WebSocket middleware) — complex, separate PR
|
||||
- IP pinning at httpx transport level for complete SSRF protection
|
||||
- python-dotenv migration for robust .env parsing
|
||||
- clients.yaml caching in middleware (per-request disk read)
|
||||
- WorkflowStore persistence (currently in-memory only)
|
||||
|
||||
### Outside This Plan's Identity
|
||||
- New features or capability additions
|
||||
- Performance optimization beyond what's needed for reliability
|
||||
- UI/UX changes
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
|------|--------|------------|
|
||||
| WorkflowStore async refactor breaks existing endpoints | High | Update all call sites; comprehensive test coverage |
|
||||
| DNS resolution adds latency to URL validation | Low | 3s timeout; acceptable for tool initialization |
|
||||
| Token-based shell matching misses new bypass patterns | Medium | Default-to-dangerous fallback; add patterns as discovered |
|
||||
| .env prefix filtering breaks existing deployments | Medium | Comprehensive allowlist; log warnings for skipped vars |
|
||||
| WebSocket isolation breaks existing WS clients | Medium | Maintain backward compat: accept without subscription, send all |
|
||||
| Config reload pending loop runs indefinitely | Low | Coalescing naturally limits; only re-runs if new change arrived |
|
||||
|
||||
---
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **Security posture** significantly improved: timing attacks, SSRF, auth bypass all addressed
|
||||
- **Reliability** improved: retry storms prevented, memory bounded, concurrent access safe
|
||||
- **API compatibility**: WorkflowStore method signatures change from sync to async — all call sites updated
|
||||
- **WebSocket protocol**: subscription message becomes recommended but not required for backward compat
|
||||
|
|
@ -14,6 +14,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
|
@ -94,6 +95,8 @@ class PlanExecutor:
|
|||
max_retries: int = 2,
|
||||
step_timeout: float = 300.0,
|
||||
max_parallel: int = 5,
|
||||
base_retry_delay: float = 1.0,
|
||||
max_retry_delay: float = 30.0,
|
||||
on_step_complete: OnStepCompleteCallback | None = None,
|
||||
on_step_failed: OnStepFailedCallback | None = None,
|
||||
on_human_intervention: OnHumanInterventionCallback | None = None,
|
||||
|
|
@ -104,6 +107,8 @@ class PlanExecutor:
|
|||
max_retries: 步骤失败后最大重试次数
|
||||
step_timeout: 单个步骤超时时间(秒)
|
||||
max_parallel: 最大并行步骤数
|
||||
base_retry_delay: 重试基础延迟(秒)
|
||||
max_retry_delay: 重试最大延迟(秒)
|
||||
on_step_complete: 步骤完成回调
|
||||
on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理
|
||||
on_human_intervention: 人工介入回调
|
||||
|
|
@ -112,6 +117,8 @@ class PlanExecutor:
|
|||
self._max_retries = max_retries
|
||||
self._step_timeout = step_timeout
|
||||
self._max_parallel = max_parallel
|
||||
self._base_retry_delay = base_retry_delay
|
||||
self._max_retry_delay = max_retry_delay
|
||||
self._on_step_complete = on_step_complete
|
||||
self._on_step_failed = on_step_failed
|
||||
self._on_human_intervention = on_human_intervention
|
||||
|
|
@ -250,6 +257,15 @@ class PlanExecutor:
|
|||
|
||||
retry_count += 1
|
||||
|
||||
if retry_count <= self._max_retries:
|
||||
delay = min(
|
||||
self._base_retry_delay * (2 ** (retry_count - 1)),
|
||||
self._max_retry_delay,
|
||||
)
|
||||
delay *= (0.5 + random.random() * 0.5) # jitter
|
||||
logger.info(f"Retrying step '{step.step_id}' in {delay:.1f}s (attempt {retry_count + 1}/{self._max_retries + 1})")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 所有重试耗尽
|
||||
step.status = PlanStepStatus.FAILED
|
||||
step.error = last_error
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ from agentkit.telemetry.setup import setup_telemetry
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_ENV_PREFIXES = (
|
||||
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||||
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
||||
)
|
||||
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
||||
|
||||
|
||||
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||
"""Build LLMGateway from ServerConfig, registering all providers."""
|
||||
|
|
@ -224,70 +230,69 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
|||
|
||||
Uses a lock to prevent concurrent config reloads from racing.
|
||||
"""
|
||||
lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
app.state._config_reload_lock = lock
|
||||
lock: asyncio.Lock = app.state._config_reload_lock
|
||||
|
||||
if lock.locked():
|
||||
logger.warning("Config reload already in progress, skipping")
|
||||
return
|
||||
app.state._config_reload_pending = True
|
||||
|
||||
async def _reload():
|
||||
if lock.locked():
|
||||
return # Another reload running; it will check pending flag
|
||||
async with lock:
|
||||
# Increment config version for audit
|
||||
current_version = getattr(app.state, "config_version", 0) + 1
|
||||
app.state.config_version = current_version
|
||||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||||
while getattr(app.state, "_config_reload_pending", False):
|
||||
app.state._config_reload_pending = False
|
||||
# Increment config version for audit
|
||||
current_version = getattr(app.state, "config_version", 0) + 1
|
||||
app.state.config_version = current_version
|
||||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||||
|
||||
# Rebuild LLMGateway if llm config changed
|
||||
try:
|
||||
new_gateway = _build_llm_gateway(config)
|
||||
app.state.llm_gateway = new_gateway
|
||||
# Also update the agent pool's gateway reference
|
||||
# Rebuild LLMGateway if llm config changed
|
||||
try:
|
||||
new_gateway = _build_llm_gateway(config)
|
||||
app.state.llm_gateway = new_gateway
|
||||
# Also update the agent pool's gateway reference
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._llm_gateway = new_gateway
|
||||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||||
app.state.intent_router._llm_gateway = new_gateway
|
||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||
|
||||
# Reload skills if skill paths changed
|
||||
try:
|
||||
new_skill_registry = _build_skill_registry(config)
|
||||
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
|
||||
tool_registry = getattr(app.state, "tool_registry", None)
|
||||
if tool_registry:
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
loader = SkillLoader(
|
||||
skill_registry=new_skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
for skill_path in (config.skill_paths or []):
|
||||
from pathlib import Path as _P
|
||||
p = _P(skill_path)
|
||||
if p.is_dir():
|
||||
loader.load_from_directory(str(p))
|
||||
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
||||
try:
|
||||
loader.load_from_file(str(p))
|
||||
except Exception:
|
||||
pass
|
||||
app.state.skill_registry = new_skill_registry
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._skill_registry = new_skill_registry
|
||||
logger.info(f"Skills reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload skills: {e}")
|
||||
|
||||
# Update config version on all agents
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._llm_gateway = new_gateway
|
||||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||||
app.state.intent_router._llm_gateway = new_gateway
|
||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||
for agent in app.state.agent_pool._agents.values():
|
||||
if hasattr(agent, "_config_version"):
|
||||
agent._config_version = current_version
|
||||
|
||||
# Reload skills if skill paths changed
|
||||
try:
|
||||
new_skill_registry = _build_skill_registry(config)
|
||||
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
|
||||
tool_registry = getattr(app.state, "tool_registry", None)
|
||||
if tool_registry:
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
loader = SkillLoader(
|
||||
skill_registry=new_skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
for skill_path in (config.skill_paths or []):
|
||||
from pathlib import Path as _P
|
||||
p = _P(skill_path)
|
||||
if p.is_dir():
|
||||
loader.load_from_directory(str(p))
|
||||
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
||||
try:
|
||||
loader.load_from_file(str(p))
|
||||
except Exception:
|
||||
pass
|
||||
app.state.skill_registry = new_skill_registry
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
app.state.agent_pool._skill_registry = new_skill_registry
|
||||
logger.info(f"Skills reloaded (config v{current_version})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload skills: {e}")
|
||||
|
||||
# Update config version on all agents
|
||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||
for agent in app.state.agent_pool._agents.values():
|
||||
if hasattr(agent, "_config_version"):
|
||||
agent._config_version = current_version
|
||||
|
||||
logger.info(f"Config reload complete (v{current_version})")
|
||||
logger.info(f"Config reload complete (v{current_version})")
|
||||
|
||||
# Schedule the reload as a task (non-blocking for the watcher thread)
|
||||
try:
|
||||
|
|
@ -327,6 +332,10 @@ def create_app(
|
|||
_key = _key.strip()
|
||||
_val = _val.strip().strip("\"'")
|
||||
if _key and _key not in os.environ:
|
||||
allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT
|
||||
if not allowed:
|
||||
logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)")
|
||||
continue
|
||||
os.environ[_key] = _val
|
||||
server_config = ServerConfig.from_yaml(config_path)
|
||||
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Server middleware - Authentication and Rate Limiting"""
|
||||
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -75,7 +76,7 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
# Check API key from header
|
||||
provided_key = request.headers.get("X-API-Key")
|
||||
if not provided_key or provided_key not in valid_keys:
|
||||
if not provided_key or not any(hmac.compare_digest(provided_key.encode(), k.encode()) for k in valid_keys):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
|
@ -47,7 +48,7 @@ async def _verify_api_key(
|
|||
return
|
||||
|
||||
provided = api_key_header or api_key_query
|
||||
if provided != configured_api_key:
|
||||
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||
|
|
@ -475,11 +476,7 @@ async def portal_websocket(websocket: WebSocket):
|
|||
# Check api_key query param
|
||||
if configured_api_key:
|
||||
provided = websocket.query_params.get("api_key")
|
||||
if provided != configured_api_key:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
|
||||
)
|
||||
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -52,7 +53,7 @@ async def _verify_api_key(
|
|||
return
|
||||
|
||||
provided = api_key_header or api_key_query
|
||||
if provided != configured_api_key:
|
||||
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||
|
|
@ -65,7 +66,7 @@ async def _verify_api_key(
|
|||
|
||||
|
||||
class WorkflowStore:
|
||||
"""In-memory workflow store."""
|
||||
"""In-memory workflow store with async-safe mutation methods."""
|
||||
|
||||
def __init__(self, max_workflows: int = 500, max_executions: int = 1000):
|
||||
self._workflows: dict[str, WorkflowDefinition] = {}
|
||||
|
|
@ -73,17 +74,30 @@ class WorkflowStore:
|
|||
self._max_workflows = max_workflows
|
||||
self._max_executions = max_executions
|
||||
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
|
||||
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||
self._execution_locks: dict[str, asyncio.Lock] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
||||
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
self._workflows[workflow.workflow_id] = workflow
|
||||
# Evict oldest if over limit
|
||||
if len(self._workflows) > self._max_workflows:
|
||||
oldest_id = min(
|
||||
self._workflows, key=lambda k: self._workflows[k].updated_at
|
||||
)
|
||||
del self._workflows[oldest_id]
|
||||
return workflow
|
||||
def _evict_execution(self, execution_id: str) -> None:
|
||||
"""Remove execution and its associated approval events."""
|
||||
self._executions.pop(execution_id, None)
|
||||
keys_to_remove = [k for k in self._approval_events if k.startswith(f"{execution_id}:")]
|
||||
for k in keys_to_remove:
|
||||
event = self._approval_events.pop(k, None)
|
||||
if event is not None:
|
||||
event.set() # Wake any waiting coroutine
|
||||
|
||||
async def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
||||
async with self._lock:
|
||||
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
self._workflows[workflow.workflow_id] = workflow
|
||||
# Evict oldest if over limit
|
||||
if len(self._workflows) > self._max_workflows:
|
||||
oldest_id = min(
|
||||
self._workflows, key=lambda k: self._workflows[k].updated_at
|
||||
)
|
||||
del self._workflows[oldest_id]
|
||||
return workflow
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowDefinition | None:
|
||||
return self._workflows.get(workflow_id)
|
||||
|
|
@ -109,47 +123,72 @@ class WorkflowStore:
|
|||
)
|
||||
return summaries
|
||||
|
||||
def delete(self, workflow_id: str) -> bool:
|
||||
if workflow_id in self._workflows:
|
||||
del self._workflows[workflow_id]
|
||||
return True
|
||||
return False
|
||||
async def delete(self, workflow_id: str) -> bool:
|
||||
async with self._lock:
|
||||
if workflow_id in self._workflows:
|
||||
del self._workflows[workflow_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_execution(self, workflow_id: str) -> WorkflowExecution:
|
||||
execution = WorkflowExecution(
|
||||
execution_id=str(uuid.uuid4()),
|
||||
workflow_id=workflow_id,
|
||||
status="pending",
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self._executions[execution.execution_id] = execution
|
||||
# Evict oldest if over limit
|
||||
if len(self._executions) > self._max_executions:
|
||||
oldest_id = min(
|
||||
self._executions,
|
||||
key=lambda k: self._executions[k].started_at or "",
|
||||
async def create_execution(self, workflow_id: str) -> WorkflowExecution:
|
||||
async with self._lock:
|
||||
execution = WorkflowExecution(
|
||||
execution_id=str(uuid.uuid4()),
|
||||
workflow_id=workflow_id,
|
||||
status="pending",
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
del self._executions[oldest_id]
|
||||
return execution
|
||||
self._executions[execution.execution_id] = execution
|
||||
# Evict oldest if over limit
|
||||
if len(self._executions) > self._max_executions:
|
||||
oldest_id = min(
|
||||
self._executions,
|
||||
key=lambda k: self._executions[k].started_at or "",
|
||||
)
|
||||
self._evict_execution(oldest_id)
|
||||
return execution
|
||||
|
||||
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
|
||||
return self._executions.get(execution_id)
|
||||
|
||||
def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
|
||||
execution = self._executions.get(execution_id)
|
||||
if execution is None:
|
||||
raise KeyError(f"Execution '{execution_id}' not found")
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(execution, key):
|
||||
setattr(execution, key, value)
|
||||
return execution
|
||||
async def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
|
||||
async with self._lock:
|
||||
execution = self._executions.get(execution_id)
|
||||
if execution is None:
|
||||
raise KeyError(f"Execution '{execution_id}' not found")
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(execution, key):
|
||||
setattr(execution, key, value)
|
||||
return execution
|
||||
|
||||
def get_execution_lock(self, execution_id: str) -> asyncio.Lock:
|
||||
"""Get or create a per-execution lock for approve/cancel serialization."""
|
||||
if execution_id not in self._execution_locks:
|
||||
self._execution_locks[execution_id] = asyncio.Lock()
|
||||
return self._execution_locks[execution_id]
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_workflow_store = WorkflowStore()
|
||||
|
||||
# WebSocket subscribers for real-time execution progress
|
||||
_ws_subscribers: list[WebSocket] = []
|
||||
# WebSocket subscribers for real-time execution progress (keyed by execution_id)
|
||||
_ws_subscribers: dict[str, set[WebSocket]] = {}
|
||||
_ws_subscribers_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _ws_subscribe(execution_id: str, ws: WebSocket) -> None:
|
||||
async with _ws_subscribers_lock:
|
||||
if execution_id not in _ws_subscribers:
|
||||
_ws_subscribers[execution_id] = set()
|
||||
_ws_subscribers[execution_id].add(ws)
|
||||
|
||||
|
||||
async def _ws_unsubscribe(execution_id: str, ws: WebSocket) -> None:
|
||||
async with _ws_subscribers_lock:
|
||||
if execution_id in _ws_subscribers:
|
||||
_ws_subscribers[execution_id].discard(ws)
|
||||
if not _ws_subscribers[execution_id]:
|
||||
del _ws_subscribers[execution_id]
|
||||
|
||||
|
||||
def _get_store(request: Request) -> WorkflowStore:
|
||||
|
|
@ -210,7 +249,7 @@ async def _execute_workflow(
|
|||
"""Execute a workflow by running its stages in topological order."""
|
||||
_store = store or _workflow_store
|
||||
execution.status = "running"
|
||||
_store.update_execution(execution.execution_id, status="running")
|
||||
await _store.update_execution(execution.execution_id, status="running")
|
||||
|
||||
# Topological sort
|
||||
stage_map = {s.name: s for s in workflow.stages}
|
||||
|
|
@ -229,7 +268,7 @@ async def _execute_workflow(
|
|||
execution.status = "failed"
|
||||
execution.error = "循环依赖"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="failed",
|
||||
error="循环依赖",
|
||||
|
|
@ -246,7 +285,7 @@ async def _execute_workflow(
|
|||
for stage_name in ordered:
|
||||
stage = stage_map[stage_name]
|
||||
execution.current_stage = stage_name
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
current_stage=stage_name,
|
||||
)
|
||||
|
|
@ -256,7 +295,7 @@ async def _execute_workflow(
|
|||
"event": "stage_started",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
|
||||
try:
|
||||
if stage.type == "approval":
|
||||
|
|
@ -267,7 +306,7 @@ async def _execute_workflow(
|
|||
|
||||
execution.status = "paused"
|
||||
execution.current_stage = stage_name
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="paused",
|
||||
current_stage=stage_name,
|
||||
|
|
@ -276,7 +315,7 @@ async def _execute_workflow(
|
|||
"event": "approval_required",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
|
||||
# Wait for approval with timeout
|
||||
try:
|
||||
|
|
@ -289,13 +328,13 @@ async def _execute_workflow(
|
|||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
"error": "Approval rejected",
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
return
|
||||
# Approval was granted — the /approve endpoint already set stage_results
|
||||
# Only update status to running if not already set
|
||||
if execution.status != "running":
|
||||
execution.status = "running"
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="running",
|
||||
)
|
||||
|
|
@ -308,7 +347,7 @@ async def _execute_workflow(
|
|||
execution.status = "failed"
|
||||
execution.error = f"Approval timeout for stage {stage_name}"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="failed",
|
||||
error=execution.error,
|
||||
|
|
@ -320,7 +359,7 @@ async def _execute_workflow(
|
|||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
"error": "Approval timeout",
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
return
|
||||
finally:
|
||||
_store._approval_events.pop(event_key, None)
|
||||
|
|
@ -332,7 +371,7 @@ async def _execute_workflow(
|
|||
"status": "completed",
|
||||
"condition_result": result,
|
||||
}
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
|
|
@ -342,7 +381,7 @@ async def _execute_workflow(
|
|||
"status": "completed",
|
||||
"output": {"dry_run": True, "action": stage.action},
|
||||
}
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
|
|
@ -351,7 +390,7 @@ async def _execute_workflow(
|
|||
"event": "stage_completed",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
|
||||
except Exception as e:
|
||||
execution.stage_results[stage_name] = {
|
||||
|
|
@ -361,7 +400,7 @@ async def _execute_workflow(
|
|||
execution.status = "failed"
|
||||
execution.error = f"阶段 '{stage_name}' 执行失败: {e}"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="failed",
|
||||
error=execution.error,
|
||||
|
|
@ -373,13 +412,13 @@ async def _execute_workflow(
|
|||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
"error": str(e),
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
return
|
||||
|
||||
execution.status = "completed"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
execution.current_stage = None
|
||||
_store.update_execution(
|
||||
await _store.update_execution(
|
||||
execution.execution_id,
|
||||
status="completed",
|
||||
completed_at=execution.completed_at,
|
||||
|
|
@ -388,7 +427,7 @@ async def _execute_workflow(
|
|||
await _broadcast_ws({
|
||||
"event": "execution_completed",
|
||||
"execution_id": execution.execution_id,
|
||||
})
|
||||
}, execution_id=execution.execution_id)
|
||||
|
||||
|
||||
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
|
|
@ -455,16 +494,25 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
|||
raise ValueError(f"Invalid condition expression: {expression}")
|
||||
|
||||
|
||||
async def _broadcast_ws(message: dict[str, Any]) -> None:
|
||||
"""Broadcast a message to all WebSocket subscribers."""
|
||||
async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None) -> None:
|
||||
"""Broadcast a message to WebSocket subscribers for a specific execution."""
|
||||
async with _ws_subscribers_lock:
|
||||
targets = set()
|
||||
if execution_id and execution_id in _ws_subscribers:
|
||||
targets = set(_ws_subscribers[execution_id]) # snapshot
|
||||
disconnected = []
|
||||
for ws in _ws_subscribers:
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(ws)
|
||||
for ws in disconnected:
|
||||
_ws_subscribers.remove(ws)
|
||||
if disconnected:
|
||||
async with _ws_subscribers_lock:
|
||||
for ws in disconnected:
|
||||
for eid in list(_ws_subscribers.keys()):
|
||||
_ws_subscribers[eid].discard(ws)
|
||||
if not _ws_subscribers[eid]:
|
||||
del _ws_subscribers[eid]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -494,7 +542,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth:
|
|||
variables_schema=body.variables_schema,
|
||||
output_schema=body.output_schema,
|
||||
)
|
||||
saved = store.save(workflow)
|
||||
saved = await store.save(workflow)
|
||||
return saved.model_dump()
|
||||
|
||||
|
||||
|
|
@ -527,7 +575,7 @@ async def update_workflow(
|
|||
existing.variables_schema = body.variables_schema
|
||||
existing.output_schema = body.output_schema
|
||||
existing.version += 1
|
||||
saved = store.save(existing)
|
||||
saved = await store.save(existing)
|
||||
return saved.model_dump()
|
||||
|
||||
|
||||
|
|
@ -535,7 +583,7 @@ async def update_workflow(
|
|||
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
|
||||
"""Delete a workflow."""
|
||||
store = _get_store(request)
|
||||
deleted = store.delete(workflow_id)
|
||||
deleted = await store.delete(workflow_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
||||
return {"message": "已删除"}
|
||||
|
|
@ -551,14 +599,13 @@ async def execute_workflow(
|
|||
if workflow is None:
|
||||
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
||||
|
||||
execution = store.create_execution(workflow_id)
|
||||
execution = await store.create_execution(workflow_id)
|
||||
execution.variables = body.variables
|
||||
|
||||
# Start execution in background
|
||||
task = asyncio.create_task(
|
||||
_execute_workflow(workflow, execution, body.variables, store=store)
|
||||
)
|
||||
store._running_tasks = getattr(store, "_running_tasks", {})
|
||||
store._running_tasks[execution.execution_id] = task
|
||||
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
|
||||
|
||||
|
|
@ -593,51 +640,60 @@ async def approve_execution(
|
|||
raise HTTPException(
|
||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||
)
|
||||
if execution.status != "paused":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="当前执行状态不是等待审批"
|
||||
)
|
||||
|
||||
if body.approved:
|
||||
if execution.current_stage:
|
||||
execution.stage_results[execution.current_stage] = {
|
||||
"status": "approved",
|
||||
"approver": "user",
|
||||
"comment": body.comment,
|
||||
}
|
||||
execution.status = "running"
|
||||
store.update_execution(
|
||||
execution.execution_id,
|
||||
status="running",
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Resume the waiting execution by setting the approval event
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
else:
|
||||
execution.status = "cancelled"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
if execution.current_stage:
|
||||
execution.stage_results[execution.current_stage] = {
|
||||
"status": "rejected",
|
||||
"approver": "user",
|
||||
"comment": body.comment,
|
||||
}
|
||||
store.update_execution(
|
||||
execution.execution_id,
|
||||
status="cancelled",
|
||||
completed_at=execution.completed_at,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Set the approval event so the waiting coroutine can observe the cancelled state
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
exec_lock = store.get_execution_lock(execution_id)
|
||||
async with exec_lock:
|
||||
# Re-fetch execution after acquiring lock
|
||||
execution = store.get_execution(execution_id)
|
||||
if execution is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||
)
|
||||
if execution.status != "paused":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="当前执行状态不是等待审批"
|
||||
)
|
||||
|
||||
if body.approved:
|
||||
if execution.current_stage:
|
||||
execution.stage_results[execution.current_stage] = {
|
||||
"status": "approved",
|
||||
"approver": "user",
|
||||
"comment": body.comment,
|
||||
}
|
||||
execution.status = "running"
|
||||
await store.update_execution(
|
||||
execution.execution_id,
|
||||
status="running",
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Resume the waiting execution by setting the approval event
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
else:
|
||||
execution.status = "cancelled"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
if execution.current_stage:
|
||||
execution.stage_results[execution.current_stage] = {
|
||||
"status": "rejected",
|
||||
"approver": "user",
|
||||
"comment": body.comment,
|
||||
}
|
||||
await store.update_execution(
|
||||
execution.execution_id,
|
||||
status="cancelled",
|
||||
completed_at=execution.completed_at,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Set the approval event so the waiting coroutine can observe the cancelled state
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
|
||||
return execution.model_dump()
|
||||
|
||||
|
|
@ -651,23 +707,33 @@ async def cancel_execution(request: Request, execution_id: str, _auth: None = De
|
|||
raise HTTPException(
|
||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||
)
|
||||
if execution.status not in ("running", "paused", "pending"):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="当前执行状态无法取消"
|
||||
)
|
||||
|
||||
execution.status = "cancelled"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
store.update_execution(
|
||||
execution.execution_id,
|
||||
status="cancelled",
|
||||
completed_at=execution.completed_at,
|
||||
)
|
||||
# Set any pending approval event so a paused workflow can observe the cancelled state
|
||||
if hasattr(execution, "current_stage") and execution.current_stage:
|
||||
event_key = f"{execution_id}:{execution.current_stage}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
exec_lock = store.get_execution_lock(execution_id)
|
||||
async with exec_lock:
|
||||
# Re-fetch execution after acquiring lock
|
||||
execution = store.get_execution(execution_id)
|
||||
if execution is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||
)
|
||||
if execution.status not in ("running", "paused", "pending"):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="当前执行状态无法取消"
|
||||
)
|
||||
|
||||
execution.status = "cancelled"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
await store.update_execution(
|
||||
execution.execution_id,
|
||||
status="cancelled",
|
||||
completed_at=execution.completed_at,
|
||||
)
|
||||
# Set any pending approval event so a paused workflow can observe the cancelled state
|
||||
if hasattr(execution, "current_stage") and execution.current_stage:
|
||||
event_key = f"{execution_id}:{execution.current_stage}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
|
||||
return execution.model_dump()
|
||||
|
||||
|
||||
|
|
@ -683,21 +749,37 @@ async def workflow_websocket(websocket: WebSocket):
|
|||
|
||||
if configured_api_key:
|
||||
provided = websocket.query_params.get("api_key")
|
||||
if provided != configured_api_key:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"event": "error", "data": {"message": "Invalid or missing api_key"}}
|
||||
)
|
||||
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
_ws_subscribers.append(websocket)
|
||||
|
||||
# Determine execution_id from query params; may be None for backward compat
|
||||
execution_id = websocket.query_params.get("execution_id")
|
||||
subscribed_execution_id: str | None = None
|
||||
|
||||
if execution_id:
|
||||
await _ws_subscribe(execution_id, websocket)
|
||||
subscribed_execution_id = execution_id
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0)
|
||||
# Handle subscription messages
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
if isinstance(msg, dict) and msg.get("type") == "subscribe":
|
||||
new_eid = msg.get("execution_id")
|
||||
if new_eid:
|
||||
# Unsubscribe from previous if any
|
||||
if subscribed_execution_id:
|
||||
await _ws_unsubscribe(subscribed_execution_id, websocket)
|
||||
await _ws_subscribe(new_eid, websocket)
|
||||
subscribed_execution_id = new_eid
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
await websocket.close(code=1000, reason="Heartbeat timeout")
|
||||
return
|
||||
|
|
@ -707,5 +789,5 @@ async def workflow_websocket(websocket: WebSocket):
|
|||
except Exception as e:
|
||||
logger.error(f"Workflow WebSocket error: {e}")
|
||||
finally:
|
||||
if websocket in _ws_subscribers:
|
||||
_ws_subscribers.remove(websocket)
|
||||
if subscribed_execution_id:
|
||||
await _ws_unsubscribe(subscribed_execution_id, websocket)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
|
@ -95,11 +96,14 @@ class ComputerUseTool(Tool):
|
|||
self._max_retries = max_retries
|
||||
self._request_timeout = request_timeout
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
self._client_lock = asyncio.Lock()
|
||||
|
||||
def _get_http_client(self) -> httpx.AsyncClient:
|
||||
async def _get_http_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
|
||||
if self._http_client is None or self._http_client.is_closed:
|
||||
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
|
||||
async with self._client_lock:
|
||||
if self._http_client is None or self._http_client.is_closed:
|
||||
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
|
||||
return self._http_client
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
@ -391,7 +395,7 @@ class ComputerUseTool(Tool):
|
|||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
client = self._get_http_client()
|
||||
client = await self._get_http_client()
|
||||
response = await client.post(
|
||||
self._api_base_url,
|
||||
json=request_body,
|
||||
|
|
|
|||
|
|
@ -66,40 +66,33 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
|||
"docker images",
|
||||
)
|
||||
|
||||
# 危险命令模式:这些命令需要人工确认
|
||||
_DANGEROUS_PATTERNS: tuple[str, ...] = (
|
||||
"rm ",
|
||||
"rm -",
|
||||
"rmdir",
|
||||
"mkfs",
|
||||
"dd ",
|
||||
"format",
|
||||
"del ",
|
||||
"erase",
|
||||
"> /dev/",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"init 0",
|
||||
"init 6",
|
||||
"kill -9",
|
||||
"killall",
|
||||
"chmod 777",
|
||||
"chown",
|
||||
"mv /",
|
||||
"pip uninstall",
|
||||
"npm uninstall",
|
||||
"apt remove",
|
||||
"yum remove",
|
||||
"brew uninstall",
|
||||
"docker rm",
|
||||
"docker rmi",
|
||||
"git push --force",
|
||||
"git reset --hard",
|
||||
"git clean -f",
|
||||
"drop table",
|
||||
"drop database",
|
||||
"truncate",
|
||||
)
|
||||
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
|
||||
|
||||
# 总是危险的二进制命令(无论参数)
|
||||
_DANGEROUS_BINARIES: frozenset[str] = frozenset({
|
||||
"rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot",
|
||||
"halt", "killall", "chown", "fdisk", "parted",
|
||||
})
|
||||
|
||||
# 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合
|
||||
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
|
||||
"rm": {"-rf", "-fr", "-r", "-f"},
|
||||
"kill": {"-9", "-kill"},
|
||||
"chmod": {"777", "000"},
|
||||
"git": {"push --force", "push -f", "reset --hard", "clean -f"},
|
||||
"pip": {"uninstall"},
|
||||
"npm": {"uninstall"},
|
||||
"docker": {"rm", "rmi", "system prune"},
|
||||
}
|
||||
|
||||
# 跨 token 的危险模式(编译后的正则)
|
||||
_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r">\s*/dev/", re.IGNORECASE),
|
||||
re.compile(r">\s*/etc/", re.IGNORECASE),
|
||||
re.compile(r"drop\s+table", re.IGNORECASE),
|
||||
re.compile(r"drop\s+database", re.IGNORECASE),
|
||||
re.compile(r"truncate\s+table", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
|
||||
|
|
@ -300,10 +293,15 @@ class ShellTool(Tool):
|
|||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
logger.debug("Process already exited before kill()")
|
||||
except OSError:
|
||||
logger.debug("OSError killing process")
|
||||
await proc.wait()
|
||||
output = f"命令执行超时({timeout}s)"
|
||||
exit_code = -1
|
||||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||
else:
|
||||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||
|
|
@ -394,10 +392,24 @@ class ShellTool(Tool):
|
|||
if binary.lower() == prefix_stripped:
|
||||
return False
|
||||
|
||||
# Dangerous pattern check
|
||||
# Dangerous pattern check — token-based matching
|
||||
binary_lower = binary.lower()
|
||||
|
||||
# 1. Binary is always dangerous regardless of flags
|
||||
if binary_lower in _DANGEROUS_BINARIES:
|
||||
return True
|
||||
|
||||
# 2. Binary is dangerous with specific flags/subcommands
|
||||
if binary_lower in _DANGEROUS_BINARY_FLAGS:
|
||||
cmd_str = " ".join(tokens).lower()
|
||||
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
|
||||
if flag_pattern in cmd_str:
|
||||
return True
|
||||
|
||||
# 3. Cross-token dangerous patterns (regex)
|
||||
command_lower = command_stripped.lower()
|
||||
for pattern in _DANGEROUS_PATTERNS:
|
||||
if pattern in command_lower:
|
||||
for pattern in _DANGEROUS_ARG_PATTERNS:
|
||||
if pattern.search(command_lower):
|
||||
return True
|
||||
|
||||
return True # Unknown commands are dangerous by default
|
||||
|
|
|
|||
|
|
@ -1,8 +1,16 @@
|
|||
"""Security utilities for URL validation."""
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_blocked_hostnames = {
|
||||
"localhost",
|
||||
"metadata.google.internal",
|
||||
"metadata.internal",
|
||||
"metadata.azure.com",
|
||||
}
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks).
|
||||
|
|
@ -20,13 +28,6 @@ def is_safe_url(url: str) -> bool:
|
|||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
# Block known internal/metadata hostnames
|
||||
_blocked_hostnames = {
|
||||
"localhost",
|
||||
"metadata.google.internal",
|
||||
"metadata.internal",
|
||||
"metadata.azure.com",
|
||||
}
|
||||
hostname_lower = hostname.lower()
|
||||
if hostname_lower in _blocked_hostnames:
|
||||
return False
|
||||
|
|
@ -37,9 +38,15 @@ def is_safe_url(url: str) -> bool:
|
|||
if _is_unsafe_ip(ip):
|
||||
return False
|
||||
except ValueError:
|
||||
# hostname is a domain, not a literal IP — DNS rebinding risk remains
|
||||
# (would need DNS resolution to fully mitigate)
|
||||
pass
|
||||
# hostname is a domain — resolve DNS and check IPs
|
||||
try:
|
||||
addr_infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
for family, type_, proto, canonname, sockaddr in addr_infos:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
if _is_unsafe_ip(ip):
|
||||
return False
|
||||
except (socket.gaierror, socket.timeout, OSError):
|
||||
return False # fail-closed on DNS errors
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -58,3 +65,36 @@ def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
|||
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def is_safe_url_async(url: str) -> bool:
|
||||
"""Async variant of is_safe_url with DNS resolution."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
hostname_lower = hostname.lower()
|
||||
if hostname_lower in _blocked_hostnames:
|
||||
return False
|
||||
if hostname_lower.endswith((".internal", ".local", ".localhost")):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if _is_unsafe_ip(ip):
|
||||
return False
|
||||
except ValueError:
|
||||
try:
|
||||
import asyncio
|
||||
addr_infos = await asyncio.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
for family, type_, proto, canonname, sockaddr in addr_infos:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
if _is_unsafe_ip(ip):
|
||||
return False
|
||||
except (socket.gaierror, socket.timeout, OSError):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
|
@ -119,7 +121,7 @@ class TestWorkflowStore:
|
|||
|
||||
store = WorkflowStore()
|
||||
wf = WorkflowDefinition(workflow_id="test-1", name="Test")
|
||||
store.save(wf)
|
||||
asyncio.run(store.save(wf))
|
||||
result = store.get("test-1")
|
||||
assert result is not None
|
||||
assert result.name == "Test"
|
||||
|
|
@ -131,36 +133,45 @@ class TestWorkflowStore:
|
|||
def test_list(self):
|
||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||
|
||||
store = WorkflowStore()
|
||||
for i in range(3):
|
||||
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||
summaries = store.list()
|
||||
assert len(summaries) == 3
|
||||
async def _run():
|
||||
store = WorkflowStore()
|
||||
for i in range(3):
|
||||
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||
summaries = store.list()
|
||||
assert len(summaries) == 3
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_list_limit(self):
|
||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||
|
||||
store = WorkflowStore()
|
||||
for i in range(5):
|
||||
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||
summaries = store.list(limit=2)
|
||||
assert len(summaries) == 2
|
||||
async def _run():
|
||||
store = WorkflowStore()
|
||||
for i in range(5):
|
||||
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||
summaries = store.list(limit=2)
|
||||
assert len(summaries) == 2
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_delete(self):
|
||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||
|
||||
store = WorkflowStore()
|
||||
store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
|
||||
assert store.delete("del-1") is True
|
||||
assert store.get("del-1") is None
|
||||
async def _run():
|
||||
store = WorkflowStore()
|
||||
await store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
|
||||
assert await store.delete("del-1") is True
|
||||
assert store.get("del-1") is None
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_delete_not_found(self):
|
||||
store = WorkflowStore()
|
||||
assert store.delete("nonexistent") is False
|
||||
assert asyncio.run(store.delete("nonexistent")) is False
|
||||
|
||||
def test_create_and_get_execution(self):
|
||||
store = WorkflowStore()
|
||||
execution = store.create_execution("wf-1")
|
||||
execution = asyncio.run(store.create_execution("wf-1"))
|
||||
assert execution.workflow_id == "wf-1"
|
||||
assert execution.status == "pending"
|
||||
|
||||
|
|
@ -173,18 +184,51 @@ class TestWorkflowStore:
|
|||
assert store.get_execution("nonexistent") is None
|
||||
|
||||
def test_update_execution(self):
|
||||
store = WorkflowStore()
|
||||
execution = store.create_execution("wf-1")
|
||||
updated = store.update_execution(
|
||||
execution.execution_id, status="running", current_stage="step-1"
|
||||
)
|
||||
assert updated.status == "running"
|
||||
assert updated.current_stage == "step-1"
|
||||
async def _run():
|
||||
store = WorkflowStore()
|
||||
execution = await store.create_execution("wf-1")
|
||||
updated = await store.update_execution(
|
||||
execution.execution_id, status="running", current_stage="step-1"
|
||||
)
|
||||
assert updated.status == "running"
|
||||
assert updated.current_stage == "step-1"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_update_execution_not_found(self):
|
||||
store = WorkflowStore()
|
||||
with pytest.raises(KeyError):
|
||||
store.update_execution("nonexistent", status="running")
|
||||
asyncio.run(store.update_execution("nonexistent", status="running"))
|
||||
|
||||
def test_running_tasks_initialized(self):
|
||||
store = WorkflowStore()
|
||||
assert hasattr(store, "_running_tasks")
|
||||
assert isinstance(store._running_tasks, dict)
|
||||
assert len(store._running_tasks) == 0
|
||||
|
||||
def test_execution_locks_initialized(self):
|
||||
store = WorkflowStore()
|
||||
assert hasattr(store, "_execution_locks")
|
||||
assert isinstance(store._execution_locks, dict)
|
||||
|
||||
def test_evict_execution_cleans_approval_events(self):
|
||||
async def _run():
|
||||
store = WorkflowStore(max_executions=2)
|
||||
e1 = await store.create_execution("wf-1")
|
||||
e2 = await store.create_execution("wf-2")
|
||||
# Add an approval event for e1
|
||||
event_key = f"{e1.execution_id}:stage1"
|
||||
event = asyncio.Event()
|
||||
store._approval_events[event_key] = event
|
||||
# Create a third execution to trigger eviction of e1
|
||||
e3 = await store.create_execution("wf-3")
|
||||
# e1 should be evicted, and its approval event should be cleaned up
|
||||
assert store.get_execution(e1.execution_id) is None
|
||||
assert event_key not in store._approval_events
|
||||
# The event should have been set (to wake any waiting coroutine)
|
||||
assert event.is_set()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue