fix(security): resolve all P0/P1 findings from code review

This commit is contained in:
chiguyong 2026-06-10 07:12:41 +08:00
parent b34f74f598
commit 9e9f1314f6
22 changed files with 457 additions and 181 deletions

View File

@ -496,6 +496,6 @@ class PlanExecutor:
# 所有步骤要么完成要么跳过
return TaskStatus.COMPLETED
if failed > 0:
return TaskStatus.COMPLETED # 部分成功
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
return TaskStatus.COMPLETED

View File

@ -13,6 +13,7 @@ class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
PARTIALLY_COMPLETED = "partially_completed"
FAILED = "failed"
CANCELLED = "cancelled"
HANDOFF = "handoff"

View File

@ -12,14 +12,18 @@ from __future__ import annotations
import logging
import math
import re
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import text
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__)
@ -69,6 +73,8 @@ class ExperienceStore:
self._retrieve_limit = retrieve_limit
self._pgvector_enabled = pgvector_enabled
self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
async def record_experience(self, experience: TaskExperience) -> str:
"""记录任务经验
@ -193,7 +199,7 @@ class ExperienceStore:
time_decay_score = (row.get("success_rate") or 0.5) * decay
if row_embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, row_embedding)
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else:
score = time_decay_score
@ -251,7 +257,7 @@ class ExperienceStore:
time_decay_score = (entry.success_rate or 0.5) * decay
if self._embedder and query_embedding is not None and entry.embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, entry.embedding)
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else:
score = time_decay_score
@ -425,7 +431,7 @@ class InMemoryExperienceStore:
time_decay_score = exp.success_rate * decay
if query_embedding is not None and exp.embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, exp.embedding)
cosine_sim = compute_cosine_similarity(query_embedding, exp.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else:
score = time_decay_score
@ -485,21 +491,6 @@ class InMemoryExperienceStore:
# ── 辅助函数 ──────────────────────────────────────────────
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b):
logger.warning(f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}")
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)
def _parse_time_window(window: str) -> timedelta:
"""解析时间窗口字符串为 timedelta

View File

@ -207,10 +207,12 @@ class PitfallDetector:
if error:
s.failure_reasons.append(error)
# 收集优化建议
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
# 收集优化建议 — only add to steps that are part of this experience
if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
for step_name, s in stats.items():
s.optimization_tips.extend(exp.optimization_tips)
if not experience_steps or step_name in experience_steps:
s.optimization_tips.extend(exp.optimization_tips)
return stats

View File

@ -6,8 +6,10 @@
from __future__ import annotations
import ipaddress
import logging
from typing import Any
from urllib.parse import urlparse
import httpx
@ -17,6 +19,33 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__)
def _escape_cql(value: str) -> str:
"""Escape special characters in CQL values."""
return value.replace("\\", "\\\\").replace('"', '\\"')
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class ConfluenceAdapter(KBAdapter):
"""Confluence 知识库适配器
@ -49,6 +78,8 @@ class ConfluenceAdapter(KBAdapter):
timeout=timeout,
)
self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._username = username
self._api_token = api_token
self._space_keys = space_keys or []
@ -88,10 +119,10 @@ class ConfluenceAdapter(KBAdapter):
"""
client = self._get_client()
try:
cql = f'text ~ "{query}"'
cql = f'text ~ "{_escape_cql(query)}"'
if self._space_keys:
space_filter = " OR ".join(
f'space = "{key}"' for key in self._space_keys
f'space = "{_escape_cql(key)}"' for key in self._space_keys
)
cql = f'{cql} AND ({space_filter})'

View File

@ -6,8 +6,11 @@
from __future__ import annotations
import ipaddress
import logging
import time
from typing import Any
from urllib.parse import urlparse
import httpx
@ -17,6 +20,28 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器
@ -51,8 +76,11 @@ class FeishuKBAdapter(KBAdapter):
self._app_id = app_id
self._app_secret = app_secret
self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._space_ids = space_ids or []
self._access_token: str | None = None
self._token_expiry: float = 0.0
def _make_client(self) -> httpx.AsyncClient:
"""创建飞书 API HTTP 客户端"""
@ -67,7 +95,7 @@ class FeishuKBAdapter(KBAdapter):
async def _get_access_token(self) -> str | None:
"""获取飞书 tenant_access_token"""
if self._access_token:
if self._access_token and time.time() < self._token_expiry:
return self._access_token
client = self._get_client()
@ -83,6 +111,8 @@ class FeishuKBAdapter(KBAdapter):
data = resp.json()
if data.get("code") == 0:
self._access_token = data.get("tenant_access_token")
expire_seconds = data.get("expire", 7200)
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
# 重建客户端以携带 token
await self.close()
return self._access_token

View File

@ -6,8 +6,10 @@
from __future__ import annotations
import ipaddress
import logging
from typing import Any
from urllib.parse import urlparse
import httpx
@ -17,6 +19,31 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
# Block common internal hostnames
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
# Try to resolve as IP and check for private ranges
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
# Not an IP address, that's OK (it's a domain name)
pass
return True
except Exception:
return False
class GenericHTTPAdapter(KBAdapter):
"""通用 HTTP 知识库适配器
@ -53,6 +80,8 @@ class GenericHTTPAdapter(KBAdapter):
timeout=timeout,
)
self._endpoint_url = endpoint_url.rstrip("/")
if not _is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
self._auth_config = auth_config or {}
self._extra_headers = headers or {}

View File

@ -10,6 +10,7 @@ from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__)
@ -123,7 +124,7 @@ class EpisodicMemory(Memory):
if row_embedding is None:
return None
cosine = self._compute_cosine_similarity(query_embedding, row_embedding)
cosine = compute_cosine_similarity(query_embedding, row_embedding)
if cosine < 0.1:
return None
@ -165,7 +166,7 @@ class EpisodicMemory(Memory):
entry_embedding = entry.embedding
if entry_embedding is None:
continue
cosine = self._compute_cosine_similarity(query_embedding, entry_embedding)
cosine = compute_cosine_similarity(query_embedding, entry_embedding)
if cosine > best_score:
best_score = cosine
best_item = entry
@ -260,7 +261,7 @@ class EpisodicMemory(Memory):
time_decay_score = (row.get("quality_score") or 0.5) * decay
if row_embedding is not None:
cosine_sim = self._compute_cosine_similarity(query_embedding, row_embedding)
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else:
score = time_decay_score
@ -327,7 +328,7 @@ class EpisodicMemory(Memory):
# 混合评分alpha * cosine + (1 - alpha) * time_decay
if self._embedder and query_embedding is not None and entry.embedding is not None:
cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding)
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else:
score = time_decay_score
@ -375,20 +376,3 @@ class EpisodicMemory(Memory):
await db.rollback()
logger.error(f"Failed to delete episodic memory: {e}")
return False
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b):
logger.warning(
f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}"
)
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)

View File

@ -10,10 +10,13 @@ from __future__ import annotations
import json
import logging
import re
import uuid
from datetime import datetime, timezone
from typing import Any
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder
@ -23,6 +26,7 @@ from agentkit.memory.knowledge_base import (
QueryResult,
SourceInfo,
)
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__)
@ -70,6 +74,8 @@ class LocalRAGService:
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
@ -335,7 +341,7 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError):
continue
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1:
continue
@ -359,18 +365,6 @@ class LocalRAGService:
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b) or not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)
class InMemoryLocalRAGService:
"""基于内存的本地 RAG 服务
@ -447,7 +441,7 @@ class InMemoryLocalRAGService:
candidates = []
for chunk_id, chunk_data in self._chunks.items():
stored_embedding = chunk_data["embedding"]
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1:
continue
@ -511,15 +505,3 @@ class InMemoryLocalRAGService:
source_doc_id=doc.doc_id,
metadata=doc.metadata,
)
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b) or not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)

View File

@ -457,6 +457,23 @@ async def list_path_optimizations(
@router.websocket("/evolution-dashboard/ws")
async def evolution_dashboard_ws(websocket: WebSocket):
"""自进化仪表盘实时更新 WebSocket"""
# Authentication - check api_key
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json(
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
)
await websocket.close(code=4001, reason="Invalid or missing api_key")
return
await websocket.accept()
_ws_connections.append(websocket)

View File

@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["kb-management"])
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
# ---------------------------------------------------------------------------
# In-memory Knowledge Source Store
@ -183,14 +185,18 @@ async def upload_document(
try:
from agentkit.memory.document_loader import DocumentLoader
content = await file.read()
content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
loader = DocumentLoader()
doc = loader.load_bytes(content, file.filename)
# Estimate chunks based on content length (rough approximation)
chunks = max(1, len(doc.content) // 500)
except ImportError:
# DocumentLoader not available, use basic estimation
content = await file.read()
content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
chunks = max(1, len(content) // 500)
except Exception as e:
logger.warning(f"Document parsing failed: {e}")

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import re
import uuid
from datetime import datetime, timezone
from typing import Any
@ -39,6 +40,7 @@ class WorkflowStore:
self._executions: dict[str, WorkflowExecution] = {}
self._max_workflows = max_workflows
self._max_executions = max_executions
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
workflow.updated_at = datetime.now(timezone.utc).isoformat()
@ -226,31 +228,70 @@ async def _execute_workflow(
try:
if stage.type == "approval":
# Pause execution and wait for approval
# Pause execution and wait for approval via asyncio.Event
event_key = f"{execution.execution_id}:{stage_name}"
approval_event = asyncio.Event()
_store._approval_events[event_key] = approval_event
execution.status = "paused"
execution.current_stage = stage_name
_store.update_execution(
execution.execution_id,
status="paused",
current_stage=stage_name,
)
await _broadcast_ws({
"event": "approval_required",
"execution_id": execution.execution_id,
"stage": stage_name,
})
# In a real implementation, this would wait for external approval
# For now, we simulate auto-approval after a brief pause
await asyncio.sleep(0.1)
execution.stage_results[stage_name] = {
"status": "approved",
"approver": "auto",
"comment": "自动审批通过",
}
execution.status = "running"
_store.update_execution(
execution.execution_id,
status="running",
stage_results=execution.stage_results,
)
# Wait for approval with timeout
try:
approval_timeout = stage.config.get("approval_timeout", 3600)
await asyncio.wait_for(approval_event.wait(), timeout=approval_timeout)
# Check if execution was cancelled/rejected while waiting
if execution.status == "cancelled":
await _broadcast_ws({
"event": "stage_failed",
"execution_id": execution.execution_id,
"stage": stage_name,
"error": "Approval rejected",
})
return
# Approval was granted — the /approve endpoint already set stage_results
# Only update status to running if not already set
if execution.status != "running":
execution.status = "running"
_store.update_execution(
execution.execution_id,
status="running",
)
except asyncio.TimeoutError:
execution.stage_results[stage_name] = {
"status": "timeout",
"approver": "none",
"comment": "审批超时",
}
execution.status = "failed"
execution.error = f"Approval timeout for stage {stage_name}"
execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution(
execution.execution_id,
status="failed",
error=execution.error,
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
await _broadcast_ws({
"event": "stage_failed",
"execution_id": execution.execution_id,
"stage": stage_name,
"error": "Approval timeout",
})
return
finally:
_store._approval_events.pop(event_key, None)
elif stage.type == "condition":
# Evaluate condition expression
condition_expr = stage.config.get("expression", "")
@ -318,23 +359,60 @@ async def _execute_workflow(
})
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
_SAFE_OPERATORS = {"==", "!=", ">", "<", ">=", "<="}
def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
"""Simple condition evaluation."""
"""Evaluate a condition expression safely."""
expression = expression.strip()
if not expression:
return True
if "==" in expression:
parts = expression.split("==", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) == right
elif "!=" in expression:
parts = expression.split("!=", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) != right
else:
# Try each operator (longer operators first to avoid partial matches)
for op in sorted(_SAFE_OPERATORS, key=len, reverse=True):
if op in expression:
parts = expression.split(op, 1)
if len(parts) != 2:
continue
left = parts[0].strip()
right = parts[1].strip()
# Validate variable names
if left and not _SAFE_VAR_PATTERN.match(left):
raise ValueError(f"Invalid variable name in condition: {left}")
left_val = variables.get(left, left)
# Strip quotes from right side if present
if right.startswith('"') and right.endswith('"'):
right_val = right[1:-1]
elif right.startswith("'") and right.endswith("'"):
right_val = right[1:-1]
elif right and _SAFE_VAR_PATTERN.match(right):
right_val = variables.get(right, right)
else:
right_val = right
# Compare based on operator
if op == "==":
return str(left_val) == str(right_val)
if op == "!=":
return str(left_val) != str(right_val)
if op == ">":
return float(left_val) > float(right_val)
if op == "<":
return float(left_val) < float(right_val)
if op == ">=":
return float(left_val) >= float(right_val)
if op == "<=":
return float(left_val) <= float(right_val)
# Boolean check for variable existence
if _SAFE_VAR_PATTERN.match(expression):
return bool(variables.get(expression))
raise ValueError(f"Invalid condition expression: {expression}")
async def _broadcast_ws(message: dict[str, Any]) -> None:
"""Broadcast a message to all WebSocket subscribers."""
@ -487,12 +565,12 @@ async def approve_execution(
status="running",
stage_results=execution.stage_results,
)
# Resume execution
workflow = store.get(execution.workflow_id)
if workflow:
asyncio.create_task(
_execute_workflow(workflow, execution, execution.variables, store=store)
)
# Resume the waiting execution by setting the approval event
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
else:
execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat()
@ -508,6 +586,12 @@ async def approve_execution(
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
# Set the approval event so the waiting coroutine can observe the cancelled state
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump()

View File

@ -18,6 +18,8 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__)
# 自动应答规则:(prompt_pattern, response)
# WARNING: auto_respond is disabled by default for safety.
# Enable it only when you explicitly want automatic yes/confirm responses.
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
(r"\[y/N\]\s*$", "y"),
(r"\[Y/n\]\s*$", "y"),
@ -61,7 +63,7 @@ class PTYSession:
def __init__(
self,
auto_respond: bool = True,
auto_respond: bool = False,
custom_rules: list[tuple[str, str]] | None = None,
default_timeout: float = 30.0,
buffer_size: int = 4096,

View File

@ -9,6 +9,8 @@ from __future__ import annotations
import asyncio
import logging
import os
import re
import shlex
import time
from typing import Any, Callable, Awaitable
@ -21,6 +23,8 @@ logger = logging.getLogger(__name__)
# 安全白名单:这些命令前缀不需要确认
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"cd",
"export",
"ls",
"cat",
"head",
@ -48,6 +52,7 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"sort",
"uniq",
"diff",
"sleep",
"git status",
"git log",
"git diff",
@ -101,6 +106,9 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
)
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
class ShellTool(Tool):
"""Shell 命令执行工具
@ -364,18 +372,39 @@ class ShellTool(Tool):
"""
command_stripped = command.strip()
# 白名单检查
for prefix in _SAFE_COMMAND_PREFIXES:
if command_stripped.startswith(prefix):
return False
# Check for shell operators that chain commands (always dangerous)
if _SHELL_OPERATORS.search(command_stripped):
return True
# 危险模式检查
# Parse the actual binary being invoked
try:
tokens = shlex.split(command_stripped)
if not tokens:
return True
binary = os.path.basename(tokens[0])
except ValueError:
# Unparsable command - treat as dangerous
return True
# Whitelist check: first try full command prefix match, then binary-only match
for prefix in _SAFE_COMMAND_PREFIXES:
prefix_stripped = prefix.lower().strip()
if " " in prefix_stripped:
# Compound prefix like "git status" - match against full command
if command_stripped.lower().startswith(prefix_stripped):
return False
else:
# Simple prefix - match against binary name only
if binary.lower().startswith(prefix_stripped):
return False
# Dangerous pattern check
command_lower = command_stripped.lower()
for pattern in _DANGEROUS_PATTERNS:
if pattern in command_lower:
return True
return False
return True # Unknown commands are dangerous by default
async def _request_confirmation(self, command: str) -> bool:
"""请求人工确认危险命令

View File

@ -9,6 +9,8 @@ from __future__ import annotations
import asyncio
import logging
import os
import re
import shlex
import time
from dataclasses import dataclass, field
from typing import Any
@ -17,6 +19,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput
logger = logging.getLogger(__name__)
_ENV_KEY_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
@dataclass
class CommandRecord:
@ -190,15 +194,13 @@ class TerminalSession:
# 注入 cd
if self._cwd:
# 使用 shlex.quote 风格的简单转义
cwd_escaped = self._cwd.replace("'", "'\\''")
parts.append(f"cd '{cwd_escaped}'")
parts.append(f"cd {shlex.quote(self._cwd)}")
# 注入环境变量
for key, value in self._env.items():
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
val_escaped = value.replace("'", "'\\''")
parts.append(f"export {key}='{val_escaped}'")
if not _ENV_KEY_PATTERN.match(key):
continue # Skip invalid env key names
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
parts.append(command)
return " && ".join(parts)

View File

@ -0,0 +1 @@
"""AgentKit utility modules."""

View File

@ -0,0 +1,30 @@
"""Shared vector math utilities."""
from __future__ import annotations
import math
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""Compute cosine similarity between two vectors.
Args:
vec_a: First vector.
vec_b: Second vector.
Returns:
Cosine similarity score between -1 and 1.
"""
if len(vec_a) != len(vec_b):
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return dot_product / (norm_a * norm_b)

View File

@ -433,12 +433,34 @@ async def test_workflow_with_approval():
assert execution.status == "pending"
assert execution.execution_id
# 5. Execute workflow (runs in background)
# 5. Execute workflow in background (approval stage will wait for event)
from agentkit.server.routes.workflows import _execute_workflow
await _execute_workflow(workflow, execution, variables={}, store=store)
# Use a short approval timeout for testing
workflow.stages[1].config["approval_timeout"] = 5
# 6. Verify execution completed (auto-approval in test mode)
async def _approve_after_pause():
"""Wait for execution to pause, then approve."""
for _ in range(100):
await asyncio.sleep(0.05)
updated = store.get_execution(execution.execution_id)
if updated and updated.status == "paused":
break
# Trigger approval
event_key = f"{execution.execution_id}:human_review"
if event_key in store._approval_events:
execution.stage_results["human_review"] = {
"status": "approved",
"approver": "test_user",
"comment": "Auto-approved in test",
}
store._approval_events[event_key].set()
approve_task = asyncio.create_task(_approve_after_pause())
await _execute_workflow(workflow, execution, variables={}, store=store)
await approve_task
# 6. Verify execution completed
updated = store.get_execution(execution.execution_id)
assert updated is not None
assert updated.status == "completed"
@ -452,7 +474,7 @@ async def test_workflow_with_approval():
approval_result = updated.stage_results["human_review"]
assert approval_result.get("status") in ("approved", "completed")
# 9. Test manual approval flow
# 9. Test second workflow with approval
workflow2 = WorkflowDefinition(
workflow_id="wf-manual-approval",
name="手动审批流程",
@ -468,6 +490,7 @@ async def test_workflow_with_approval():
agent="reviewer",
action="approve",
type="approval",
config={"approval_timeout": 5},
depends_on=["step1"],
),
WorkflowStage(
@ -481,28 +504,30 @@ async def test_workflow_with_approval():
)
store.save(workflow2)
# Simulate manual approval via API
execution2 = store.create_execution(workflow2.workflow_id)
execution2.status = "paused"
execution2.current_stage = "approval_step"
store.update_execution(
execution2.execution_id,
status="paused",
current_stage="approval_step",
)
# Approve
execution2.stage_results["approval_step"] = {
"status": "approved",
"approver": "user",
"comment": "LGTM",
}
execution2.status = "running"
store.update_execution(
execution2.execution_id,
status="running",
stage_results=execution2.stage_results,
)
async def _approve2_after_pause():
for _ in range(100):
await asyncio.sleep(0.05)
updated2 = store.get_execution(execution2.execution_id)
if updated2 and updated2.status == "paused":
break
event_key2 = f"{execution2.execution_id}:approval_step"
if event_key2 in store._approval_events:
execution2.stage_results["approval_step"] = {
"status": "approved",
"approver": "user",
"comment": "LGTM",
}
store._approval_events[event_key2].set()
approve_task2 = asyncio.create_task(_approve2_after_pause())
await _execute_workflow(workflow2, execution2, variables={}, store=store)
await approve_task2
updated2 = store.get_execution(execution2.execution_id)
assert updated2 is not None
assert updated2.status == "completed"
# Verify approval was recorded
paused_exec = store.get_execution(execution2.execution_id)
@ -1014,7 +1039,27 @@ async def test_multi_source_rag_with_workflow(local_rag, mock_embedder):
execution = store.create_execution(workflow.workflow_id)
from agentkit.server.routes.workflows import _execute_workflow
# Set short approval timeout and handle approval
workflow.stages[1].config["approval_timeout"] = 5
async def _approve_kb_review():
for _ in range(100):
await asyncio.sleep(0.05)
upd = store.get_execution(execution.execution_id)
if upd and upd.status == "paused":
break
event_key = f"{execution.execution_id}:review_findings"
if event_key in store._approval_events:
execution.stage_results["review_findings"] = {
"status": "approved",
"approver": "test_user",
"comment": "Approved",
}
store._approval_events[event_key].set()
approve_task = asyncio.create_task(_approve_kb_review())
await _execute_workflow(workflow, execution, variables={}, store=store)
await approve_task
updated = store.get_execution(execution.execution_id)
assert updated.status == "completed"

View File

@ -9,10 +9,10 @@ import pytest
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
from agentkit.evolution.experience_store import (
InMemoryExperienceStore,
_compute_cosine_similarity,
_parse_time_window,
)
from agentkit.memory.embedder import MockEmbedder
from agentkit.utils.vector_math import compute_cosine_similarity
# ── Fixtures ──────────────────────────────────────────────
@ -136,23 +136,23 @@ class TestEvolutionMetrics:
class TestHelperFunctions:
def test_cosine_similarity_identical(self):
vec = [1.0, 0.0, 0.0]
assert _compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
def test_cosine_similarity_orthogonal(self):
a = [1.0, 0.0]
b = [0.0, 1.0]
assert _compute_cosine_similarity(a, b) == pytest.approx(0.0)
assert compute_cosine_similarity(a, b) == pytest.approx(0.0)
def test_cosine_similarity_opposite(self):
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert _compute_cosine_similarity(a, b) == pytest.approx(-1.0)
assert compute_cosine_similarity(a, b) == pytest.approx(-1.0)
def test_cosine_similarity_empty(self):
assert _compute_cosine_similarity([], []) == 0.0
assert compute_cosine_similarity([], []) == 0.0
def test_cosine_similarity_mismatched_dims(self):
assert _compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
assert compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
def test_parse_time_window_hours(self):
delta = _parse_time_window("24h")

View File

@ -1,6 +1,7 @@
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
import pytest
import time
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
@ -64,7 +65,7 @@ class TestKnowledgeBaseProtocol:
assert isinstance(adapter2, KnowledgeBase)
adapter3 = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
assert isinstance(adapter3, KnowledgeBase)
@ -256,6 +257,8 @@ class TestFeishuKBAdapterSearch:
async def test_search_success(self, adapter):
# Mock authentication
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
@ -307,6 +310,7 @@ class TestFeishuKBAdapterSearch:
@pytest.mark.asyncio
async def test_search_api_error(self, adapter):
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
@ -328,6 +332,7 @@ class TestFeishuKBAdapterSearch:
import httpx
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 500
@ -380,6 +385,7 @@ class TestFeishuKBAdapterListSources:
async def test_list_sources_success(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
@ -420,6 +426,7 @@ class TestFeishuKBAdapterGetDocument:
async def test_get_document_success(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
@ -450,6 +457,7 @@ class TestFeishuKBAdapterGetDocument:
async def test_get_document_not_found(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
@ -736,23 +744,23 @@ class TestGenericHTTPAdapterInit:
def test_basic_init(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
assert adapter._endpoint_url == "https://example.com/api/kb"
assert adapter._auth_config == {}
assert adapter._extra_headers == {}
assert adapter._source_type == "generic_http"
def test_init_with_auth_bearer(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb/",
endpoint_url="https://example.com/api/kb/",
auth_config={"type": "bearer", "token": "sk-test"},
headers={"X-Custom": "value"},
source_id="my-kb",
source_name="My KB",
timeout=60,
)
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
assert adapter._endpoint_url == "https://example.com/api/kb"
assert adapter._auth_config["type"] == "bearer"
assert adapter._extra_headers == {"X-Custom": "value"}
assert adapter._source_id == "my-kb"
@ -760,7 +768,7 @@ class TestGenericHTTPAdapterInit:
def test_client_bearer_auth_header(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
auth_config={"type": "bearer", "token": "sk-test"},
)
client = adapter._make_client()
@ -768,7 +776,7 @@ class TestGenericHTTPAdapterInit:
def test_client_basic_auth_header(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
auth_config={"type": "basic", "username": "user", "password": "pass"},
)
client = adapter._make_client()
@ -777,7 +785,7 @@ class TestGenericHTTPAdapterInit:
def test_client_api_key_header(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
)
client = adapter._make_client()
@ -790,7 +798,7 @@ class TestGenericHTTPAdapterSearch:
@pytest.fixture
def adapter(self):
return GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
auth_config={"type": "bearer", "token": "sk-test"},
)
@ -892,7 +900,7 @@ class TestGenericHTTPAdapterIngest:
@pytest.mark.asyncio
async def test_ingest_success(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -921,7 +929,7 @@ class TestGenericHTTPAdapterIngest:
import httpx
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -946,7 +954,7 @@ class TestGenericHTTPAdapterDeleteById:
@pytest.mark.asyncio
async def test_delete_success(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -963,7 +971,7 @@ class TestGenericHTTPAdapterDeleteById:
@pytest.mark.asyncio
async def test_delete_not_found(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -983,7 +991,7 @@ class TestGenericHTTPAdapterGetDocument:
@pytest.mark.asyncio
async def test_get_document_success(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -1011,7 +1019,7 @@ class TestGenericHTTPAdapterGetDocument:
import httpx
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -1034,7 +1042,7 @@ class TestGenericHTTPAdapterHealthCheck:
@pytest.mark.asyncio
async def test_health_check_ok(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -1050,7 +1058,7 @@ class TestGenericHTTPAdapterHealthCheck:
async def test_health_check_fallback_to_root(self):
"""health endpoint 不存在时回退到根路径"""
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
import httpx
@ -1076,7 +1084,7 @@ class TestGenericHTTPAdapterHealthCheck:
@pytest.mark.asyncio
async def test_health_check_connection_error(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_client = AsyncMock()
@ -1092,7 +1100,7 @@ class TestGenericHTTPAdapterListSources:
@pytest.mark.asyncio
async def test_list_sources_success(self):
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_resp = MagicMock()
@ -1116,7 +1124,7 @@ class TestGenericHTTPAdapterListSources:
async def test_list_sources_endpoint_not_found(self):
"""sources endpoint 不存在时返回默认信息源"""
adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb",
endpoint_url="https://example.com/api/kb",
)
mock_client = AsyncMock()
@ -1146,7 +1154,7 @@ class TestCrossAdapterIntegration:
username="user@test.com",
api_token="token",
),
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
]
for adapter in adapters:
assert isinstance(adapter, KnowledgeBase)
@ -1166,7 +1174,7 @@ class TestCrossAdapterIntegration:
username="user@test.com",
api_token="token",
),
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
]
for adapter in adapters:
assert hasattr(adapter, "search")
@ -1179,6 +1187,7 @@ class TestCrossAdapterIntegration:
# Feishu
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
feishu._access_token = "t-xxx"
feishu._token_expiry = time.time() + 7200
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
@ -1220,7 +1229,7 @@ class TestCrossAdapterIntegration:
assert all(isinstance(r, QueryResult) for r in results)
# GenericHTTP
generic = GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb")
generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb")
mock_resp3 = MagicMock()
mock_resp3.status_code = 200
mock_resp3.raise_for_status = MagicMock()

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.base import MemoryItem
from agentkit.memory.embedder import MockEmbedder
from agentkit.utils.vector_math import compute_cosine_similarity
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
@ -112,40 +113,40 @@ def _make_row_mapping(data: dict) -> _RowMapping:
class TestCosineSimilarity:
"""_compute_cosine_similarity 测试"""
"""compute_cosine_similarity 测试"""
def test_identical_vectors_return_one(self):
"""相同向量余弦相似度为 1"""
vec = [1.0, 0.0, 0.0]
assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
def test_orthogonal_vectors_return_zero(self):
"""正交向量余弦相似度为 0"""
vec_a = [1.0, 0.0]
vec_b = [0.0, 1.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
def test_opposite_vectors_return_minus_one(self):
"""相反向量余弦相似度为 -1"""
vec_a = [1.0, 0.0]
vec_b = [-1.0, 0.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
def test_dimension_mismatch_returns_zero(self):
"""维度不匹配返回 0"""
vec_a = [1.0, 2.0]
vec_b = [1.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
def test_empty_vectors_return_zero(self):
"""空向量返回 0"""
assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0
assert compute_cosine_similarity([], []) == 0.0
def test_zero_vector_returns_zero(self):
"""零向量返回 0"""
vec_a = [0.0, 0.0]
vec_b = [1.0, 2.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
# ── MockEmbedder 测试 ───────────────────────────────────

View File

@ -25,7 +25,7 @@ class TestPTYSessionConstruction:
def test_default_construction(self):
pty = PTYSession()
assert pty.is_running is False
assert pty._auto_respond is True
assert pty._auto_respond is False
assert pty._default_timeout == 30.0
def test_custom_construction(self):