fix: resolve key P2 findings from code review

- Shell whitelist: use exact binary match instead of startswith
- Shell audit log: use deque(maxlen=10000) to cap memory
- Terminal history: use deque(maxlen) for O(1) eviction
- Path optimizer: cap _pending_paths at 50 entries per task_type
- Pitfall detector: only add tips to matching steps, not all
- Experience store: handle non-numeric _parse_time_window input
- Extract shared is_safe_url() to utils/security.py (DRY)
- Workflow condition evaluator: handle float() ValueError
This commit is contained in:
chiguyong 2026-06-10 09:01:23 +08:00
parent b46a10973f
commit 1d1805753c
18 changed files with 271 additions and 177 deletions

View File

@ -186,7 +186,7 @@ class Orchestrator:
if failed_count == len(plan.subtasks):
status = TaskStatus.FAILED
elif failed_count > 0:
status = TaskStatus.COMPLETED # Partial success
status = TaskStatus.PARTIALLY_COMPLETED
else:
status = TaskStatus.COMPLETED
@ -804,7 +804,7 @@ class Orchestrator:
if failed_count == len(merged_results):
status = TaskStatus.FAILED
elif failed_count > 0:
status = TaskStatus.COMPLETED
status = TaskStatus.PARTIALLY_COMPLETED
else:
status = TaskStatus.COMPLETED

View File

@ -384,9 +384,8 @@ class PlanExecutor:
return "human"
if action == FailureAction.ABORT:
# 将失败步骤本身也标记为 SKIPPED
step.status = PlanStepStatus.SKIPPED
exec_result.status = PlanStepStatus.SKIPPED
# The failed step itself keeps FAILED status; only remaining PENDING steps are skipped
# (step.status and exec_result.status are already FAILED from _execute_step_with_retry)
# 中止所有后续步骤
self._abort_remaining_steps(step_map, step_results, plan)
return "adjusted"
@ -492,10 +491,10 @@ class PlanExecutor:
return TaskStatus.COMPLETED
if failed == total:
return TaskStatus.FAILED
if failed > 0:
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
if completed + skipped == total:
# 所有步骤要么完成要么跳过
return TaskStatus.COMPLETED
if failed > 0:
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
return TaskStatus.COMPLETED

View File

@ -497,7 +497,10 @@ def _parse_time_window(window: str) -> timedelta:
支持格式: "1h", "24h", "7d", "30d"
"""
unit = window[-1].lower()
value = int(window[:-1])
try:
value = int(window[:-1])
except ValueError:
return timedelta(hours=24)
if unit == "h":
return timedelta(hours=value)
elif unit == "d":

View File

@ -137,6 +137,8 @@ class PathOptimizer:
# 样本量不足 → 不更新,记录待观察
if new_path.sample_count < self._min_sample_count:
self._pending_paths.setdefault(task_type, []).append(new_path)
if len(self._pending_paths[task_type]) > 50:
self._pending_paths[task_type] = self._pending_paths[task_type][-50:]
reason = (
f"样本量不足({new_path.sample_count} < {self._min_sample_count}"
f"记录待观察"

View File

@ -211,7 +211,7 @@ class PitfallDetector:
if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
for step_name, s in stats.items():
if not experience_steps or step_name in experience_steps:
if experience_steps and step_name in experience_steps:
s.optimization_tips.extend(exp.optimization_tips)
return stats

View File

@ -6,15 +6,14 @@
from __future__ import annotations
import ipaddress
import logging
from typing import Any
from urllib.parse import urlparse
import httpx
from agentkit.memory.adapters.base import KBAdapter
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__)
@ -24,28 +23,6 @@ def _escape_cql(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"')
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class ConfluenceAdapter(KBAdapter):
"""Confluence 知识库适配器
@ -78,7 +55,7 @@ class ConfluenceAdapter(KBAdapter):
timeout=timeout,
)
self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._username = username
self._api_token = api_token

View File

@ -6,42 +6,19 @@
from __future__ import annotations
import ipaddress
import logging
import time
from typing import Any
from urllib.parse import urlparse
import httpx
from agentkit.memory.adapters.base import KBAdapter
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器
@ -76,7 +53,7 @@ class FeishuKBAdapter(KBAdapter):
self._app_id = app_id
self._app_secret = app_secret
self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
if not is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._space_ids = space_ids or []
self._access_token: str | None = None
@ -113,8 +90,8 @@ class FeishuKBAdapter(KBAdapter):
self._access_token = data.get("tenant_access_token")
expire_seconds = data.get("expire", 7200)
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
# 重建客户端以携带 token
await self.close()
# Invalidate cached client so it's rebuilt with the new token
self._client = None
return self._access_token
else:
logger.error(

View File

@ -6,44 +6,18 @@
from __future__ import annotations
import ipaddress
import logging
from typing import Any
from urllib.parse import urlparse
import httpx
from agentkit.memory.adapters.base import KBAdapter
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
from agentkit.utils.security import is_safe_url
logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
# Block common internal hostnames
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
# Try to resolve as IP and check for private ranges
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
# Not an IP address, that's OK (it's a domain name)
pass
return True
except Exception:
return False
class GenericHTTPAdapter(KBAdapter):
"""通用 HTTP 知识库适配器
@ -80,7 +54,7 @@ class GenericHTTPAdapter(KBAdapter):
timeout=timeout,
)
self._endpoint_url = endpoint_url.rstrip("/")
if not _is_safe_url(self._endpoint_url):
if not is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
self._auth_config = auth_config or {}
self._extra_headers = headers or {}

View File

@ -228,19 +228,27 @@ class LocalRAGService:
try:
from sqlalchemy import text as sql_text
for chunk in chunks:
# 生成嵌入
embedding = await self._embedder.embed(chunk.content)
# Batch embedding generation
embeddings: list[list[float]] = []
if hasattr(self._embedder, "embed_batch"):
embeddings = await self._embedder.embed_batch([c.content for c in chunks])
else:
for chunk in chunks:
embedding = await self._embedder.embed(chunk.content)
embeddings.append(embedding)
sql = sql_text(
f"INSERT INTO {self._table_name} "
f"(chunk_id, source_doc_id, source_title, doc_format, "
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
f"VALUES (:chunk_id, :doc_id, :title, :format, "
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
)
# Batch INSERT using executemany
sql = sql_text(
f"INSERT INTO {self._table_name} "
f"(chunk_id, source_doc_id, source_title, doc_format, "
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
f"VALUES (:chunk_id, :doc_id, :title, :format, "
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
)
await db.execute(sql, {
now = datetime.now(timezone.utc)
params_list = [
{
"chunk_id": chunk.chunk_id,
"doc_id": doc.doc_id,
"title": doc.title,
@ -249,9 +257,12 @@ class LocalRAGService:
"embedding": str(embedding),
"chunk_meta": json.dumps(chunk.metadata, ensure_ascii=False),
"doc_meta": json.dumps(doc.metadata, ensure_ascii=False),
"created_at": datetime.now(timezone.utc),
})
"created_at": now,
}
for chunk, embedding in zip(chunks, embeddings)
]
await db.execute(sql, params_list)
await db.commit()
except Exception as e:
await db.rollback()

View File

@ -10,7 +10,8 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, Security
from fastapi.security import APIKeyHeader, APIKeyQuery
from pydantic import BaseModel
from agentkit.core.protocol import TaskMessage
@ -21,6 +22,37 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["portal"])
# ---------------------------------------------------------------------------
# API Key Authentication
# ---------------------------------------------------------------------------
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
_api_key_query = APIKeyQuery(name="api_key", auto_error=False)
async def _verify_api_key(
request: Request,
api_key_header: str | None = Security(_api_key_header),
api_key_query: str | None = Security(_api_key_query),
) -> None:
"""Verify API key for REST endpoints. Raises HTTPException if invalid."""
configured_api_key: str | None = None
if hasattr(request.app.state, "server_config") and request.app.state.server_config:
configured_api_key = request.app.state.server_config.api_key
if configured_api_key is None and hasattr(request.app.state, "api_key"):
configured_api_key = request.app.state.api_key
# If no API key is configured, allow all requests (backwards compat)
if configured_api_key is None:
return
provided = api_key_header or api_key_query
if provided != configured_api_key:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
)
# ---------------------------------------------------------------------------
# In-memory Conversation Store
@ -241,7 +273,7 @@ async def _resolve_for_chat(
@router.post("/portal/chat", response_model=ChatResponse)
async def chat(request: ChatRequest, req: Request):
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
"""Send a chat message and get a response with intent routing."""
agent, skill, matched_skill, routing_method, confidence = await _resolve_for_chat(
request, req
@ -291,7 +323,7 @@ async def chat(request: ChatRequest, req: Request):
@router.post("/portal/chat/stream")
async def chat_stream(request: ChatRequest, req: Request):
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
"""Stream chat responses via SSE."""
from sse_starlette.sse import EventSourceResponse
@ -368,7 +400,7 @@ async def chat_stream(request: ChatRequest, req: Request):
@router.get("/portal/capabilities", response_model=CapabilitiesResponse)
async def get_capabilities(req: Request):
async def get_capabilities(req: Request, _auth: None = Depends(_verify_api_key)):
"""List all available capabilities with their status."""
skill_registry = req.app.state.skill_registry
all_skills = skill_registry.list_skills()
@ -399,7 +431,7 @@ async def get_capabilities(req: Request):
@router.get("/portal/conversations")
async def list_conversations(limit: int = 20):
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
"""List recent conversations."""
convs = _conversation_store.list_conversations(limit=limit)
return [
@ -414,7 +446,7 @@ async def list_conversations(limit: int = 20):
@router.get("/portal/conversations/{conversation_id}")
async def get_conversation(conversation_id: str, limit: int = 50):
async def get_conversation(conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)):
"""Get conversation history."""
history = _conversation_store.get_history(conversation_id, limit=limit)
if not history and conversation_id not in _conversation_store._conversations:

View File

@ -10,7 +10,8 @@ import uuid
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, Security
from fastapi.security import APIKeyHeader, APIKeyQuery
from agentkit.orchestrator.workflow_schema import (
ApproveRequest,
@ -26,6 +27,37 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["workflows"])
# ---------------------------------------------------------------------------
# API Key Authentication
# ---------------------------------------------------------------------------
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
_api_key_query = APIKeyQuery(name="api_key", auto_error=False)
async def _verify_api_key(
request: Request,
api_key_header: str | None = Security(_api_key_header),
api_key_query: str | None = Security(_api_key_query),
) -> None:
"""Verify API key for REST endpoints. Raises HTTPException if invalid."""
configured_api_key: str | None = None
if hasattr(request.app.state, "server_config") and request.app.state.server_config:
configured_api_key = request.app.state.server_config.api_key
if configured_api_key is None and hasattr(request.app.state, "api_key"):
configured_api_key = request.app.state.api_key
# If no API key is configured, allow all requests (backwards compat)
if configured_api_key is None:
return
provided = api_key_header or api_key_query
if provided != configured_api_key:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
)
# ---------------------------------------------------------------------------
# In-memory Workflow Store
@ -398,14 +430,19 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
return str(left_val) == str(right_val)
if op == "!=":
return str(left_val) != str(right_val)
try:
left_num = float(left_val)
right_num = float(right_val)
except (ValueError, TypeError):
return False
if op == ">":
return float(left_val) > float(right_val)
return left_num > right_num
if op == "<":
return float(left_val) < float(right_val)
return left_num < right_num
if op == ">=":
return float(left_val) >= float(right_val)
return left_num >= right_num
if op == "<=":
return float(left_val) <= float(right_val)
return left_num <= right_num
# Boolean check for variable existence
if _SAFE_VAR_PATTERN.match(expression):
@ -513,9 +550,12 @@ async def execute_workflow(
execution.variables = body.variables
# Start execution in background
asyncio.create_task(
task = asyncio.create_task(
_execute_workflow(workflow, execution, body.variables, store=store)
)
store._running_tasks = getattr(store, "_running_tasks", {})
store._running_tasks[execution.execution_id] = task
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
return {
"execution_id": execution.execution_id,
@ -538,7 +578,8 @@ async def get_execution(request: Request, execution_id: str):
@router.post("/workflows/executions/{execution_id}/approve")
async def approve_execution(
request: Request, execution_id: str, body: ApproveRequest
request: Request, execution_id: str, body: ApproveRequest,
_auth: None = Depends(_verify_api_key),
):
"""Approve a paused approval node."""
store = _get_store(request)
@ -617,6 +658,11 @@ async def cancel_execution(request: Request, execution_id: str):
status="cancelled",
completed_at=execution.completed_at,
)
# Set any pending approval event so a paused workflow can observe the cancelled state
if hasattr(execution, "current_stage") and execution.current_stage:
event_key = f"{execution_id}:{execution.current_stage}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump()

View File

@ -8,8 +8,10 @@
from __future__ import annotations
import base64
import ipaddress
import logging
from typing import Any, Callable, Awaitable
from urllib.parse import urlparse
import httpx
@ -41,6 +43,28 @@ _FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = {
}
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class ComputerUseTool(Tool):
"""Computer Use 工具
@ -82,6 +106,8 @@ class ComputerUseTool(Tool):
self._api_key = api_key
self._model = model
self._api_base_url = api_base_url
if not _is_safe_url(self._api_base_url):
raise ValueError(f"Unsafe api_base_url: {self._api_base_url}. Private/internal URLs are not allowed.")
self._session_manager = ComputerUseSessionManager(
session_factory=session_factory or InMemoryComputerUseSession,
)
@ -89,6 +115,18 @@ class ComputerUseTool(Tool):
self._fallback_callback = fallback_callback
self._max_retries = max_retries
self._request_timeout = request_timeout
self._http_client: httpx.AsyncClient | None = None
def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
if self._http_client is None or self._http_client.is_closed:
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
return self._http_client
async def close(self) -> None:
"""Close the persistent HTTP client."""
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
@staticmethod
def _default_input_schema() -> dict[str, Any]:
@ -374,47 +412,47 @@ class ComputerUseTool(Tool):
"content-type": "application/json",
}
async with httpx.AsyncClient(timeout=self._request_timeout) as client:
response = await client.post(
self._api_base_url,
json=request_body,
headers=headers,
client = self._get_http_client()
response = await client.post(
self._api_base_url,
json=request_body,
headers=headers,
)
if response.status_code != 200:
error_detail = response.text[:500]
return ActionResult(
success=False,
action=action,
error=f"Anthropic API error {response.status_code}: {error_detail}",
)
if response.status_code != 200:
error_detail = response.text[:500]
data = response.json()
# 解析 API 响应中的 tool_use 内容
for block in data.get("content", []):
if block.get("type") == "tool_use" and block.get("name") == "computer":
tool_input_resp = block.get("input", {})
resp_action = tool_input_resp.get("action", action)
return ActionResult(
success=False,
action=action,
error=f"Anthropic API error {response.status_code}: {error_detail}",
success=True,
action=resp_action,
output=f"API executed: {resp_action}",
metadata={"api_response": data},
)
data = response.json()
# API 没有返回 tool_use可能是纯文本响应
text_output = ""
for block in data.get("content", []):
if block.get("type") == "text":
text_output += block.get("text", "")
# 解析 API 响应中的 tool_use 内容
for block in data.get("content", []):
if block.get("type") == "tool_use" and block.get("name") == "computer":
tool_input_resp = block.get("input", {})
resp_action = tool_input_resp.get("action", action)
return ActionResult(
success=True,
action=resp_action,
output=f"API executed: {resp_action}",
metadata={"api_response": data},
)
# API 没有返回 tool_use可能是纯文本响应
text_output = ""
for block in data.get("content", []):
if block.get("type") == "text":
text_output += block.get("text", "")
return ActionResult(
success=True,
action=action,
output=text_output[:500] if text_output else "API call completed",
metadata={"api_response": data},
)
return ActionResult(
success=True,
action=action,
output=text_output[:500] if text_output else "API call completed",
metadata={"api_response": data},
)
def _validate_params(self, action: str, kwargs: dict[str, Any]) -> str | None:
"""验证操作参数

View File

@ -166,6 +166,9 @@ class PTYSession:
try:
exit_code = await self._read_until_exit(timeout)
except asyncio.TimeoutError:
if self._process and self._process.returncode is None:
self._process.kill()
await self._process.wait()
self._output_buffer += "\n[PTY 命令执行超时]"
return PTYOutput(
output=self._output_buffer,

View File

@ -12,6 +12,7 @@ import os
import re
import shlex
import time
from collections import deque
from typing import Any, Callable, Awaitable
from agentkit.tools.base import Tool
@ -26,9 +27,6 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"cd",
"export",
"ls",
"cat",
"head",
"tail",
"grep",
"find",
"pwd",
@ -66,8 +64,6 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"npm list",
"docker ps",
"docker images",
"curl",
"wget",
)
# 危险命令模式:这些命令需要人工确认
@ -155,7 +151,7 @@ class ShellTool(Tool):
self._confirm_callback = confirm_callback
self._default_timeout = default_timeout
self._max_output_length = max_output_length
self._audit_log: list[dict[str, Any]] = []
self._audit_log: deque[dict[str, Any]] = deque(maxlen=10000)
@staticmethod
def _default_input_schema() -> dict[str, Any]:
@ -394,8 +390,8 @@ class ShellTool(Tool):
if command_stripped.lower().startswith(prefix_stripped):
return False
else:
# Simple prefix - match against binary name only
if binary.lower().startswith(prefix_stripped):
# Simple prefix - match against binary name exactly
if binary.lower() == prefix_stripped:
return False
# Dangerous pattern check

View File

@ -12,6 +12,7 @@ import os
import re
import shlex
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any
@ -65,8 +66,7 @@ class TerminalSession:
self.session_id = session_id
self._cwd = cwd or os.getcwd()
self._env: dict[str, str] = dict(env or os.environ)
self._history: list[CommandRecord] = []
self._max_history = max_history
self._history: deque[CommandRecord] = deque(maxlen=max_history)
self._output_parser = OutputParser()
self._created_at = time.time()
@ -271,10 +271,8 @@ class TerminalSession:
self._env[key] = value
def _add_history(self, record: CommandRecord) -> None:
"""添加命令记录到历史,超出上限时移除最旧记录"""
"""添加命令记录到历史,deque maxlen 自动淘汰最旧记录"""
self._history.append(record)
while len(self._history) > self._max_history:
self._history.pop(0)
def close(self) -> None:
"""关闭会话,清理资源"""

View File

@ -0,0 +1,26 @@
"""Security utilities for URL validation."""
import ipaddress
from urllib.parse import urlparse
def is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False

View File

@ -2,29 +2,39 @@
from __future__ import annotations
import math
import logging
logger = logging.getLogger(__name__)
try:
import numpy as np
_HAS_NUMPY = True
except ImportError:
_HAS_NUMPY = False
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""Compute cosine similarity between two vectors.
Args:
vec_a: First vector.
vec_b: Second vector.
Returns:
Cosine similarity score between -1 and 1.
Uses numpy for performance when available, falls back to pure Python.
"""
if len(vec_a) != len(vec_b):
return 0.0
if not vec_a:
logger.warning("Vector length mismatch: %d vs %d, returning 0.0", len(vec_a), len(vec_b))
return 0.0
if _HAS_NUMPY:
a = np.array(vec_a, dtype=np.float64)
b = np.array(vec_b, dtype=np.float64)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
# Pure Python fallback
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
norm_a = sum(a * a for a in vec_a) ** 0.5
norm_b = sum(b * b for b in vec_b) ** 0.5
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return dot_product / (norm_a * norm_b)

View File

@ -442,7 +442,9 @@ class TestPlanExecutorPlanAdjustment:
executor = PlanExecutor(agent_pool=pool, max_retries=0, on_step_failed=on_failed)
result = await executor.execute(plan, make_task())
assert result.step_results["s0"].status == PlanStepStatus.SKIPPED
# The failed step should remain FAILED (not SKIPPED)
assert result.step_results["s0"].status == PlanStepStatus.FAILED
# Remaining steps should be SKIPPED
assert result.step_results["s1"].status == PlanStepStatus.SKIPPED
assert result.adjusted is True