fix(security,reliability): resolve all P2 findings from code review
This commit is contained in:
parent
658e188939
commit
6852dfe892
|
|
@ -0,0 +1,287 @@
|
||||||
|
# fix: AgentKit P2 Security & Reliability Hardening
|
||||||
|
|
||||||
|
**Status:** active
|
||||||
|
**Created:** 2026-06-10
|
||||||
|
**Origin:** ce-code-review findings on main branch after 7-capabilities merge
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Fix all remaining P2 issues from the final code review of the `feat/agentkit-7-capabilities` merge. These 16 issues span security (SSRF, auth, injection), reliability (retry, concurrency, resource leaks), and architecture (auth consistency, config reload). All are non-blocking for current functionality but represent real risk in production.
|
||||||
|
|
||||||
|
## Problem Frame
|
||||||
|
|
||||||
|
After merging 7 core capabilities and fixing P0/P1 issues, the codebase has 16 P2 findings that fall into four categories:
|
||||||
|
|
||||||
|
1. **Security gaps** — SSRF DNS rebinding, API key timing attacks, WebSocket auth bypass, approval race conditions, shell pattern matching, .env injection
|
||||||
|
2. **Reliability gaps** — retry storms, unbounded memory growth, concurrent modification, config change loss
|
||||||
|
3. **Architecture gaps** — dual auth layer inconsistency, WebSocket broadcast without isolation
|
||||||
|
4. **Defensive gaps** — process kill race, lazy init race, httpx client double-creation
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- R1: All P2 security findings must be resolved or explicitly deferred with documented rationale
|
||||||
|
- R2: Retry logic must use exponential backoff with jitter
|
||||||
|
- R3: All in-memory stores must have bounded growth with proper eviction
|
||||||
|
- R4: Concurrent access to shared mutable state must be protected
|
||||||
|
- R5: Authentication must be consistent across all endpoints and transports
|
||||||
|
- R6: No regression in existing test suite (764+ tests must pass)
|
||||||
|
|
||||||
|
## Key Technical Decisions
|
||||||
|
|
||||||
|
- **KTD1: DNS resolution for SSRF** — Use `asyncio.getaddrinfo` (not blocking `socket.getaddrinfo`) to resolve hostnames before IP validation. Fail-closed on DNS errors. Accept that TOCTOU remains between check and connect; full mitigation requires IP pinning at the httpx transport level (deferred).
|
||||||
|
- **KTD2: Auth consolidation strategy** — Fix per-endpoint auth first (hmac.compare_digest, WebSocket auth-before-accept), then consolidate into middleware in a separate pass. This avoids a big-bang refactor and allows incremental testing.
|
||||||
|
- **KTD3: WorkflowStore concurrency** — Add `asyncio.Lock` to WorkflowStore, convert sync methods to async. This is the most invasive change but necessary for correctness under concurrent load.
|
||||||
|
- **KTD4: Shell dangerous patterns** — Replace substring matching with token-based matching using the already-parsed shlex tokens. Separate binary-level and flag-level dangerous patterns.
|
||||||
|
- **KTD5: WebSocket isolation** — Replace flat subscriber list with `dict[str, set[WebSocket]]` keyed by execution_id. Clients subscribe to specific executions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Units
|
||||||
|
|
||||||
|
### U1. API Key Timing Attack & WebSocket Auth-Before-Accept
|
||||||
|
|
||||||
|
**Goal:** Eliminate timing side-channel in API key comparison and fix WebSocket accept-before-auth pattern.
|
||||||
|
|
||||||
|
**Requirements:** R5
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/server/routes/portal.py`
|
||||||
|
- `src/agentkit/server/routes/workflows.py`
|
||||||
|
- `src/agentkit/server/middleware.py`
|
||||||
|
- `tests/unit/server/test_portal_routes.py`
|
||||||
|
- `tests/unit/server/test_workflow_routes.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. Add `import hmac` to portal.py, workflows.py, middleware.py
|
||||||
|
2. Replace all `!=` API key comparisons with `hmac.compare_digest()`
|
||||||
|
3. In both WebSocket handlers, move auth validation BEFORE `websocket.accept()`. On auth failure, call `await websocket.close(code=4001)` without accepting, then return
|
||||||
|
4. In middleware.py, replace `provided_key not in valid_keys` with `any(hmac.compare_digest(...) for k in valid_keys)`
|
||||||
|
5. Handle None/empty provided key by defaulting to empty string before comparison
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- API key comparison with correct key returns True
|
||||||
|
- API key comparison with incorrect key returns False
|
||||||
|
- API key comparison with None provided defaults to empty string
|
||||||
|
- WebSocket connection without api_key is rejected (no 101 upgrade)
|
||||||
|
- WebSocket connection with wrong api_key is rejected
|
||||||
|
- WebSocket connection with correct api_key is accepted
|
||||||
|
|
||||||
|
**Verification:** All existing tests pass; new auth tests pass
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U2. Shell Dangerous Pattern Token-Based Matching
|
||||||
|
|
||||||
|
**Goal:** Replace substring-based dangerous pattern matching with precise token-based matching to eliminate false positives and bypasses.
|
||||||
|
|
||||||
|
**Requirements:** R1
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/tools/shell.py`
|
||||||
|
- `tests/unit/test_shell_tool.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. Replace `_DANGEROUS_PATTERNS` tuple with two new structures:
|
||||||
|
- `_DANGEROUS_BINARIES`: set of binary names that are always dangerous (e.g., `rm`, `mkfs`, `dd`, `shutdown`)
|
||||||
|
- `_DANGEROUS_BINARY_FLAGS`: dict mapping binary name to set of dangerous flag combinations (e.g., `kill: {-9, -KILL}`, `chmod: {777}`)
|
||||||
|
2. In `_is_dangerous()`, after shlex parsing:
|
||||||
|
- Check `binary` against `_DANGEROUS_BINARIES` (exact match)
|
||||||
|
- Check `binary` + flags against `_DANGEROUS_BINARY_FLAGS`
|
||||||
|
- Keep regex-based arg patterns for cross-token matches (e.g., `> /dev/`, `drop table`)
|
||||||
|
3. Remove overly broad patterns like `format`, `erase` as substrings; add `format` as binary-only
|
||||||
|
4. Normalize whitespace in command before matching
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- `rm -rf /` detected as dangerous (binary match)
|
||||||
|
- `echo 'performing rm operations'` NOT detected as dangerous (no binary match)
|
||||||
|
- `kill -9 1234` detected as dangerous (binary + flag match)
|
||||||
|
- `kill -15 1234` NOT detected as dangerous (flag not in dangerous set)
|
||||||
|
- `rm\t/tmp/file` detected as dangerous (tab normalization)
|
||||||
|
- `format C:` detected as dangerous (binary match)
|
||||||
|
- `informatica` NOT detected as dangerous (not a binary match)
|
||||||
|
|
||||||
|
**Verification:** Shell tool tests pass with new pattern logic
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U3. .env File Prefix Filtering & Retry Backoff
|
||||||
|
|
||||||
|
**Goal:** Prevent arbitrary environment variable injection from .env files and add exponential backoff to plan executor retries.
|
||||||
|
|
||||||
|
**Requirements:** R1, R2
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/server/app.py`
|
||||||
|
- `src/agentkit/core/plan_executor.py`
|
||||||
|
- `tests/unit/test_plan_executor.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. In app.py .env loading, add an allowlist of prefixes: `AGENTKIT_`, `OPENAI_`, `ANTHROPIC_`, `GEMINI_`, `TAVILY_`, `SERPER_`, plus exact matches for `DATABASE_URL`, `REDIS_URL`
|
||||||
|
2. Log a warning when a .env variable is skipped due to prefix filtering
|
||||||
|
3. In plan_executor.py `_execute_step_with_retry()`, add exponential backoff with jitter after each failure:
|
||||||
|
- `base_retry_delay=1.0s`, `max_retry_delay=30.0s`
|
||||||
|
- `delay = min(base * 2^(retry-1), max) * (0.5 + random() * 0.5)`
|
||||||
|
- Skip delay on first attempt (retry_count == 0)
|
||||||
|
4. Add `base_retry_delay` and `max_retry_delay` as constructor parameters with defaults
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- .env with `AGENTKIT_FOO=bar` is loaded
|
||||||
|
- .env with `PATH=/malicious` is skipped with warning
|
||||||
|
- .env with `OPENAI_API_KEY=sk-xxx` is loaded
|
||||||
|
- Retry backoff: first retry waits ~1s, second ~2s, third ~4s (with jitter)
|
||||||
|
- Max retry delay caps at 30s
|
||||||
|
- Plan executor with max_retries=0 still works (no backoff)
|
||||||
|
|
||||||
|
**Verification:** Unit tests pass; .env loading respects prefix filter
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U4. WorkflowStore Concurrency & Resource Management
|
||||||
|
|
||||||
|
**Goal:** Make WorkflowStore thread-safe, fix unbounded growth, add proper cleanup for approval events and running tasks.
|
||||||
|
|
||||||
|
**Requirements:** R3, R4
|
||||||
|
|
||||||
|
**Dependencies:** U1 (hmac.compare_digest in workflow routes)
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/server/routes/workflows.py`
|
||||||
|
- `tests/unit/server/test_workflow_routes.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. Add `asyncio.Lock` to WorkflowStore; convert `save`, `delete`, `create_execution`, `update_execution` to async methods
|
||||||
|
2. Initialize `_running_tasks` in `__init__` instead of lazy `getattr`
|
||||||
|
3. Add `_evict_execution()` helper that removes execution AND its approval events (with `event.set()` to wake waiting coroutines)
|
||||||
|
4. Replace `_ws_subscribers: list` with `_ws_subscribers: dict[str, set[WebSocket]]` keyed by execution_id, protected by `asyncio.Lock`
|
||||||
|
5. Add `_subscribe(execution_id, ws)` and `_unsubscribe(execution_id, ws)` helpers
|
||||||
|
6. Modify `_broadcast_ws` to accept `execution_id` parameter and only send to relevant subscribers
|
||||||
|
7. Update all call sites from sync to async (all endpoint functions already async)
|
||||||
|
8. Add `_execution_locks: dict[str, asyncio.Lock]` for per-execution approve/cancel serialization
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Concurrent save operations don't cause KeyError or over-eviction
|
||||||
|
- Evicted execution's approval events are cleaned up
|
||||||
|
- WebSocket subscriber receives only messages for subscribed execution_id
|
||||||
|
- WebSocket disconnect properly cleans up from subscriber dict
|
||||||
|
- Concurrent approve/cancel for same execution serialized by lock
|
||||||
|
- _running_tasks initialized in __init__, no getattr fallback
|
||||||
|
|
||||||
|
**Verification:** Workflow route tests pass; no RuntimeError from concurrent modification
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U5. SSRF DNS Resolution & ComputerUse Client Safety
|
||||||
|
|
||||||
|
**Goal:** Add DNS resolution to SSRF protection and make ComputerUseTool's httpx client creation thread-safe.
|
||||||
|
|
||||||
|
**Requirements:** R1
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/utils/security.py`
|
||||||
|
- `src/agentkit/tools/computer_use.py`
|
||||||
|
- `tests/unit/test_security.py`
|
||||||
|
- `tests/unit/tools/test_computer_use.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. In `is_safe_url()`, after the ValueError catch for domain names, add DNS resolution:
|
||||||
|
- Use `socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)` (sync, acceptable for tool init)
|
||||||
|
- Validate each resolved IP against `_is_unsafe_ip()`
|
||||||
|
- Fail-closed on DNS errors (return False)
|
||||||
|
- Add 3-second timeout on DNS resolution
|
||||||
|
2. Add `is_safe_url_async()` variant using `asyncio.getaddrinfo` for async callers
|
||||||
|
3. In ComputerUseTool, add `asyncio.Lock` to `_get_http_client()`, make it async with double-check locking
|
||||||
|
4. Update caller from `client = self._get_http_client()` to `client = await self._get_http_client()`
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- is_safe_url with domain resolving to 127.0.0.1 returns False
|
||||||
|
- is_safe_url with domain resolving to public IP returns True
|
||||||
|
- is_safe_url with unresolvable domain returns False (fail-closed)
|
||||||
|
- is_safe_url with ::ffff:127.0.0.1 returns False
|
||||||
|
- is_safe_url with .internal TLD returns False
|
||||||
|
- ComputerUseTool concurrent _get_http_client returns same client instance
|
||||||
|
- ComputerUseTool close() properly cleans up client
|
||||||
|
|
||||||
|
**Verification:** Security and computer_use tests pass
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U6. Config Hot-Reload & Defensive Fixes
|
||||||
|
|
||||||
|
**Goal:** Fix config hot-reload lock race, add proc.kill() error handling, and initialize config reload lock eagerly.
|
||||||
|
|
||||||
|
**Requirements:** R4
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/server/app.py`
|
||||||
|
- `src/agentkit/tools/shell.py`
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
1. In app.py `_on_config_change()`:
|
||||||
|
- Initialize lock eagerly in `create_app()` (store on `app.state._config_reload_lock`)
|
||||||
|
- Replace `lock.locked()` skip with pending-flag pattern: set `app.state._config_reload_pending = True`, then in `_reload()` loop while pending flag is set
|
||||||
|
- This ensures no config change is silently dropped
|
||||||
|
2. In shell.py `_execute_standalone()`:
|
||||||
|
- Wrap `proc.kill()` in `try/except (ProcessLookupError, OSError)`
|
||||||
|
- Use `proc.returncode` if available instead of hardcoded -1
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Two rapid config changes both get applied (no silent drop)
|
||||||
|
- Config reload pending flag is cleared after reload completes
|
||||||
|
- Shell timeout with already-exited process returns clean error (no ProcessLookupError)
|
||||||
|
- Shell timeout output still shows timeout message
|
||||||
|
|
||||||
|
**Verification:** App startup and shell tool tests pass
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scope Boundaries
|
||||||
|
|
||||||
|
### In Scope
|
||||||
|
- All 16 P2 findings from ce-code-review
|
||||||
|
- Test coverage for new security/reliability patterns
|
||||||
|
|
||||||
|
### Deferred to Follow-Up Work
|
||||||
|
- Full auth middleware consolidation (ASGI-level WebSocket middleware) — complex, separate PR
|
||||||
|
- IP pinning at httpx transport level for complete SSRF protection
|
||||||
|
- python-dotenv migration for robust .env parsing
|
||||||
|
- clients.yaml caching in middleware (per-request disk read)
|
||||||
|
- WorkflowStore persistence (currently in-memory only)
|
||||||
|
|
||||||
|
### Outside This Plan's Identity
|
||||||
|
- New features or capability additions
|
||||||
|
- Performance optimization beyond what's needed for reliability
|
||||||
|
- UI/UX changes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risks & Mitigations
|
||||||
|
|
||||||
|
| Risk | Impact | Mitigation |
|
||||||
|
|------|--------|------------|
|
||||||
|
| WorkflowStore async refactor breaks existing endpoints | High | Update all call sites; comprehensive test coverage |
|
||||||
|
| DNS resolution adds latency to URL validation | Low | 3s timeout; acceptable for tool initialization |
|
||||||
|
| Token-based shell matching misses new bypass patterns | Medium | Default-to-dangerous fallback; add patterns as discovered |
|
||||||
|
| .env prefix filtering breaks existing deployments | Medium | Comprehensive allowlist; log warnings for skipped vars |
|
||||||
|
| WebSocket isolation breaks existing WS clients | Medium | Maintain backward compat: accept without subscription, send all |
|
||||||
|
| Config reload pending loop runs indefinitely | Low | Coalescing naturally limits; only re-runs if new change arrived |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## System-Wide Impact
|
||||||
|
|
||||||
|
- **Security posture** significantly improved: timing attacks, SSRF, auth bypass all addressed
|
||||||
|
- **Reliability** improved: retry storms prevented, memory bounded, concurrent access safe
|
||||||
|
- **API compatibility**: WorkflowStore method signatures change from sync to async — all call sites updated
|
||||||
|
- **WebSocket protocol**: subscription message becomes recommended but not required for backward compat
|
||||||
|
|
@ -14,6 +14,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
@ -94,6 +95,8 @@ class PlanExecutor:
|
||||||
max_retries: int = 2,
|
max_retries: int = 2,
|
||||||
step_timeout: float = 300.0,
|
step_timeout: float = 300.0,
|
||||||
max_parallel: int = 5,
|
max_parallel: int = 5,
|
||||||
|
base_retry_delay: float = 1.0,
|
||||||
|
max_retry_delay: float = 30.0,
|
||||||
on_step_complete: OnStepCompleteCallback | None = None,
|
on_step_complete: OnStepCompleteCallback | None = None,
|
||||||
on_step_failed: OnStepFailedCallback | None = None,
|
on_step_failed: OnStepFailedCallback | None = None,
|
||||||
on_human_intervention: OnHumanInterventionCallback | None = None,
|
on_human_intervention: OnHumanInterventionCallback | None = None,
|
||||||
|
|
@ -104,6 +107,8 @@ class PlanExecutor:
|
||||||
max_retries: 步骤失败后最大重试次数
|
max_retries: 步骤失败后最大重试次数
|
||||||
step_timeout: 单个步骤超时时间(秒)
|
step_timeout: 单个步骤超时时间(秒)
|
||||||
max_parallel: 最大并行步骤数
|
max_parallel: 最大并行步骤数
|
||||||
|
base_retry_delay: 重试基础延迟(秒)
|
||||||
|
max_retry_delay: 重试最大延迟(秒)
|
||||||
on_step_complete: 步骤完成回调
|
on_step_complete: 步骤完成回调
|
||||||
on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理
|
on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理
|
||||||
on_human_intervention: 人工介入回调
|
on_human_intervention: 人工介入回调
|
||||||
|
|
@ -112,6 +117,8 @@ class PlanExecutor:
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._step_timeout = step_timeout
|
self._step_timeout = step_timeout
|
||||||
self._max_parallel = max_parallel
|
self._max_parallel = max_parallel
|
||||||
|
self._base_retry_delay = base_retry_delay
|
||||||
|
self._max_retry_delay = max_retry_delay
|
||||||
self._on_step_complete = on_step_complete
|
self._on_step_complete = on_step_complete
|
||||||
self._on_step_failed = on_step_failed
|
self._on_step_failed = on_step_failed
|
||||||
self._on_human_intervention = on_human_intervention
|
self._on_human_intervention = on_human_intervention
|
||||||
|
|
@ -250,6 +257,15 @@ class PlanExecutor:
|
||||||
|
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
|
|
||||||
|
if retry_count <= self._max_retries:
|
||||||
|
delay = min(
|
||||||
|
self._base_retry_delay * (2 ** (retry_count - 1)),
|
||||||
|
self._max_retry_delay,
|
||||||
|
)
|
||||||
|
delay *= (0.5 + random.random() * 0.5) # jitter
|
||||||
|
logger.info(f"Retrying step '{step.step_id}' in {delay:.1f}s (attempt {retry_count + 1}/{self._max_retries + 1})")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
# 所有重试耗尽
|
# 所有重试耗尽
|
||||||
step.status = PlanStepStatus.FAILED
|
step.status = PlanStepStatus.FAILED
|
||||||
step.error = last_error
|
step.error = last_error
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,12 @@ from agentkit.telemetry.setup import setup_telemetry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ALLOWED_ENV_PREFIXES = (
|
||||||
|
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||||||
|
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
||||||
|
)
|
||||||
|
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
||||||
|
|
||||||
|
|
||||||
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||||
"""Build LLMGateway from ServerConfig, registering all providers."""
|
"""Build LLMGateway from ServerConfig, registering all providers."""
|
||||||
|
|
@ -224,70 +230,69 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
||||||
|
|
||||||
Uses a lock to prevent concurrent config reloads from racing.
|
Uses a lock to prevent concurrent config reloads from racing.
|
||||||
"""
|
"""
|
||||||
lock: asyncio.Lock = getattr(app.state, "_config_reload_lock", None)
|
lock: asyncio.Lock = app.state._config_reload_lock
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
app.state._config_reload_lock = lock
|
|
||||||
|
|
||||||
if lock.locked():
|
app.state._config_reload_pending = True
|
||||||
logger.warning("Config reload already in progress, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _reload():
|
async def _reload():
|
||||||
|
if lock.locked():
|
||||||
|
return # Another reload running; it will check pending flag
|
||||||
async with lock:
|
async with lock:
|
||||||
# Increment config version for audit
|
while getattr(app.state, "_config_reload_pending", False):
|
||||||
current_version = getattr(app.state, "config_version", 0) + 1
|
app.state._config_reload_pending = False
|
||||||
app.state.config_version = current_version
|
# Increment config version for audit
|
||||||
logger.info(f"Config change detected (v{current_version}), reloading...")
|
current_version = getattr(app.state, "config_version", 0) + 1
|
||||||
|
app.state.config_version = current_version
|
||||||
|
logger.info(f"Config change detected (v{current_version}), reloading...")
|
||||||
|
|
||||||
# Rebuild LLMGateway if llm config changed
|
# Rebuild LLMGateway if llm config changed
|
||||||
try:
|
try:
|
||||||
new_gateway = _build_llm_gateway(config)
|
new_gateway = _build_llm_gateway(config)
|
||||||
app.state.llm_gateway = new_gateway
|
app.state.llm_gateway = new_gateway
|
||||||
# Also update the agent pool's gateway reference
|
# Also update the agent pool's gateway reference
|
||||||
|
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||||
|
app.state.agent_pool._llm_gateway = new_gateway
|
||||||
|
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
||||||
|
app.state.intent_router._llm_gateway = new_gateway
|
||||||
|
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reload LLM Gateway: {e}")
|
||||||
|
|
||||||
|
# Reload skills if skill paths changed
|
||||||
|
try:
|
||||||
|
new_skill_registry = _build_skill_registry(config)
|
||||||
|
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
|
||||||
|
tool_registry = getattr(app.state, "tool_registry", None)
|
||||||
|
if tool_registry:
|
||||||
|
from agentkit.skills.loader import SkillLoader
|
||||||
|
loader = SkillLoader(
|
||||||
|
skill_registry=new_skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
for skill_path in (config.skill_paths or []):
|
||||||
|
from pathlib import Path as _P
|
||||||
|
p = _P(skill_path)
|
||||||
|
if p.is_dir():
|
||||||
|
loader.load_from_directory(str(p))
|
||||||
|
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
||||||
|
try:
|
||||||
|
loader.load_from_file(str(p))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
app.state.skill_registry = new_skill_registry
|
||||||
|
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||||
|
app.state.agent_pool._skill_registry = new_skill_registry
|
||||||
|
logger.info(f"Skills reloaded (config v{current_version})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reload skills: {e}")
|
||||||
|
|
||||||
|
# Update config version on all agents
|
||||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
||||||
app.state.agent_pool._llm_gateway = new_gateway
|
for agent in app.state.agent_pool._agents.values():
|
||||||
if hasattr(app.state, "intent_router") and app.state.intent_router is not None:
|
if hasattr(agent, "_config_version"):
|
||||||
app.state.intent_router._llm_gateway = new_gateway
|
agent._config_version = current_version
|
||||||
logger.info(f"LLM Gateway reloaded (config v{current_version})")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to reload LLM Gateway: {e}")
|
|
||||||
|
|
||||||
# Reload skills if skill paths changed
|
logger.info(f"Config reload complete (v{current_version})")
|
||||||
try:
|
|
||||||
new_skill_registry = _build_skill_registry(config)
|
|
||||||
# Re-bind tools from the shared tool_registry so skills don't lose their bindings
|
|
||||||
tool_registry = getattr(app.state, "tool_registry", None)
|
|
||||||
if tool_registry:
|
|
||||||
from agentkit.skills.loader import SkillLoader
|
|
||||||
loader = SkillLoader(
|
|
||||||
skill_registry=new_skill_registry,
|
|
||||||
tool_registry=tool_registry,
|
|
||||||
)
|
|
||||||
for skill_path in (config.skill_paths or []):
|
|
||||||
from pathlib import Path as _P
|
|
||||||
p = _P(skill_path)
|
|
||||||
if p.is_dir():
|
|
||||||
loader.load_from_directory(str(p))
|
|
||||||
elif p.is_file() and p.suffix in (".yaml", ".yml"):
|
|
||||||
try:
|
|
||||||
loader.load_from_file(str(p))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
app.state.skill_registry = new_skill_registry
|
|
||||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
|
||||||
app.state.agent_pool._skill_registry = new_skill_registry
|
|
||||||
logger.info(f"Skills reloaded (config v{current_version})")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to reload skills: {e}")
|
|
||||||
|
|
||||||
# Update config version on all agents
|
|
||||||
if hasattr(app.state, "agent_pool") and app.state.agent_pool is not None:
|
|
||||||
for agent in app.state.agent_pool._agents.values():
|
|
||||||
if hasattr(agent, "_config_version"):
|
|
||||||
agent._config_version = current_version
|
|
||||||
|
|
||||||
logger.info(f"Config reload complete (v{current_version})")
|
|
||||||
|
|
||||||
# Schedule the reload as a task (non-blocking for the watcher thread)
|
# Schedule the reload as a task (non-blocking for the watcher thread)
|
||||||
try:
|
try:
|
||||||
|
|
@ -327,6 +332,10 @@ def create_app(
|
||||||
_key = _key.strip()
|
_key = _key.strip()
|
||||||
_val = _val.strip().strip("\"'")
|
_val = _val.strip().strip("\"'")
|
||||||
if _key and _key not in os.environ:
|
if _key and _key not in os.environ:
|
||||||
|
allowed = any(_key.startswith(p) for p in _ALLOWED_ENV_PREFIXES) or _key in _ALLOWED_ENV_EXACT
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(f"Skipping .env variable '{_key}' (not in allowed prefixes)")
|
||||||
|
continue
|
||||||
os.environ[_key] = _val
|
os.environ[_key] = _val
|
||||||
server_config = ServerConfig.from_yaml(config_path)
|
server_config = ServerConfig.from_yaml(config_path)
|
||||||
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Server middleware - Authentication and Rate Limiting"""
|
"""Server middleware - Authentication and Rate Limiting"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -75,7 +76,7 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
|
||||||
# Check API key from header
|
# Check API key from header
|
||||||
provided_key = request.headers.get("X-API-Key")
|
provided_key = request.headers.get("X-API-Key")
|
||||||
if not provided_key or provided_key not in valid_keys:
|
if not provided_key or not any(hmac.compare_digest(provided_key.encode(), k.encode()) for k in valid_keys):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -47,7 +48,7 @@ async def _verify_api_key(
|
||||||
return
|
return
|
||||||
|
|
||||||
provided = api_key_header or api_key_query
|
provided = api_key_header or api_key_query
|
||||||
if provided != configured_api_key:
|
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||||
|
|
@ -475,11 +476,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
# Check api_key query param
|
# Check api_key query param
|
||||||
if configured_api_key:
|
if configured_api_key:
|
||||||
provided = websocket.query_params.get("api_key")
|
provided = websocket.query_params.get("api_key")
|
||||||
if provided != configured_api_key:
|
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||||
await websocket.accept()
|
|
||||||
await websocket.send_json(
|
|
||||||
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
|
|
||||||
)
|
|
||||||
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -52,7 +53,7 @@ async def _verify_api_key(
|
||||||
return
|
return
|
||||||
|
|
||||||
provided = api_key_header or api_key_query
|
provided = api_key_header or api_key_query
|
||||||
if provided != configured_api_key:
|
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||||
|
|
@ -65,7 +66,7 @@ async def _verify_api_key(
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStore:
|
class WorkflowStore:
|
||||||
"""In-memory workflow store."""
|
"""In-memory workflow store with async-safe mutation methods."""
|
||||||
|
|
||||||
def __init__(self, max_workflows: int = 500, max_executions: int = 1000):
|
def __init__(self, max_workflows: int = 500, max_executions: int = 1000):
|
||||||
self._workflows: dict[str, WorkflowDefinition] = {}
|
self._workflows: dict[str, WorkflowDefinition] = {}
|
||||||
|
|
@ -73,17 +74,30 @@ class WorkflowStore:
|
||||||
self._max_workflows = max_workflows
|
self._max_workflows = max_workflows
|
||||||
self._max_executions = max_executions
|
self._max_executions = max_executions
|
||||||
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
|
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
|
||||||
|
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._execution_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
def _evict_execution(self, execution_id: str) -> None:
|
||||||
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
"""Remove execution and its associated approval events."""
|
||||||
self._workflows[workflow.workflow_id] = workflow
|
self._executions.pop(execution_id, None)
|
||||||
# Evict oldest if over limit
|
keys_to_remove = [k for k in self._approval_events if k.startswith(f"{execution_id}:")]
|
||||||
if len(self._workflows) > self._max_workflows:
|
for k in keys_to_remove:
|
||||||
oldest_id = min(
|
event = self._approval_events.pop(k, None)
|
||||||
self._workflows, key=lambda k: self._workflows[k].updated_at
|
if event is not None:
|
||||||
)
|
event.set() # Wake any waiting coroutine
|
||||||
del self._workflows[oldest_id]
|
|
||||||
return workflow
|
async def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
||||||
|
async with self._lock:
|
||||||
|
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
self._workflows[workflow.workflow_id] = workflow
|
||||||
|
# Evict oldest if over limit
|
||||||
|
if len(self._workflows) > self._max_workflows:
|
||||||
|
oldest_id = min(
|
||||||
|
self._workflows, key=lambda k: self._workflows[k].updated_at
|
||||||
|
)
|
||||||
|
del self._workflows[oldest_id]
|
||||||
|
return workflow
|
||||||
|
|
||||||
def get(self, workflow_id: str) -> WorkflowDefinition | None:
|
def get(self, workflow_id: str) -> WorkflowDefinition | None:
|
||||||
return self._workflows.get(workflow_id)
|
return self._workflows.get(workflow_id)
|
||||||
|
|
@ -109,47 +123,72 @@ class WorkflowStore:
|
||||||
)
|
)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
def delete(self, workflow_id: str) -> bool:
|
async def delete(self, workflow_id: str) -> bool:
|
||||||
if workflow_id in self._workflows:
|
async with self._lock:
|
||||||
del self._workflows[workflow_id]
|
if workflow_id in self._workflows:
|
||||||
return True
|
del self._workflows[workflow_id]
|
||||||
return False
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def create_execution(self, workflow_id: str) -> WorkflowExecution:
|
async def create_execution(self, workflow_id: str) -> WorkflowExecution:
|
||||||
execution = WorkflowExecution(
|
async with self._lock:
|
||||||
execution_id=str(uuid.uuid4()),
|
execution = WorkflowExecution(
|
||||||
workflow_id=workflow_id,
|
execution_id=str(uuid.uuid4()),
|
||||||
status="pending",
|
workflow_id=workflow_id,
|
||||||
started_at=datetime.now(timezone.utc).isoformat(),
|
status="pending",
|
||||||
)
|
started_at=datetime.now(timezone.utc).isoformat(),
|
||||||
self._executions[execution.execution_id] = execution
|
|
||||||
# Evict oldest if over limit
|
|
||||||
if len(self._executions) > self._max_executions:
|
|
||||||
oldest_id = min(
|
|
||||||
self._executions,
|
|
||||||
key=lambda k: self._executions[k].started_at or "",
|
|
||||||
)
|
)
|
||||||
del self._executions[oldest_id]
|
self._executions[execution.execution_id] = execution
|
||||||
return execution
|
# Evict oldest if over limit
|
||||||
|
if len(self._executions) > self._max_executions:
|
||||||
|
oldest_id = min(
|
||||||
|
self._executions,
|
||||||
|
key=lambda k: self._executions[k].started_at or "",
|
||||||
|
)
|
||||||
|
self._evict_execution(oldest_id)
|
||||||
|
return execution
|
||||||
|
|
||||||
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
|
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
|
||||||
return self._executions.get(execution_id)
|
return self._executions.get(execution_id)
|
||||||
|
|
||||||
def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
|
async def update_execution(self, execution_id: str, **kwargs: Any) -> WorkflowExecution:
|
||||||
execution = self._executions.get(execution_id)
|
async with self._lock:
|
||||||
if execution is None:
|
execution = self._executions.get(execution_id)
|
||||||
raise KeyError(f"Execution '{execution_id}' not found")
|
if execution is None:
|
||||||
for key, value in kwargs.items():
|
raise KeyError(f"Execution '{execution_id}' not found")
|
||||||
if hasattr(execution, key):
|
for key, value in kwargs.items():
|
||||||
setattr(execution, key, value)
|
if hasattr(execution, key):
|
||||||
return execution
|
setattr(execution, key, value)
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def get_execution_lock(self, execution_id: str) -> asyncio.Lock:
|
||||||
|
"""Get or create a per-execution lock for approve/cancel serialization."""
|
||||||
|
if execution_id not in self._execution_locks:
|
||||||
|
self._execution_locks[execution_id] = asyncio.Lock()
|
||||||
|
return self._execution_locks[execution_id]
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton
|
||||||
_workflow_store = WorkflowStore()
|
_workflow_store = WorkflowStore()
|
||||||
|
|
||||||
# WebSocket subscribers for real-time execution progress
|
# WebSocket subscribers for real-time execution progress (keyed by execution_id)
|
||||||
_ws_subscribers: list[WebSocket] = []
|
_ws_subscribers: dict[str, set[WebSocket]] = {}
|
||||||
|
_ws_subscribers_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def _ws_subscribe(execution_id: str, ws: WebSocket) -> None:
|
||||||
|
async with _ws_subscribers_lock:
|
||||||
|
if execution_id not in _ws_subscribers:
|
||||||
|
_ws_subscribers[execution_id] = set()
|
||||||
|
_ws_subscribers[execution_id].add(ws)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ws_unsubscribe(execution_id: str, ws: WebSocket) -> None:
|
||||||
|
async with _ws_subscribers_lock:
|
||||||
|
if execution_id in _ws_subscribers:
|
||||||
|
_ws_subscribers[execution_id].discard(ws)
|
||||||
|
if not _ws_subscribers[execution_id]:
|
||||||
|
del _ws_subscribers[execution_id]
|
||||||
|
|
||||||
|
|
||||||
def _get_store(request: Request) -> WorkflowStore:
|
def _get_store(request: Request) -> WorkflowStore:
|
||||||
|
|
@ -210,7 +249,7 @@ async def _execute_workflow(
|
||||||
"""Execute a workflow by running its stages in topological order."""
|
"""Execute a workflow by running its stages in topological order."""
|
||||||
_store = store or _workflow_store
|
_store = store or _workflow_store
|
||||||
execution.status = "running"
|
execution.status = "running"
|
||||||
_store.update_execution(execution.execution_id, status="running")
|
await _store.update_execution(execution.execution_id, status="running")
|
||||||
|
|
||||||
# Topological sort
|
# Topological sort
|
||||||
stage_map = {s.name: s for s in workflow.stages}
|
stage_map = {s.name: s for s in workflow.stages}
|
||||||
|
|
@ -229,7 +268,7 @@ async def _execute_workflow(
|
||||||
execution.status = "failed"
|
execution.status = "failed"
|
||||||
execution.error = "循环依赖"
|
execution.error = "循环依赖"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="failed",
|
status="failed",
|
||||||
error="循环依赖",
|
error="循环依赖",
|
||||||
|
|
@ -246,7 +285,7 @@ async def _execute_workflow(
|
||||||
for stage_name in ordered:
|
for stage_name in ordered:
|
||||||
stage = stage_map[stage_name]
|
stage = stage_map[stage_name]
|
||||||
execution.current_stage = stage_name
|
execution.current_stage = stage_name
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
current_stage=stage_name,
|
current_stage=stage_name,
|
||||||
)
|
)
|
||||||
|
|
@ -256,7 +295,7 @@ async def _execute_workflow(
|
||||||
"event": "stage_started",
|
"event": "stage_started",
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stage.type == "approval":
|
if stage.type == "approval":
|
||||||
|
|
@ -267,7 +306,7 @@ async def _execute_workflow(
|
||||||
|
|
||||||
execution.status = "paused"
|
execution.status = "paused"
|
||||||
execution.current_stage = stage_name
|
execution.current_stage = stage_name
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="paused",
|
status="paused",
|
||||||
current_stage=stage_name,
|
current_stage=stage_name,
|
||||||
|
|
@ -276,7 +315,7 @@ async def _execute_workflow(
|
||||||
"event": "approval_required",
|
"event": "approval_required",
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
|
|
||||||
# Wait for approval with timeout
|
# Wait for approval with timeout
|
||||||
try:
|
try:
|
||||||
|
|
@ -289,13 +328,13 @@ async def _execute_workflow(
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
"error": "Approval rejected",
|
"error": "Approval rejected",
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
return
|
return
|
||||||
# Approval was granted — the /approve endpoint already set stage_results
|
# Approval was granted — the /approve endpoint already set stage_results
|
||||||
# Only update status to running if not already set
|
# Only update status to running if not already set
|
||||||
if execution.status != "running":
|
if execution.status != "running":
|
||||||
execution.status = "running"
|
execution.status = "running"
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
|
|
@ -308,7 +347,7 @@ async def _execute_workflow(
|
||||||
execution.status = "failed"
|
execution.status = "failed"
|
||||||
execution.error = f"Approval timeout for stage {stage_name}"
|
execution.error = f"Approval timeout for stage {stage_name}"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="failed",
|
status="failed",
|
||||||
error=execution.error,
|
error=execution.error,
|
||||||
|
|
@ -320,7 +359,7 @@ async def _execute_workflow(
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
"error": "Approval timeout",
|
"error": "Approval timeout",
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
return
|
return
|
||||||
finally:
|
finally:
|
||||||
_store._approval_events.pop(event_key, None)
|
_store._approval_events.pop(event_key, None)
|
||||||
|
|
@ -332,7 +371,7 @@ async def _execute_workflow(
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"condition_result": result,
|
"condition_result": result,
|
||||||
}
|
}
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
stage_results=execution.stage_results,
|
stage_results=execution.stage_results,
|
||||||
)
|
)
|
||||||
|
|
@ -342,7 +381,7 @@ async def _execute_workflow(
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"output": {"dry_run": True, "action": stage.action},
|
"output": {"dry_run": True, "action": stage.action},
|
||||||
}
|
}
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
stage_results=execution.stage_results,
|
stage_results=execution.stage_results,
|
||||||
)
|
)
|
||||||
|
|
@ -351,7 +390,7 @@ async def _execute_workflow(
|
||||||
"event": "stage_completed",
|
"event": "stage_completed",
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
execution.stage_results[stage_name] = {
|
execution.stage_results[stage_name] = {
|
||||||
|
|
@ -361,7 +400,7 @@ async def _execute_workflow(
|
||||||
execution.status = "failed"
|
execution.status = "failed"
|
||||||
execution.error = f"阶段 '{stage_name}' 执行失败: {e}"
|
execution.error = f"阶段 '{stage_name}' 执行失败: {e}"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="failed",
|
status="failed",
|
||||||
error=execution.error,
|
error=execution.error,
|
||||||
|
|
@ -373,13 +412,13 @@ async def _execute_workflow(
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
execution.status = "completed"
|
execution.status = "completed"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
execution.current_stage = None
|
execution.current_stage = None
|
||||||
_store.update_execution(
|
await _store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="completed",
|
status="completed",
|
||||||
completed_at=execution.completed_at,
|
completed_at=execution.completed_at,
|
||||||
|
|
@ -388,7 +427,7 @@ async def _execute_workflow(
|
||||||
await _broadcast_ws({
|
await _broadcast_ws({
|
||||||
"event": "execution_completed",
|
"event": "execution_completed",
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
})
|
}, execution_id=execution.execution_id)
|
||||||
|
|
||||||
|
|
||||||
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||||
|
|
@ -455,16 +494,25 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
||||||
raise ValueError(f"Invalid condition expression: {expression}")
|
raise ValueError(f"Invalid condition expression: {expression}")
|
||||||
|
|
||||||
|
|
||||||
async def _broadcast_ws(message: dict[str, Any]) -> None:
|
async def _broadcast_ws(message: dict[str, Any], execution_id: str | None = None) -> None:
|
||||||
"""Broadcast a message to all WebSocket subscribers."""
|
"""Broadcast a message to WebSocket subscribers for a specific execution."""
|
||||||
|
async with _ws_subscribers_lock:
|
||||||
|
targets = set()
|
||||||
|
if execution_id and execution_id in _ws_subscribers:
|
||||||
|
targets = set(_ws_subscribers[execution_id]) # snapshot
|
||||||
disconnected = []
|
disconnected = []
|
||||||
for ws in _ws_subscribers:
|
for ws in targets:
|
||||||
try:
|
try:
|
||||||
await ws.send_json(message)
|
await ws.send_json(message)
|
||||||
except Exception:
|
except Exception:
|
||||||
disconnected.append(ws)
|
disconnected.append(ws)
|
||||||
for ws in disconnected:
|
if disconnected:
|
||||||
_ws_subscribers.remove(ws)
|
async with _ws_subscribers_lock:
|
||||||
|
for ws in disconnected:
|
||||||
|
for eid in list(_ws_subscribers.keys()):
|
||||||
|
_ws_subscribers[eid].discard(ws)
|
||||||
|
if not _ws_subscribers[eid]:
|
||||||
|
del _ws_subscribers[eid]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -494,7 +542,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth:
|
||||||
variables_schema=body.variables_schema,
|
variables_schema=body.variables_schema,
|
||||||
output_schema=body.output_schema,
|
output_schema=body.output_schema,
|
||||||
)
|
)
|
||||||
saved = store.save(workflow)
|
saved = await store.save(workflow)
|
||||||
return saved.model_dump()
|
return saved.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -527,7 +575,7 @@ async def update_workflow(
|
||||||
existing.variables_schema = body.variables_schema
|
existing.variables_schema = body.variables_schema
|
||||||
existing.output_schema = body.output_schema
|
existing.output_schema = body.output_schema
|
||||||
existing.version += 1
|
existing.version += 1
|
||||||
saved = store.save(existing)
|
saved = await store.save(existing)
|
||||||
return saved.model_dump()
|
return saved.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -535,7 +583,7 @@ async def update_workflow(
|
||||||
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
|
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
|
||||||
"""Delete a workflow."""
|
"""Delete a workflow."""
|
||||||
store = _get_store(request)
|
store = _get_store(request)
|
||||||
deleted = store.delete(workflow_id)
|
deleted = await store.delete(workflow_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
||||||
return {"message": "已删除"}
|
return {"message": "已删除"}
|
||||||
|
|
@ -551,14 +599,13 @@ async def execute_workflow(
|
||||||
if workflow is None:
|
if workflow is None:
|
||||||
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
raise HTTPException(status_code=404, detail=f"工作流 '{workflow_id}' 不存在")
|
||||||
|
|
||||||
execution = store.create_execution(workflow_id)
|
execution = await store.create_execution(workflow_id)
|
||||||
execution.variables = body.variables
|
execution.variables = body.variables
|
||||||
|
|
||||||
# Start execution in background
|
# Start execution in background
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
_execute_workflow(workflow, execution, body.variables, store=store)
|
_execute_workflow(workflow, execution, body.variables, store=store)
|
||||||
)
|
)
|
||||||
store._running_tasks = getattr(store, "_running_tasks", {})
|
|
||||||
store._running_tasks[execution.execution_id] = task
|
store._running_tasks[execution.execution_id] = task
|
||||||
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
|
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
|
||||||
|
|
||||||
|
|
@ -593,51 +640,60 @@ async def approve_execution(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||||
)
|
)
|
||||||
if execution.status != "paused":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="当前执行状态不是等待审批"
|
|
||||||
)
|
|
||||||
|
|
||||||
if body.approved:
|
exec_lock = store.get_execution_lock(execution_id)
|
||||||
if execution.current_stage:
|
async with exec_lock:
|
||||||
execution.stage_results[execution.current_stage] = {
|
# Re-fetch execution after acquiring lock
|
||||||
"status": "approved",
|
execution = store.get_execution(execution_id)
|
||||||
"approver": "user",
|
if execution is None:
|
||||||
"comment": body.comment,
|
raise HTTPException(
|
||||||
}
|
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||||
execution.status = "running"
|
)
|
||||||
store.update_execution(
|
if execution.status != "paused":
|
||||||
execution.execution_id,
|
raise HTTPException(
|
||||||
status="running",
|
status_code=400, detail="当前执行状态不是等待审批"
|
||||||
stage_results=execution.stage_results,
|
)
|
||||||
)
|
|
||||||
# Resume the waiting execution by setting the approval event
|
if body.approved:
|
||||||
stage_name = execution.current_stage
|
if execution.current_stage:
|
||||||
if stage_name:
|
execution.stage_results[execution.current_stage] = {
|
||||||
event_key = f"{execution_id}:{stage_name}"
|
"status": "approved",
|
||||||
if event_key in store._approval_events:
|
"approver": "user",
|
||||||
store._approval_events[event_key].set()
|
"comment": body.comment,
|
||||||
else:
|
}
|
||||||
execution.status = "cancelled"
|
execution.status = "running"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
await store.update_execution(
|
||||||
if execution.current_stage:
|
execution.execution_id,
|
||||||
execution.stage_results[execution.current_stage] = {
|
status="running",
|
||||||
"status": "rejected",
|
stage_results=execution.stage_results,
|
||||||
"approver": "user",
|
)
|
||||||
"comment": body.comment,
|
# Resume the waiting execution by setting the approval event
|
||||||
}
|
stage_name = execution.current_stage
|
||||||
store.update_execution(
|
if stage_name:
|
||||||
execution.execution_id,
|
event_key = f"{execution_id}:{stage_name}"
|
||||||
status="cancelled",
|
if event_key in store._approval_events:
|
||||||
completed_at=execution.completed_at,
|
store._approval_events[event_key].set()
|
||||||
stage_results=execution.stage_results,
|
else:
|
||||||
)
|
execution.status = "cancelled"
|
||||||
# Set the approval event so the waiting coroutine can observe the cancelled state
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
stage_name = execution.current_stage
|
if execution.current_stage:
|
||||||
if stage_name:
|
execution.stage_results[execution.current_stage] = {
|
||||||
event_key = f"{execution_id}:{stage_name}"
|
"status": "rejected",
|
||||||
if event_key in store._approval_events:
|
"approver": "user",
|
||||||
store._approval_events[event_key].set()
|
"comment": body.comment,
|
||||||
|
}
|
||||||
|
await store.update_execution(
|
||||||
|
execution.execution_id,
|
||||||
|
status="cancelled",
|
||||||
|
completed_at=execution.completed_at,
|
||||||
|
stage_results=execution.stage_results,
|
||||||
|
)
|
||||||
|
# Set the approval event so the waiting coroutine can observe the cancelled state
|
||||||
|
stage_name = execution.current_stage
|
||||||
|
if stage_name:
|
||||||
|
event_key = f"{execution_id}:{stage_name}"
|
||||||
|
if event_key in store._approval_events:
|
||||||
|
store._approval_events[event_key].set()
|
||||||
|
|
||||||
return execution.model_dump()
|
return execution.model_dump()
|
||||||
|
|
||||||
|
|
@ -651,23 +707,33 @@ async def cancel_execution(request: Request, execution_id: str, _auth: None = De
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||||
)
|
)
|
||||||
if execution.status not in ("running", "paused", "pending"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="当前执行状态无法取消"
|
|
||||||
)
|
|
||||||
|
|
||||||
execution.status = "cancelled"
|
exec_lock = store.get_execution_lock(execution_id)
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
async with exec_lock:
|
||||||
store.update_execution(
|
# Re-fetch execution after acquiring lock
|
||||||
execution.execution_id,
|
execution = store.get_execution(execution_id)
|
||||||
status="cancelled",
|
if execution is None:
|
||||||
completed_at=execution.completed_at,
|
raise HTTPException(
|
||||||
)
|
status_code=404, detail=f"执行记录 '{execution_id}' 不存在"
|
||||||
# Set any pending approval event so a paused workflow can observe the cancelled state
|
)
|
||||||
if hasattr(execution, "current_stage") and execution.current_stage:
|
if execution.status not in ("running", "paused", "pending"):
|
||||||
event_key = f"{execution_id}:{execution.current_stage}"
|
raise HTTPException(
|
||||||
if event_key in store._approval_events:
|
status_code=400, detail="当前执行状态无法取消"
|
||||||
store._approval_events[event_key].set()
|
)
|
||||||
|
|
||||||
|
execution.status = "cancelled"
|
||||||
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
await store.update_execution(
|
||||||
|
execution.execution_id,
|
||||||
|
status="cancelled",
|
||||||
|
completed_at=execution.completed_at,
|
||||||
|
)
|
||||||
|
# Set any pending approval event so a paused workflow can observe the cancelled state
|
||||||
|
if hasattr(execution, "current_stage") and execution.current_stage:
|
||||||
|
event_key = f"{execution_id}:{execution.current_stage}"
|
||||||
|
if event_key in store._approval_events:
|
||||||
|
store._approval_events[event_key].set()
|
||||||
|
|
||||||
return execution.model_dump()
|
return execution.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -683,21 +749,37 @@ async def workflow_websocket(websocket: WebSocket):
|
||||||
|
|
||||||
if configured_api_key:
|
if configured_api_key:
|
||||||
provided = websocket.query_params.get("api_key")
|
provided = websocket.query_params.get("api_key")
|
||||||
if provided != configured_api_key:
|
if not hmac.compare_digest((provided or "").encode(), configured_api_key.encode()):
|
||||||
await websocket.accept()
|
|
||||||
await websocket.send_json(
|
|
||||||
{"event": "error", "data": {"message": "Invalid or missing api_key"}}
|
|
||||||
)
|
|
||||||
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||||
return
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
_ws_subscribers.append(websocket)
|
|
||||||
|
# Determine execution_id from query params; may be None for backward compat
|
||||||
|
execution_id = websocket.query_params.get("execution_id")
|
||||||
|
subscribed_execution_id: str | None = None
|
||||||
|
|
||||||
|
if execution_id:
|
||||||
|
await _ws_subscribe(execution_id, websocket)
|
||||||
|
subscribed_execution_id = execution_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0)
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=120.0)
|
||||||
|
# Handle subscription messages
|
||||||
|
try:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
if isinstance(msg, dict) and msg.get("type") == "subscribe":
|
||||||
|
new_eid = msg.get("execution_id")
|
||||||
|
if new_eid:
|
||||||
|
# Unsubscribe from previous if any
|
||||||
|
if subscribed_execution_id:
|
||||||
|
await _ws_unsubscribe(subscribed_execution_id, websocket)
|
||||||
|
await _ws_subscribe(new_eid, websocket)
|
||||||
|
subscribed_execution_id = new_eid
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
pass
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
await websocket.close(code=1000, reason="Heartbeat timeout")
|
await websocket.close(code=1000, reason="Heartbeat timeout")
|
||||||
return
|
return
|
||||||
|
|
@ -707,5 +789,5 @@ async def workflow_websocket(websocket: WebSocket):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Workflow WebSocket error: {e}")
|
logger.error(f"Workflow WebSocket error: {e}")
|
||||||
finally:
|
finally:
|
||||||
if websocket in _ws_subscribers:
|
if subscribed_execution_id:
|
||||||
_ws_subscribers.remove(websocket)
|
await _ws_unsubscribe(subscribed_execution_id, websocket)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
@ -95,11 +96,14 @@ class ComputerUseTool(Tool):
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._request_timeout = request_timeout
|
self._request_timeout = request_timeout
|
||||||
self._http_client: httpx.AsyncClient | None = None
|
self._http_client: httpx.AsyncClient | None = None
|
||||||
|
self._client_lock = asyncio.Lock()
|
||||||
|
|
||||||
def _get_http_client(self) -> httpx.AsyncClient:
|
async def _get_http_client(self) -> httpx.AsyncClient:
|
||||||
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
|
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
|
||||||
if self._http_client is None or self._http_client.is_closed:
|
if self._http_client is None or self._http_client.is_closed:
|
||||||
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
|
async with self._client_lock:
|
||||||
|
if self._http_client is None or self._http_client.is_closed:
|
||||||
|
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
|
||||||
return self._http_client
|
return self._http_client
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|
@ -391,7 +395,7 @@ class ComputerUseTool(Tool):
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
client = self._get_http_client()
|
client = await self._get_http_client()
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self._api_base_url,
|
self._api_base_url,
|
||||||
json=request_body,
|
json=request_body,
|
||||||
|
|
|
||||||
|
|
@ -66,40 +66,33 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||||
"docker images",
|
"docker images",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 危险命令模式:这些命令需要人工确认
|
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
|
||||||
_DANGEROUS_PATTERNS: tuple[str, ...] = (
|
|
||||||
"rm ",
|
# 总是危险的二进制命令(无论参数)
|
||||||
"rm -",
|
_DANGEROUS_BINARIES: frozenset[str] = frozenset({
|
||||||
"rmdir",
|
"rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot",
|
||||||
"mkfs",
|
"halt", "killall", "chown", "fdisk", "parted",
|
||||||
"dd ",
|
})
|
||||||
"format",
|
|
||||||
"del ",
|
# 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合
|
||||||
"erase",
|
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
|
||||||
"> /dev/",
|
"rm": {"-rf", "-fr", "-r", "-f"},
|
||||||
"shutdown",
|
"kill": {"-9", "-kill"},
|
||||||
"reboot",
|
"chmod": {"777", "000"},
|
||||||
"init 0",
|
"git": {"push --force", "push -f", "reset --hard", "clean -f"},
|
||||||
"init 6",
|
"pip": {"uninstall"},
|
||||||
"kill -9",
|
"npm": {"uninstall"},
|
||||||
"killall",
|
"docker": {"rm", "rmi", "system prune"},
|
||||||
"chmod 777",
|
}
|
||||||
"chown",
|
|
||||||
"mv /",
|
# 跨 token 的危险模式(编译后的正则)
|
||||||
"pip uninstall",
|
_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
|
||||||
"npm uninstall",
|
re.compile(r">\s*/dev/", re.IGNORECASE),
|
||||||
"apt remove",
|
re.compile(r">\s*/etc/", re.IGNORECASE),
|
||||||
"yum remove",
|
re.compile(r"drop\s+table", re.IGNORECASE),
|
||||||
"brew uninstall",
|
re.compile(r"drop\s+database", re.IGNORECASE),
|
||||||
"docker rm",
|
re.compile(r"truncate\s+table", re.IGNORECASE),
|
||||||
"docker rmi",
|
]
|
||||||
"git push --force",
|
|
||||||
"git reset --hard",
|
|
||||||
"git clean -f",
|
|
||||||
"drop table",
|
|
||||||
"drop database",
|
|
||||||
"truncate",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
|
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
|
||||||
|
|
@ -300,10 +293,15 @@ class ShellTool(Tool):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
proc.kill()
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
logger.debug("Process already exited before kill()")
|
||||||
|
except OSError:
|
||||||
|
logger.debug("OSError killing process")
|
||||||
await proc.wait()
|
await proc.wait()
|
||||||
output = f"命令执行超时({timeout}s)"
|
output = f"命令执行超时({timeout}s)"
|
||||||
exit_code = -1
|
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||||
else:
|
else:
|
||||||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||||
exit_code = proc.returncode if proc.returncode is not None else 0
|
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||||
|
|
@ -394,10 +392,24 @@ class ShellTool(Tool):
|
||||||
if binary.lower() == prefix_stripped:
|
if binary.lower() == prefix_stripped:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Dangerous pattern check
|
# Dangerous pattern check — token-based matching
|
||||||
|
binary_lower = binary.lower()
|
||||||
|
|
||||||
|
# 1. Binary is always dangerous regardless of flags
|
||||||
|
if binary_lower in _DANGEROUS_BINARIES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 2. Binary is dangerous with specific flags/subcommands
|
||||||
|
if binary_lower in _DANGEROUS_BINARY_FLAGS:
|
||||||
|
cmd_str = " ".join(tokens).lower()
|
||||||
|
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
|
||||||
|
if flag_pattern in cmd_str:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 3. Cross-token dangerous patterns (regex)
|
||||||
command_lower = command_stripped.lower()
|
command_lower = command_stripped.lower()
|
||||||
for pattern in _DANGEROUS_PATTERNS:
|
for pattern in _DANGEROUS_ARG_PATTERNS:
|
||||||
if pattern in command_lower:
|
if pattern.search(command_lower):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return True # Unknown commands are dangerous by default
|
return True # Unknown commands are dangerous by default
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,16 @@
|
||||||
"""Security utilities for URL validation."""
|
"""Security utilities for URL validation."""
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import socket
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
_blocked_hostnames = {
|
||||||
|
"localhost",
|
||||||
|
"metadata.google.internal",
|
||||||
|
"metadata.internal",
|
||||||
|
"metadata.azure.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def is_safe_url(url: str) -> bool:
|
def is_safe_url(url: str) -> bool:
|
||||||
"""Check if URL is safe (not pointing to private/internal networks).
|
"""Check if URL is safe (not pointing to private/internal networks).
|
||||||
|
|
@ -20,13 +28,6 @@ def is_safe_url(url: str) -> bool:
|
||||||
hostname = parsed.hostname
|
hostname = parsed.hostname
|
||||||
if not hostname:
|
if not hostname:
|
||||||
return False
|
return False
|
||||||
# Block known internal/metadata hostnames
|
|
||||||
_blocked_hostnames = {
|
|
||||||
"localhost",
|
|
||||||
"metadata.google.internal",
|
|
||||||
"metadata.internal",
|
|
||||||
"metadata.azure.com",
|
|
||||||
}
|
|
||||||
hostname_lower = hostname.lower()
|
hostname_lower = hostname.lower()
|
||||||
if hostname_lower in _blocked_hostnames:
|
if hostname_lower in _blocked_hostnames:
|
||||||
return False
|
return False
|
||||||
|
|
@ -37,9 +38,15 @@ def is_safe_url(url: str) -> bool:
|
||||||
if _is_unsafe_ip(ip):
|
if _is_unsafe_ip(ip):
|
||||||
return False
|
return False
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# hostname is a domain, not a literal IP — DNS rebinding risk remains
|
# hostname is a domain — resolve DNS and check IPs
|
||||||
# (would need DNS resolution to fully mitigate)
|
try:
|
||||||
pass
|
addr_infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||||
|
for family, type_, proto, canonname, sockaddr in addr_infos:
|
||||||
|
ip = ipaddress.ip_address(sockaddr[0])
|
||||||
|
if _is_unsafe_ip(ip):
|
||||||
|
return False
|
||||||
|
except (socket.gaierror, socket.timeout, OSError):
|
||||||
|
return False # fail-closed on DNS errors
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
@ -58,3 +65,36 @@ def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||||
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
|
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def is_safe_url_async(url: str) -> bool:
|
||||||
|
"""Async variant of is_safe_url with DNS resolution."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
hostname_lower = hostname.lower()
|
||||||
|
if hostname_lower in _blocked_hostnames:
|
||||||
|
return False
|
||||||
|
if hostname_lower.endswith((".internal", ".local", ".localhost")):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
if _is_unsafe_ip(ip):
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
addr_infos = await asyncio.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||||
|
for family, type_, proto, canonname, sockaddr in addr_infos:
|
||||||
|
ip = ipaddress.ip_address(sockaddr[0])
|
||||||
|
if _is_unsafe_ip(ip):
|
||||||
|
return False
|
||||||
|
except (socket.gaierror, socket.timeout, OSError):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
@ -119,7 +121,7 @@ class TestWorkflowStore:
|
||||||
|
|
||||||
store = WorkflowStore()
|
store = WorkflowStore()
|
||||||
wf = WorkflowDefinition(workflow_id="test-1", name="Test")
|
wf = WorkflowDefinition(workflow_id="test-1", name="Test")
|
||||||
store.save(wf)
|
asyncio.run(store.save(wf))
|
||||||
result = store.get("test-1")
|
result = store.get("test-1")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.name == "Test"
|
assert result.name == "Test"
|
||||||
|
|
@ -131,36 +133,45 @@ class TestWorkflowStore:
|
||||||
def test_list(self):
|
def test_list(self):
|
||||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||||
|
|
||||||
store = WorkflowStore()
|
async def _run():
|
||||||
for i in range(3):
|
store = WorkflowStore()
|
||||||
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
for i in range(3):
|
||||||
summaries = store.list()
|
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||||
assert len(summaries) == 3
|
summaries = store.list()
|
||||||
|
assert len(summaries) == 3
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
def test_list_limit(self):
|
def test_list_limit(self):
|
||||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||||
|
|
||||||
store = WorkflowStore()
|
async def _run():
|
||||||
for i in range(5):
|
store = WorkflowStore()
|
||||||
store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
for i in range(5):
|
||||||
summaries = store.list(limit=2)
|
await store.save(WorkflowDefinition(workflow_id=f"wf-{i}", name=f"Workflow {i}"))
|
||||||
assert len(summaries) == 2
|
summaries = store.list(limit=2)
|
||||||
|
assert len(summaries) == 2
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
from agentkit.orchestrator.workflow_schema import WorkflowDefinition
|
||||||
|
|
||||||
store = WorkflowStore()
|
async def _run():
|
||||||
store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
|
store = WorkflowStore()
|
||||||
assert store.delete("del-1") is True
|
await store.save(WorkflowDefinition(workflow_id="del-1", name="Delete Me"))
|
||||||
assert store.get("del-1") is None
|
assert await store.delete("del-1") is True
|
||||||
|
assert store.get("del-1") is None
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
def test_delete_not_found(self):
|
def test_delete_not_found(self):
|
||||||
store = WorkflowStore()
|
store = WorkflowStore()
|
||||||
assert store.delete("nonexistent") is False
|
assert asyncio.run(store.delete("nonexistent")) is False
|
||||||
|
|
||||||
def test_create_and_get_execution(self):
|
def test_create_and_get_execution(self):
|
||||||
store = WorkflowStore()
|
store = WorkflowStore()
|
||||||
execution = store.create_execution("wf-1")
|
execution = asyncio.run(store.create_execution("wf-1"))
|
||||||
assert execution.workflow_id == "wf-1"
|
assert execution.workflow_id == "wf-1"
|
||||||
assert execution.status == "pending"
|
assert execution.status == "pending"
|
||||||
|
|
||||||
|
|
@ -173,18 +184,51 @@ class TestWorkflowStore:
|
||||||
assert store.get_execution("nonexistent") is None
|
assert store.get_execution("nonexistent") is None
|
||||||
|
|
||||||
def test_update_execution(self):
|
def test_update_execution(self):
|
||||||
store = WorkflowStore()
|
async def _run():
|
||||||
execution = store.create_execution("wf-1")
|
store = WorkflowStore()
|
||||||
updated = store.update_execution(
|
execution = await store.create_execution("wf-1")
|
||||||
execution.execution_id, status="running", current_stage="step-1"
|
updated = await store.update_execution(
|
||||||
)
|
execution.execution_id, status="running", current_stage="step-1"
|
||||||
assert updated.status == "running"
|
)
|
||||||
assert updated.current_stage == "step-1"
|
assert updated.status == "running"
|
||||||
|
assert updated.current_stage == "step-1"
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
def test_update_execution_not_found(self):
|
def test_update_execution_not_found(self):
|
||||||
store = WorkflowStore()
|
store = WorkflowStore()
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
store.update_execution("nonexistent", status="running")
|
asyncio.run(store.update_execution("nonexistent", status="running"))
|
||||||
|
|
||||||
|
def test_running_tasks_initialized(self):
|
||||||
|
store = WorkflowStore()
|
||||||
|
assert hasattr(store, "_running_tasks")
|
||||||
|
assert isinstance(store._running_tasks, dict)
|
||||||
|
assert len(store._running_tasks) == 0
|
||||||
|
|
||||||
|
def test_execution_locks_initialized(self):
|
||||||
|
store = WorkflowStore()
|
||||||
|
assert hasattr(store, "_execution_locks")
|
||||||
|
assert isinstance(store._execution_locks, dict)
|
||||||
|
|
||||||
|
def test_evict_execution_cleans_approval_events(self):
|
||||||
|
async def _run():
|
||||||
|
store = WorkflowStore(max_executions=2)
|
||||||
|
e1 = await store.create_execution("wf-1")
|
||||||
|
e2 = await store.create_execution("wf-2")
|
||||||
|
# Add an approval event for e1
|
||||||
|
event_key = f"{e1.execution_id}:stage1"
|
||||||
|
event = asyncio.Event()
|
||||||
|
store._approval_events[event_key] = event
|
||||||
|
# Create a third execution to trigger eviction of e1
|
||||||
|
e3 = await store.create_execution("wf-3")
|
||||||
|
# e1 should be evicted, and its approval event should be cleaned up
|
||||||
|
assert store.get_execution(e1.execution_id) is None
|
||||||
|
assert event_key not in store._approval_events
|
||||||
|
# The event should have been set (to wake any waiting coroutine)
|
||||||
|
assert event.is_set()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue