fix: resolve key P2 findings from code review
- Shell whitelist: use exact binary match instead of startswith - Shell audit log: use deque(maxlen=10000) to cap memory - Terminal history: use deque(maxlen) for O(1) eviction - Path optimizer: cap _pending_paths at 50 entries per task_type - Pitfall detector: only add tips to matching steps, not all - Experience store: handle non-numeric _parse_time_window input - Extract shared is_safe_url() to utils/security.py (DRY) - Workflow condition evaluator: handle float() ValueError
This commit is contained in:
parent
b46a10973f
commit
1d1805753c
|
|
@ -186,7 +186,7 @@ class Orchestrator:
|
|||
if failed_count == len(plan.subtasks):
|
||||
status = TaskStatus.FAILED
|
||||
elif failed_count > 0:
|
||||
status = TaskStatus.COMPLETED # Partial success
|
||||
status = TaskStatus.PARTIALLY_COMPLETED
|
||||
else:
|
||||
status = TaskStatus.COMPLETED
|
||||
|
||||
|
|
@ -804,7 +804,7 @@ class Orchestrator:
|
|||
if failed_count == len(merged_results):
|
||||
status = TaskStatus.FAILED
|
||||
elif failed_count > 0:
|
||||
status = TaskStatus.COMPLETED
|
||||
status = TaskStatus.PARTIALLY_COMPLETED
|
||||
else:
|
||||
status = TaskStatus.COMPLETED
|
||||
|
||||
|
|
|
|||
|
|
@ -384,9 +384,8 @@ class PlanExecutor:
|
|||
return "human"
|
||||
|
||||
if action == FailureAction.ABORT:
|
||||
# 将失败步骤本身也标记为 SKIPPED
|
||||
step.status = PlanStepStatus.SKIPPED
|
||||
exec_result.status = PlanStepStatus.SKIPPED
|
||||
# The failed step itself keeps FAILED status; only remaining PENDING steps are skipped
|
||||
# (step.status and exec_result.status are already FAILED from _execute_step_with_retry)
|
||||
# 中止所有后续步骤
|
||||
self._abort_remaining_steps(step_map, step_results, plan)
|
||||
return "adjusted"
|
||||
|
|
@ -492,10 +491,10 @@ class PlanExecutor:
|
|||
return TaskStatus.COMPLETED
|
||||
if failed == total:
|
||||
return TaskStatus.FAILED
|
||||
if failed > 0:
|
||||
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||||
if completed + skipped == total:
|
||||
# 所有步骤要么完成要么跳过
|
||||
return TaskStatus.COMPLETED
|
||||
if failed > 0:
|
||||
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||||
|
||||
return TaskStatus.COMPLETED
|
||||
|
|
|
|||
|
|
@ -497,7 +497,10 @@ def _parse_time_window(window: str) -> timedelta:
|
|||
支持格式: "1h", "24h", "7d", "30d"
|
||||
"""
|
||||
unit = window[-1].lower()
|
||||
value = int(window[:-1])
|
||||
try:
|
||||
value = int(window[:-1])
|
||||
except ValueError:
|
||||
return timedelta(hours=24)
|
||||
if unit == "h":
|
||||
return timedelta(hours=value)
|
||||
elif unit == "d":
|
||||
|
|
|
|||
|
|
@ -137,6 +137,8 @@ class PathOptimizer:
|
|||
# 样本量不足 → 不更新,记录待观察
|
||||
if new_path.sample_count < self._min_sample_count:
|
||||
self._pending_paths.setdefault(task_type, []).append(new_path)
|
||||
if len(self._pending_paths[task_type]) > 50:
|
||||
self._pending_paths[task_type] = self._pending_paths[task_type][-50:]
|
||||
reason = (
|
||||
f"样本量不足({new_path.sample_count} < {self._min_sample_count}),"
|
||||
f"记录待观察"
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class PitfallDetector:
|
|||
if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
|
||||
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
|
||||
for step_name, s in stats.items():
|
||||
if not experience_steps or step_name in experience_steps:
|
||||
if experience_steps and step_name in experience_steps:
|
||||
s.optimization_tips.extend(exp.optimization_tips)
|
||||
|
||||
return stats
|
||||
|
|
|
|||
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.memory.adapters.base import KBAdapter
|
||||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||
from agentkit.utils.security import is_safe_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,28 +23,6 @@ def _escape_cql(value: str) -> str:
|
|||
return value.replace("\\", "\\\\").replace('"', '\\"')
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ConfluenceAdapter(KBAdapter):
|
||||
"""Confluence 知识库适配器
|
||||
|
||||
|
|
@ -78,7 +55,7 @@ class ConfluenceAdapter(KBAdapter):
|
|||
timeout=timeout,
|
||||
)
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not _is_safe_url(self._base_url):
|
||||
if not is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
self._username = username
|
||||
self._api_token = api_token
|
||||
|
|
|
|||
|
|
@ -6,42 +6,19 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.memory.adapters.base import KBAdapter
|
||||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||
from agentkit.utils.security import is_safe_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class FeishuKBAdapter(KBAdapter):
|
||||
"""飞书知识库适配器
|
||||
|
||||
|
|
@ -76,7 +53,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
self._app_id = app_id
|
||||
self._app_secret = app_secret
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not _is_safe_url(self._base_url):
|
||||
if not is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
self._space_ids = space_ids or []
|
||||
self._access_token: str | None = None
|
||||
|
|
@ -113,8 +90,8 @@ class FeishuKBAdapter(KBAdapter):
|
|||
self._access_token = data.get("tenant_access_token")
|
||||
expire_seconds = data.get("expire", 7200)
|
||||
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
|
||||
# 重建客户端以携带 token
|
||||
await self.close()
|
||||
# Invalidate cached client so it's rebuilt with the new token
|
||||
self._client = None
|
||||
return self._access_token
|
||||
else:
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -6,44 +6,18 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.memory.adapters.base import KBAdapter
|
||||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||
from agentkit.utils.security import is_safe_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
# Block common internal hostnames
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
# Try to resolve as IP and check for private ranges
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
# Not an IP address, that's OK (it's a domain name)
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class GenericHTTPAdapter(KBAdapter):
|
||||
"""通用 HTTP 知识库适配器
|
||||
|
||||
|
|
@ -80,7 +54,7 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
timeout=timeout,
|
||||
)
|
||||
self._endpoint_url = endpoint_url.rstrip("/")
|
||||
if not _is_safe_url(self._endpoint_url):
|
||||
if not is_safe_url(self._endpoint_url):
|
||||
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
||||
self._auth_config = auth_config or {}
|
||||
self._extra_headers = headers or {}
|
||||
|
|
|
|||
|
|
@ -228,19 +228,27 @@ class LocalRAGService:
|
|||
try:
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
for chunk in chunks:
|
||||
# 生成嵌入
|
||||
embedding = await self._embedder.embed(chunk.content)
|
||||
# Batch embedding generation
|
||||
embeddings: list[list[float]] = []
|
||||
if hasattr(self._embedder, "embed_batch"):
|
||||
embeddings = await self._embedder.embed_batch([c.content for c in chunks])
|
||||
else:
|
||||
for chunk in chunks:
|
||||
embedding = await self._embedder.embed(chunk.content)
|
||||
embeddings.append(embedding)
|
||||
|
||||
sql = sql_text(
|
||||
f"INSERT INTO {self._table_name} "
|
||||
f"(chunk_id, source_doc_id, source_title, doc_format, "
|
||||
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
|
||||
f"VALUES (:chunk_id, :doc_id, :title, :format, "
|
||||
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
|
||||
)
|
||||
# Batch INSERT using executemany
|
||||
sql = sql_text(
|
||||
f"INSERT INTO {self._table_name} "
|
||||
f"(chunk_id, source_doc_id, source_title, doc_format, "
|
||||
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
|
||||
f"VALUES (:chunk_id, :doc_id, :title, :format, "
|
||||
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
|
||||
)
|
||||
|
||||
await db.execute(sql, {
|
||||
now = datetime.now(timezone.utc)
|
||||
params_list = [
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"doc_id": doc.doc_id,
|
||||
"title": doc.title,
|
||||
|
|
@ -249,9 +257,12 @@ class LocalRAGService:
|
|||
"embedding": str(embedding),
|
||||
"chunk_meta": json.dumps(chunk.metadata, ensure_ascii=False),
|
||||
"doc_meta": json.dumps(doc.metadata, ensure_ascii=False),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
})
|
||||
"created_at": now,
|
||||
}
|
||||
for chunk, embedding in zip(chunks, embeddings)
|
||||
]
|
||||
|
||||
await db.execute(sql, params_list)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, Security
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
|
|
@ -21,6 +22,37 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(tags=["portal"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API Key Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
_api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
|
||||
|
||||
async def _verify_api_key(
|
||||
request: Request,
|
||||
api_key_header: str | None = Security(_api_key_header),
|
||||
api_key_query: str | None = Security(_api_key_query),
|
||||
) -> None:
|
||||
"""Verify API key for REST endpoints. Raises HTTPException if invalid."""
|
||||
configured_api_key: str | None = None
|
||||
if hasattr(request.app.state, "server_config") and request.app.state.server_config:
|
||||
configured_api_key = request.app.state.server_config.api_key
|
||||
if configured_api_key is None and hasattr(request.app.state, "api_key"):
|
||||
configured_api_key = request.app.state.api_key
|
||||
|
||||
# If no API key is configured, allow all requests (backwards compat)
|
||||
if configured_api_key is None:
|
||||
return
|
||||
|
||||
provided = api_key_header or api_key_query
|
||||
if provided != configured_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory Conversation Store
|
||||
|
|
@ -241,7 +273,7 @@ async def _resolve_for_chat(
|
|||
|
||||
|
||||
@router.post("/portal/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest, req: Request):
|
||||
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||
"""Send a chat message and get a response with intent routing."""
|
||||
agent, skill, matched_skill, routing_method, confidence = await _resolve_for_chat(
|
||||
request, req
|
||||
|
|
@ -291,7 +323,7 @@ async def chat(request: ChatRequest, req: Request):
|
|||
|
||||
|
||||
@router.post("/portal/chat/stream")
|
||||
async def chat_stream(request: ChatRequest, req: Request):
|
||||
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||
"""Stream chat responses via SSE."""
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
|
@ -368,7 +400,7 @@ async def chat_stream(request: ChatRequest, req: Request):
|
|||
|
||||
|
||||
@router.get("/portal/capabilities", response_model=CapabilitiesResponse)
|
||||
async def get_capabilities(req: Request):
|
||||
async def get_capabilities(req: Request, _auth: None = Depends(_verify_api_key)):
|
||||
"""List all available capabilities with their status."""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
all_skills = skill_registry.list_skills()
|
||||
|
|
@ -399,7 +431,7 @@ async def get_capabilities(req: Request):
|
|||
|
||||
|
||||
@router.get("/portal/conversations")
|
||||
async def list_conversations(limit: int = 20):
|
||||
async def list_conversations(limit: int = 20, _auth: None = Depends(_verify_api_key)):
|
||||
"""List recent conversations."""
|
||||
convs = _conversation_store.list_conversations(limit=limit)
|
||||
return [
|
||||
|
|
@ -414,7 +446,7 @@ async def list_conversations(limit: int = 20):
|
|||
|
||||
|
||||
@router.get("/portal/conversations/{conversation_id}")
|
||||
async def get_conversation(conversation_id: str, limit: int = 50):
|
||||
async def get_conversation(conversation_id: str, limit: int = 50, _auth: None = Depends(_verify_api_key)):
|
||||
"""Get conversation history."""
|
||||
history = _conversation_store.get_history(conversation_id, limit=limit)
|
||||
if not history and conversation_id not in _conversation_store._conversations:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ import uuid
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, Security
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
|
||||
from agentkit.orchestrator.workflow_schema import (
|
||||
ApproveRequest,
|
||||
|
|
@ -26,6 +27,37 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(tags=["workflows"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API Key Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
_api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
|
||||
|
||||
async def _verify_api_key(
|
||||
request: Request,
|
||||
api_key_header: str | None = Security(_api_key_header),
|
||||
api_key_query: str | None = Security(_api_key_query),
|
||||
) -> None:
|
||||
"""Verify API key for REST endpoints. Raises HTTPException if invalid."""
|
||||
configured_api_key: str | None = None
|
||||
if hasattr(request.app.state, "server_config") and request.app.state.server_config:
|
||||
configured_api_key = request.app.state.server_config.api_key
|
||||
if configured_api_key is None and hasattr(request.app.state, "api_key"):
|
||||
configured_api_key = request.app.state.api_key
|
||||
|
||||
# If no API key is configured, allow all requests (backwards compat)
|
||||
if configured_api_key is None:
|
||||
return
|
||||
|
||||
provided = api_key_header or api_key_query
|
||||
if provided != configured_api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or missing API key. Provide via X-API-Key header or api_key query parameter.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory Workflow Store
|
||||
|
|
@ -398,14 +430,19 @@ def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
|||
return str(left_val) == str(right_val)
|
||||
if op == "!=":
|
||||
return str(left_val) != str(right_val)
|
||||
try:
|
||||
left_num = float(left_val)
|
||||
right_num = float(right_val)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
if op == ">":
|
||||
return float(left_val) > float(right_val)
|
||||
return left_num > right_num
|
||||
if op == "<":
|
||||
return float(left_val) < float(right_val)
|
||||
return left_num < right_num
|
||||
if op == ">=":
|
||||
return float(left_val) >= float(right_val)
|
||||
return left_num >= right_num
|
||||
if op == "<=":
|
||||
return float(left_val) <= float(right_val)
|
||||
return left_num <= right_num
|
||||
|
||||
# Boolean check for variable existence
|
||||
if _SAFE_VAR_PATTERN.match(expression):
|
||||
|
|
@ -513,9 +550,12 @@ async def execute_workflow(
|
|||
execution.variables = body.variables
|
||||
|
||||
# Start execution in background
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
_execute_workflow(workflow, execution, body.variables, store=store)
|
||||
)
|
||||
store._running_tasks = getattr(store, "_running_tasks", {})
|
||||
store._running_tasks[execution.execution_id] = task
|
||||
task.add_done_callback(lambda t: store._running_tasks.pop(execution.execution_id, None))
|
||||
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
|
|
@ -538,7 +578,8 @@ async def get_execution(request: Request, execution_id: str):
|
|||
|
||||
@router.post("/workflows/executions/{execution_id}/approve")
|
||||
async def approve_execution(
|
||||
request: Request, execution_id: str, body: ApproveRequest
|
||||
request: Request, execution_id: str, body: ApproveRequest,
|
||||
_auth: None = Depends(_verify_api_key),
|
||||
):
|
||||
"""Approve a paused approval node."""
|
||||
store = _get_store(request)
|
||||
|
|
@ -617,6 +658,11 @@ async def cancel_execution(request: Request, execution_id: str):
|
|||
status="cancelled",
|
||||
completed_at=execution.completed_at,
|
||||
)
|
||||
# Set any pending approval event so a paused workflow can observe the cancelled state
|
||||
if hasattr(execution, "current_stage") and execution.current_stage:
|
||||
event_key = f"{execution_id}:{execution.current_stage}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
return execution.model_dump()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -41,6 +43,28 @@ _FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = {
|
|||
}
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ComputerUseTool(Tool):
|
||||
"""Computer Use 工具
|
||||
|
||||
|
|
@ -82,6 +106,8 @@ class ComputerUseTool(Tool):
|
|||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._api_base_url = api_base_url
|
||||
if not _is_safe_url(self._api_base_url):
|
||||
raise ValueError(f"Unsafe api_base_url: {self._api_base_url}. Private/internal URLs are not allowed.")
|
||||
self._session_manager = ComputerUseSessionManager(
|
||||
session_factory=session_factory or InMemoryComputerUseSession,
|
||||
)
|
||||
|
|
@ -89,6 +115,18 @@ class ComputerUseTool(Tool):
|
|||
self._fallback_callback = fallback_callback
|
||||
self._max_retries = max_retries
|
||||
self._request_timeout = request_timeout
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_http_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create a persistent httpx.AsyncClient for connection reuse."""
|
||||
if self._http_client is None or self._http_client.is_closed:
|
||||
self._http_client = httpx.AsyncClient(timeout=self._request_timeout)
|
||||
return self._http_client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the persistent HTTP client."""
|
||||
if self._http_client and not self._http_client.is_closed:
|
||||
await self._http_client.aclose()
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
|
|
@ -374,47 +412,47 @@ class ComputerUseTool(Tool):
|
|||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._request_timeout) as client:
|
||||
response = await client.post(
|
||||
self._api_base_url,
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
client = self._get_http_client()
|
||||
response = await client.post(
|
||||
self._api_base_url,
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text[:500]
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error=f"Anthropic API error {response.status_code}: {error_detail}",
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text[:500]
|
||||
data = response.json()
|
||||
|
||||
# 解析 API 响应中的 tool_use 内容
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "tool_use" and block.get("name") == "computer":
|
||||
tool_input_resp = block.get("input", {})
|
||||
resp_action = tool_input_resp.get("action", action)
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error=f"Anthropic API error {response.status_code}: {error_detail}",
|
||||
success=True,
|
||||
action=resp_action,
|
||||
output=f"API executed: {resp_action}",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# API 没有返回 tool_use,可能是纯文本响应
|
||||
text_output = ""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
text_output += block.get("text", "")
|
||||
|
||||
# 解析 API 响应中的 tool_use 内容
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "tool_use" and block.get("name") == "computer":
|
||||
tool_input_resp = block.get("input", {})
|
||||
resp_action = tool_input_resp.get("action", action)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=resp_action,
|
||||
output=f"API executed: {resp_action}",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
|
||||
# API 没有返回 tool_use,可能是纯文本响应
|
||||
text_output = ""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
text_output += block.get("text", "")
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=action,
|
||||
output=text_output[:500] if text_output else "API call completed",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=action,
|
||||
output=text_output[:500] if text_output else "API call completed",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
|
||||
def _validate_params(self, action: str, kwargs: dict[str, Any]) -> str | None:
|
||||
"""验证操作参数
|
||||
|
|
|
|||
|
|
@ -166,6 +166,9 @@ class PTYSession:
|
|||
try:
|
||||
exit_code = await self._read_until_exit(timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if self._process and self._process.returncode is None:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
self._output_buffer += "\n[PTY 命令执行超时]"
|
||||
return PTYOutput(
|
||||
output=self._output_buffer,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import os
|
|||
import re
|
||||
import shlex
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
|
@ -26,9 +27,6 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
|||
"cd",
|
||||
"export",
|
||||
"ls",
|
||||
"cat",
|
||||
"head",
|
||||
"tail",
|
||||
"grep",
|
||||
"find",
|
||||
"pwd",
|
||||
|
|
@ -66,8 +64,6 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
|||
"npm list",
|
||||
"docker ps",
|
||||
"docker images",
|
||||
"curl",
|
||||
"wget",
|
||||
)
|
||||
|
||||
# 危险命令模式:这些命令需要人工确认
|
||||
|
|
@ -155,7 +151,7 @@ class ShellTool(Tool):
|
|||
self._confirm_callback = confirm_callback
|
||||
self._default_timeout = default_timeout
|
||||
self._max_output_length = max_output_length
|
||||
self._audit_log: list[dict[str, Any]] = []
|
||||
self._audit_log: deque[dict[str, Any]] = deque(maxlen=10000)
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
|
|
@ -394,8 +390,8 @@ class ShellTool(Tool):
|
|||
if command_stripped.lower().startswith(prefix_stripped):
|
||||
return False
|
||||
else:
|
||||
# Simple prefix - match against binary name only
|
||||
if binary.lower().startswith(prefix_stripped):
|
||||
# Simple prefix - match against binary name exactly
|
||||
if binary.lower() == prefix_stripped:
|
||||
return False
|
||||
|
||||
# Dangerous pattern check
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import os
|
|||
import re
|
||||
import shlex
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -65,8 +66,7 @@ class TerminalSession:
|
|||
self.session_id = session_id
|
||||
self._cwd = cwd or os.getcwd()
|
||||
self._env: dict[str, str] = dict(env or os.environ)
|
||||
self._history: list[CommandRecord] = []
|
||||
self._max_history = max_history
|
||||
self._history: deque[CommandRecord] = deque(maxlen=max_history)
|
||||
self._output_parser = OutputParser()
|
||||
self._created_at = time.time()
|
||||
|
||||
|
|
@ -271,10 +271,8 @@ class TerminalSession:
|
|||
self._env[key] = value
|
||||
|
||||
def _add_history(self, record: CommandRecord) -> None:
|
||||
"""添加命令记录到历史,超出上限时移除最旧记录"""
|
||||
"""添加命令记录到历史,deque maxlen 自动淘汰最旧记录"""
|
||||
self._history.append(record)
|
||||
while len(self._history) > self._max_history:
|
||||
self._history.pop(0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭会话,清理资源"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
"""Security utilities for URL validation."""
|
||||
|
||||
import ipaddress
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -2,29 +2,39 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_HAS_NUMPY = True
|
||||
except ImportError:
|
||||
_HAS_NUMPY = False
|
||||
|
||||
|
||||
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
vec_a: First vector.
|
||||
vec_b: Second vector.
|
||||
|
||||
Returns:
|
||||
Cosine similarity score between -1 and 1.
|
||||
Uses numpy for performance when available, falls back to pure Python.
|
||||
"""
|
||||
if len(vec_a) != len(vec_b):
|
||||
return 0.0
|
||||
if not vec_a:
|
||||
logger.warning("Vector length mismatch: %d vs %d, returning 0.0", len(vec_a), len(vec_b))
|
||||
return 0.0
|
||||
|
||||
if _HAS_NUMPY:
|
||||
a = np.array(vec_a, dtype=np.float64)
|
||||
b = np.array(vec_b, dtype=np.float64)
|
||||
norm_a = np.linalg.norm(a)
|
||||
norm_b = np.linalg.norm(b)
|
||||
if norm_a == 0.0 or norm_b == 0.0:
|
||||
return 0.0
|
||||
return float(np.dot(a, b) / (norm_a * norm_b))
|
||||
|
||||
# Pure Python fallback
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
norm_a = math.sqrt(sum(a * a for a in vec_a))
|
||||
norm_b = math.sqrt(sum(b * b for b in vec_b))
|
||||
|
||||
norm_a = sum(a * a for a in vec_a) ** 0.5
|
||||
norm_b = sum(b * b for b in vec_b) ** 0.5
|
||||
if norm_a == 0.0 or norm_b == 0.0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
|
|
|||
|
|
@ -442,7 +442,9 @@ class TestPlanExecutorPlanAdjustment:
|
|||
executor = PlanExecutor(agent_pool=pool, max_retries=0, on_step_failed=on_failed)
|
||||
result = await executor.execute(plan, make_task())
|
||||
|
||||
assert result.step_results["s0"].status == PlanStepStatus.SKIPPED
|
||||
# The failed step should remain FAILED (not SKIPPED)
|
||||
assert result.step_results["s0"].status == PlanStepStatus.FAILED
|
||||
# Remaining steps should be SKIPPED
|
||||
assert result.step_results["s1"].status == PlanStepStatus.SKIPPED
|
||||
assert result.adjusted is True
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue