fix(review): resolve P0/P1 findings from final code review

This commit is contained in:
chiguyong 2026-06-10 09:57:29 +08:00
parent 1d1805753c
commit 658e188939
8 changed files with 66 additions and 39 deletions

View File

@ -293,6 +293,7 @@ class Orchestrator:
for i, defn in enumerate(subtask_defs):
depends_on = [
f"task-{i}" for i in defn.get("depends_on", [])
if isinstance(i, int) and 0 <= i < len(subtask_defs)
]
subtasks.append(SubTask(
task_id=f"task-{i}",
@ -799,7 +800,7 @@ class Orchestrator:
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
failed_count = sum(
1 for r in merged_results.values() if r.get("status") != "completed"
1 for r in merged_results.values() if r.get("status") == "failed"
)
if failed_count == len(merged_results):
status = TaskStatus.FAILED

View File

@ -494,7 +494,9 @@ class PlanExecutor:
if failed > 0:
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
if completed + skipped == total:
# 所有步骤要么完成要么跳过
return TaskStatus.COMPLETED
# 所有步骤要么完成要么跳过 — 至少需要一个完成才算成功
if completed > 0:
return TaskStatus.COMPLETED
return TaskStatus.FAILED # 全部跳过 = 没有实际完成
return TaskStatus.COMPLETED
return TaskStatus.PARTIALLY_COMPLETED

View File

@ -203,6 +203,11 @@ async def lifespan(app: FastAPI):
if mcp_manager is not None:
await mcp_manager.stop_all()
# Close Redis client for working memory
working_redis = getattr(app.state, "working_redis_client", None)
if working_redis is not None:
await working_redis.aclose()
if server_config is not None:
server_config.stop_watching()
@ -493,6 +498,7 @@ def create_app(
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
redis_client = aioredis.from_url(redis_url, decode_responses=True)
working = WorkingMemory(redis=redis_client)
app.state.working_redis_client = redis_client
if server_config.memory.get("semantic", {}).get("enabled"):
sem_conf = server_config.memory["semantic"]

View File

@ -414,7 +414,11 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
if left and not _SAFE_VAR_PATTERN.match(left):
raise ValueError(f"Invalid variable name in condition: {left}")
left_val = variables.get(left, left)
if left and left not in variables:
# Missing variable: treat as None/empty — condition evaluates to False
left_val = None
else:
left_val = variables.get(left, left)
# Strip quotes from right side if present
if right.startswith('"') and right.endswith('"'):
right_val = right[1:-1]
@ -469,7 +473,7 @@ async def _broadcast_ws(message: dict[str, Any]) -> None:
@router.get("/workflows")
async def list_workflows(request: Request, limit: int = 50):
async def list_workflows(request: Request, limit: int = 50, _auth: None = Depends(_verify_api_key)):
"""List all workflows."""
store = _get_store(request)
summaries = store.list(limit=limit)
@ -477,7 +481,7 @@ async def list_workflows(request: Request, limit: int = 50):
@router.post("/workflows", status_code=201)
async def create_workflow(request: Request, body: CreateWorkflowRequest):
async def create_workflow(request: Request, body: CreateWorkflowRequest, _auth: None = Depends(_verify_api_key)):
"""Create a new workflow."""
store = _get_store(request)
_validate_workflow_stages(body.stages)
@ -495,7 +499,7 @@ async def create_workflow(request: Request, body: CreateWorkflowRequest):
@router.get("/workflows/{workflow_id}")
async def get_workflow(request: Request, workflow_id: str):
async def get_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
"""Get a workflow by ID."""
store = _get_store(request)
workflow = store.get(workflow_id)
@ -506,7 +510,8 @@ async def get_workflow(request: Request, workflow_id: str):
@router.put("/workflows/{workflow_id}")
async def update_workflow(
request: Request, workflow_id: str, body: CreateWorkflowRequest
request: Request, workflow_id: str, body: CreateWorkflowRequest,
_auth: None = Depends(_verify_api_key),
):
"""Update an existing workflow."""
store = _get_store(request)
@ -527,7 +532,7 @@ async def update_workflow(
@router.delete("/workflows/{workflow_id}")
async def delete_workflow(request: Request, workflow_id: str):
async def delete_workflow(request: Request, workflow_id: str, _auth: None = Depends(_verify_api_key)):
"""Delete a workflow."""
store = _get_store(request)
deleted = store.delete(workflow_id)
@ -638,7 +643,7 @@ async def approve_execution(
@router.post("/workflows/executions/{execution_id}/cancel")
async def cancel_execution(request: Request, execution_id: str):
async def cancel_execution(request: Request, execution_id: str, _auth: None = Depends(_verify_api_key)):
"""Cancel a running execution."""
store = _get_store(request)
execution = store.get_execution(execution_id)

View File

@ -8,10 +8,8 @@
from __future__ import annotations
import base64
import ipaddress
import logging
from typing import Any, Callable, Awaitable
from urllib.parse import urlparse
import httpx
@ -43,26 +41,7 @@ _FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = {
}
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
from agentkit.utils.security import is_safe_url as _is_safe_url
class ComputerUseTool(Tool):

View File

@ -102,7 +102,7 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
)
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
class ShellTool(Tool):

View File

@ -5,7 +5,14 @@ from urllib.parse import urlparse
def is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
"""Check if URL is safe (not pointing to private/internal networks).
Validates against:
- Private/loopback/reserved/link-local/unspecified IPs
- IPv6-mapped IPv4 addresses (e.g. ::ffff:127.0.0.1)
- Common metadata/internal hostnames
- Non-HTTP schemes
"""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
@ -13,14 +20,41 @@ def is_safe_url(url: str) -> bool:
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
# Block known internal/metadata hostnames
_blocked_hostnames = {
"localhost",
"metadata.google.internal",
"metadata.internal",
"metadata.azure.com",
}
hostname_lower = hostname.lower()
if hostname_lower in _blocked_hostnames:
return False
if hostname_lower.endswith((".internal", ".local", ".localhost")):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
if _is_unsafe_ip(ip):
return False
except ValueError:
# hostname is a domain, not a literal IP — DNS rebinding risk remains
# (would need DNS resolution to fully mitigate)
pass
return True
except Exception:
return False
def _is_unsafe_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Check if an IP address is unsafe (private, loopback, etc.).
Handles IPv6-mapped IPv4 addresses by checking the embedded IPv4.
"""
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local or ip.is_unspecified:
return True
# Check IPv6-mapped IPv4 (e.g. ::ffff:127.0.0.1)
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
mapped = ip.ipv4_mapped
if mapped.is_private or mapped.is_loopback or mapped.is_reserved or mapped.is_link_local:
return True
return False

View File

@ -662,8 +662,8 @@ class TestPlanExecutorOverallStatus:
executor = PlanExecutor(agent_pool=pool, max_retries=0)
result = await executor.execute(plan, make_task())
# 默认策略 SKIP → 全部跳过 → COMPLETED
assert result.status == TaskStatus.COMPLETED
# 默认策略 SKIP → 全部跳过 → FAILED没有实际完成的步骤
assert result.status == TaskStatus.FAILED
@pytest.mark.asyncio
async def test_empty_plan(self):