fix(review): resolve P0/P1 findings from final code review
This commit is contained in:
parent
1d1805753c
commit
658e188939
|
|
@ -293,6 +293,7 @@ class Orchestrator:
|
|||
for i, defn in enumerate(subtask_defs):
|
||||
depends_on = [
|
||||
f"task-{i}" for i in defn.get("depends_on", [])
|
||||
if isinstance(i, int) and 0 <= i < len(subtask_defs)
|
||||
]
|
||||
subtasks.append(SubTask(
|
||||
task_id=f"task-{i}",
|
||||
|
|
@ -799,7 +800,7 @@ class Orchestrator:
|
|||
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
|
||||
|
||||
failed_count = sum(
|
||||
1 for r in merged_results.values() if r.get("status") != "completed"
|
||||
1 for r in merged_results.values() if r.get("status") == "failed"
|
||||
)
|
||||
if failed_count == len(merged_results):
|
||||
status = TaskStatus.FAILED
|
||||
|
|
|
|||
|
|
@ -494,7 +494,9 @@ class PlanExecutor:
|
|||
if failed > 0:
|
||||
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||||
if completed + skipped == total:
|
||||
# 所有步骤要么完成要么跳过
|
||||
return TaskStatus.COMPLETED
|
||||
# 所有步骤要么完成要么跳过 — 至少需要一个完成才算成功
|
||||
if completed > 0:
|
||||
return TaskStatus.COMPLETED
|
||||
return TaskStatus.FAILED # 全部跳过 = 没有实际完成
|
||||
|
||||
return TaskStatus.COMPLETED
|
||||
return TaskStatus.PARTIALLY_COMPLETED
|
||||
|
|
|
|||
|
|
@ -203,6 +203,11 @@ async def lifespan(app: FastAPI):
|
|||
if mcp_manager is not None:
|
||||
await mcp_manager.stop_all()
|
||||
|
||||
# Close Redis client for working memory
|
||||
working_redis = getattr(app.state, "working_redis_client", None)
|
||||
if working_redis is not None:
|
||||
await working_redis.aclose()
|
||||
|
||||
if server_config is not None:
|
||||
server_config.stop_watching()
|
||||
|
||||
|
|
@ -493,6 +498,7 @@ def create_app(
|
|||
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||
working = WorkingMemory(redis=redis_client)
|
||||
app.state.working_redis_client = redis_client
|
||||
|
||||
if server_config.memory.get("semantic", {}).get("enabled"):
|
||||
sem_conf = server_config.memory["semantic"]
|
||||
|
|
|
|||
|
|
@ -414,7 +414,11 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
|||
if left and not _SAFE_VAR_PATTERN.match(left):
|
||||
raise ValueError(f"Invalid variable name in condition: {left}")
|
||||
|
||||
left_val = variables.get(left, left)
|
||||
if left and left not in variables:
|
||||
# Missing variable: treat as None/empty — condition evaluates to False
|
||||
left_val = None
|
||||
else:
|
||||
left_val = variables.get(left, left)
|
||||
# Strip quotes from right side if present
|
||||
if right.startswith('"') and right.endswith('"'):
|
||||
right_val = right[1:-1]
|
||||
|
|
@ -469,7 +473,7 @@ async def _broadcast_ws(message: dict[str, Any]) -> None:
|
|||
|
||||
|
||||
@router.get("/workflows")
|
||||
async def list_workflows(request: Request, limit: int = 50):
|
||||
async def list_workflows(request: Request, limit: int = 50, _auth: None = Depends(_verify_api_key)):
|
||||
"""List all workflows."""
|
||||
store = _get_store(request)
|
||||
summaries = store.list(limit=limit)
|
||||
|
|
@ -477,7 +481,7 @@ async def list_workflows(request: Request, limit: int = 50):
|
|||
|
||||
|
||||
@router.post("/workflows", status_code=201)
|
||||
async def create_workflow(request: Request, body: CreateWorkflowRequest):
|
||||
async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth: None = Depends(_verify_api_key)):
|
||||
"""Create a new workflow."""
|
||||
store = _get_store(request)
|
||||
_validate_workflow_stages(body.stages)
|
||||
|
|
@ -495,7 +499,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest):
|
|||
|
||||
|
||||
@router.get("/workflows/{workflow_id}")
|
||||
async def get_workflow(request: Request, workflow_id: str):
|
||||
async def get_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
|
||||
"""Get a workflow by ID."""
|
||||
store = _get_store(request)
|
||||
workflow = store.get(workflow_id)
|
||||
|
|
@ -506,7 +510,8 @@ async def get_workflow(request: Request, workflow_id: str):
|
|||
|
||||
@router.put("/workflows/{workflow_id}")
|
||||
async def update_workflow(
|
||||
request: Request, workflow_id: str, body: CreateWorkflowRequest
|
||||
request: Request, workflow_id: str, body: CreateWorkflowRequest,
|
||||
_auth: None = Depends(_verify_api_key),
|
||||
):
|
||||
"""Update an existing workflow."""
|
||||
store = _get_store(request)
|
||||
|
|
@ -527,7 +532,7 @@ async def update_workflow(
|
|||
|
||||
|
||||
@router.delete("/workflows/{workflow_id}")
|
||||
async def delete_workflow(request: Request, workflow_id: str):
|
||||
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
|
||||
"""Delete a workflow."""
|
||||
store = _get_store(request)
|
||||
deleted = store.delete(workflow_id)
|
||||
|
|
@ -638,7 +643,7 @@ async def approve_execution(
|
|||
|
||||
|
||||
@router.post("/workflows/executions/{execution_id}/cancel")
|
||||
async def cancel_execution(request: Request, execution_id: str):
|
||||
async def cancel_execution(request: Request, execution_id: str, _auth: None = Depends(_verify_api_key)):
|
||||
"""Cancel a running execution."""
|
||||
store = _get_store(request)
|
||||
execution = store.get_execution(execution_id)
|
||||
|
|
|
|||
|
|
@ -8,10 +8,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -43,26 +41,7 @@ _FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = {
|
|||
}
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
from agentkit.utils.security import is_safe_url as _is_safe_url
|
||||
|
||||
|
||||
class ComputerUseTool(Tool):
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
|
|||
)
|
||||
|
||||
|
||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
|
||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
|
||||
|
||||
|
||||
class ShellTool(Tool):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,14 @@ from urllib.parse import urlparse
|
|||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
"""Check if URL is safe (not pointing to private/internal networks).
|
||||
|
||||
Validates against:
|
||||
- Private/loopback/reserved/link-local/unspecified IPs
|
||||
- IPv6-mapped IPv4 addresses (e.g. ::ffff:127.0.0.1)
|
||||
- Common metadata/internal hostnames
|
||||
- Non-HTTP schemes
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
|
|
@ -13,14 +20,41 @@ def is_safe_url(url: str) -> bool:
|
|||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
# Block known internal/metadata hostnames
|
||||
_blocked_hostnames = {
|
||||
"localhost",
|
||||
"metadata.google.internal",
|
||||
"metadata.internal",
|
||||
"metadata.azure.com",
|
||||
}
|
||||
hostname_lower = hostname.lower()
|
||||
if hostname_lower in _blocked_hostnames:
|
||||
return False
|
||||
if hostname_lower.endswith((".internal", ".local", ".localhost")):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
if _is_unsafe_ip(ip):
|
||||
return False
|
||||
except ValueError:
|
||||
# hostname is a domain, not a literal IP — DNS rebinding risk remains
|
||||
# (would need DNS resolution to fully mitigate)
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Check if an IP address is unsafe (private, loopback, etc.).
|
||||
|
||||
Handles IPv6-mapped IPv4 addresses by checking the embedded IPv4.
|
||||
"""
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local or ip.is_unspecified:
|
||||
return True
|
||||
# Check IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1)
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
|
||||
mapped = ip.ipv4_mapped
|
||||
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -662,8 +662,8 @@ class TestPlanExecutorOverallStatus:
|
|||
executor = PlanExecutor(agent_pool=pool, max_retries=0)
|
||||
result = await executor.execute(plan, make_task())
|
||||
|
||||
# 默认策略 SKIP → 全部跳过 → COMPLETED
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
# 默认策略 SKIP → 全部跳过 → FAILED(没有实际完成的步骤)
|
||||
assert result.status == TaskStatus.FAILED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_plan(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue