fix(security): resolve all P0/P1 findings from code review

This commit is contained in:
chiguyong 2026-06-10 07:12:41 +08:00
parent b34f74f598
commit 9e9f1314f6
22 changed files with 457 additions and 181 deletions

View File

@ -496,6 +496,6 @@ class PlanExecutor:
# 所有步骤要么完成要么跳过 # 所有步骤要么完成要么跳过
return TaskStatus.COMPLETED return TaskStatus.COMPLETED
if failed > 0: if failed > 0:
return TaskStatus.COMPLETED # 部分成功 return TaskStatus.PARTIALLY_COMPLETED # 部分成功
return TaskStatus.COMPLETED return TaskStatus.COMPLETED

View File

@ -13,6 +13,7 @@ class TaskStatus(str, Enum):
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
COMPLETED = "completed" COMPLETED = "completed"
PARTIALLY_COMPLETED = "partially_completed"
FAILED = "failed" FAILED = "failed"
CANCELLED = "cancelled" CANCELLED = "cancelled"
HANDOFF = "handoff" HANDOFF = "handoff"

View File

@ -12,14 +12,18 @@ from __future__ import annotations
import logging import logging
import math import math
import re
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
from sqlalchemy import text from sqlalchemy import text
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,6 +73,8 @@ class ExperienceStore:
self._retrieve_limit = retrieve_limit self._retrieve_limit = retrieve_limit
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._table_name = table_name self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
async def record_experience(self, experience: TaskExperience) -> str: async def record_experience(self, experience: TaskExperience) -> str:
"""记录任务经验 """记录任务经验
@ -193,7 +199,7 @@ class ExperienceStore:
time_decay_score = (row.get("success_rate") or 0.5) * decay time_decay_score = (row.get("success_rate") or 0.5) * decay
if row_embedding is not None: if row_embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, row_embedding) cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else: else:
score = time_decay_score score = time_decay_score
@ -251,7 +257,7 @@ class ExperienceStore:
time_decay_score = (entry.success_rate or 0.5) * decay time_decay_score = (entry.success_rate or 0.5) * decay
if self._embedder and query_embedding is not None and entry.embedding is not None: if self._embedder and query_embedding is not None and entry.embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, entry.embedding) cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else: else:
score = time_decay_score score = time_decay_score
@ -425,7 +431,7 @@ class InMemoryExperienceStore:
time_decay_score = exp.success_rate * decay time_decay_score = exp.success_rate * decay
if query_embedding is not None and exp.embedding is not None: if query_embedding is not None and exp.embedding is not None:
cosine_sim = _compute_cosine_similarity(query_embedding, exp.embedding) cosine_sim = compute_cosine_similarity(query_embedding, exp.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else: else:
score = time_decay_score score = time_decay_score
@ -485,21 +491,6 @@ class InMemoryExperienceStore:
# ── 辅助函数 ────────────────────────────────────────────── # ── 辅助函数 ──────────────────────────────────────────────
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b):
logger.warning(f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}")
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)
def _parse_time_window(window: str) -> timedelta: def _parse_time_window(window: str) -> timedelta:
"""解析时间窗口字符串为 timedelta """解析时间窗口字符串为 timedelta

View File

@ -207,10 +207,12 @@ class PitfallDetector:
if error: if error:
s.failure_reasons.append(error) s.failure_reasons.append(error)
# 收集优化建议 # 收集优化建议 — only add to steps that are part of this experience
if hasattr(exp, "optimization_tips") and exp.optimization_tips: if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
for step_name, s in stats.items(): for step_name, s in stats.items():
s.optimization_tips.extend(exp.optimization_tips) if not experience_steps or step_name in experience_steps:
s.optimization_tips.extend(exp.optimization_tips)
return stats return stats

View File

@ -6,8 +6,10 @@
from __future__ import annotations from __future__ import annotations
import ipaddress
import logging import logging
from typing import Any from typing import Any
from urllib.parse import urlparse
import httpx import httpx
@ -17,6 +19,33 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _escape_cql(value: str) -> str:
"""Escape special characters in CQL values."""
return value.replace("\\", "\\\\").replace('"', '\\"')
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class ConfluenceAdapter(KBAdapter): class ConfluenceAdapter(KBAdapter):
"""Confluence 知识库适配器 """Confluence 知识库适配器
@ -49,6 +78,8 @@ class ConfluenceAdapter(KBAdapter):
timeout=timeout, timeout=timeout,
) )
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._username = username self._username = username
self._api_token = api_token self._api_token = api_token
self._space_keys = space_keys or [] self._space_keys = space_keys or []
@ -88,10 +119,10 @@ class ConfluenceAdapter(KBAdapter):
""" """
client = self._get_client() client = self._get_client()
try: try:
cql = f'text ~ "{query}"' cql = f'text ~ "{_escape_cql(query)}"'
if self._space_keys: if self._space_keys:
space_filter = " OR ".join( space_filter = " OR ".join(
f'space = "{key}"' for key in self._space_keys f'space = "{_escape_cql(key)}"' for key in self._space_keys
) )
cql = f'{cql} AND ({space_filter})' cql = f'{cql} AND ({space_filter})'

View File

@ -6,8 +6,11 @@
from __future__ import annotations from __future__ import annotations
import ipaddress
import logging import logging
import time
from typing import Any from typing import Any
from urllib.parse import urlparse
import httpx import httpx
@ -17,6 +20,28 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
pass
return True
except Exception:
return False
class FeishuKBAdapter(KBAdapter): class FeishuKBAdapter(KBAdapter):
"""飞书知识库适配器 """飞书知识库适配器
@ -51,8 +76,11 @@ class FeishuKBAdapter(KBAdapter):
self._app_id = app_id self._app_id = app_id
self._app_secret = app_secret self._app_secret = app_secret
self._base_url = base_url.rstrip("/") self._base_url = base_url.rstrip("/")
if not _is_safe_url(self._base_url):
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
self._space_ids = space_ids or [] self._space_ids = space_ids or []
self._access_token: str | None = None self._access_token: str | None = None
self._token_expiry: float = 0.0
def _make_client(self) -> httpx.AsyncClient: def _make_client(self) -> httpx.AsyncClient:
"""创建飞书 API HTTP 客户端""" """创建飞书 API HTTP 客户端"""
@ -67,7 +95,7 @@ class FeishuKBAdapter(KBAdapter):
async def _get_access_token(self) -> str | None: async def _get_access_token(self) -> str | None:
"""获取飞书 tenant_access_token""" """获取飞书 tenant_access_token"""
if self._access_token: if self._access_token and time.time() < self._token_expiry:
return self._access_token return self._access_token
client = self._get_client() client = self._get_client()
@ -83,6 +111,8 @@ class FeishuKBAdapter(KBAdapter):
data = resp.json() data = resp.json()
if data.get("code") == 0: if data.get("code") == 0:
self._access_token = data.get("tenant_access_token") self._access_token = data.get("tenant_access_token")
expire_seconds = data.get("expire", 7200)
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
# 重建客户端以携带 token # 重建客户端以携带 token
await self.close() await self.close()
return self._access_token return self._access_token

View File

@ -6,8 +6,10 @@
from __future__ import annotations from __future__ import annotations
import ipaddress
import logging import logging
from typing import Any from typing import Any
from urllib.parse import urlparse
import httpx import httpx
@ -17,6 +19,31 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_safe_url(url: str) -> bool:
"""Check if URL is safe (not pointing to private/internal networks)."""
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
# Block common internal hostnames
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
return False
# Try to resolve as IP and check for private ranges
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
return False
except ValueError:
# Not an IP address, that's OK (it's a domain name)
pass
return True
except Exception:
return False
class GenericHTTPAdapter(KBAdapter): class GenericHTTPAdapter(KBAdapter):
"""通用 HTTP 知识库适配器 """通用 HTTP 知识库适配器
@ -53,6 +80,8 @@ class GenericHTTPAdapter(KBAdapter):
timeout=timeout, timeout=timeout,
) )
self._endpoint_url = endpoint_url.rstrip("/") self._endpoint_url = endpoint_url.rstrip("/")
if not _is_safe_url(self._endpoint_url):
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
self._auth_config = auth_config or {} self._auth_config = auth_config or {}
self._extra_headers = headers or {} self._extra_headers = headers or {}

View File

@ -10,6 +10,7 @@ from sqlalchemy import text
from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,7 +124,7 @@ class EpisodicMemory(Memory):
if row_embedding is None: if row_embedding is None:
return None return None
cosine = self._compute_cosine_similarity(query_embedding, row_embedding) cosine = compute_cosine_similarity(query_embedding, row_embedding)
if cosine < 0.1: if cosine < 0.1:
return None return None
@ -165,7 +166,7 @@ class EpisodicMemory(Memory):
entry_embedding = entry.embedding entry_embedding = entry.embedding
if entry_embedding is None: if entry_embedding is None:
continue continue
cosine = self._compute_cosine_similarity(query_embedding, entry_embedding) cosine = compute_cosine_similarity(query_embedding, entry_embedding)
if cosine > best_score: if cosine > best_score:
best_score = cosine best_score = cosine
best_item = entry best_item = entry
@ -260,7 +261,7 @@ class EpisodicMemory(Memory):
time_decay_score = (row.get("quality_score") or 0.5) * decay time_decay_score = (row.get("quality_score") or 0.5) * decay
if row_embedding is not None: if row_embedding is not None:
cosine_sim = self._compute_cosine_similarity(query_embedding, row_embedding) cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else: else:
score = time_decay_score score = time_decay_score
@ -327,7 +328,7 @@ class EpisodicMemory(Memory):
# 混合评分alpha * cosine + (1 - alpha) * time_decay # 混合评分alpha * cosine + (1 - alpha) * time_decay
if self._embedder and query_embedding is not None and entry.embedding is not None: if self._embedder and query_embedding is not None and entry.embedding is not None:
cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding) cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
else: else:
score = time_decay_score score = time_decay_score
@ -375,20 +376,3 @@ class EpisodicMemory(Memory):
await db.rollback() await db.rollback()
logger.error(f"Failed to delete episodic memory: {e}") logger.error(f"Failed to delete episodic memory: {e}")
return False return False
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b):
logger.warning(
f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}"
)
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)

View File

@ -10,10 +10,13 @@ from __future__ import annotations
import json import json
import logging import logging
import re
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder from agentkit.memory.embedder import Embedder
@ -23,6 +26,7 @@ from agentkit.memory.knowledge_base import (
QueryResult, QueryResult,
SourceInfo, SourceInfo,
) )
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,6 +74,8 @@ class LocalRAGService:
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._table_name = table_name self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
self._pgvector_enabled = pgvector_enabled self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
@ -335,7 +341,7 @@ class LocalRAGService:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
continue continue
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding) cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1: if cosine < 0.1:
continue continue
@ -359,18 +365,6 @@ class LocalRAGService:
candidates.sort(key=lambda x: x.score, reverse=True) candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k] return candidates[:top_k]
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b) or not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)
class InMemoryLocalRAGService: class InMemoryLocalRAGService:
"""基于内存的本地 RAG 服务 """基于内存的本地 RAG 服务
@ -447,7 +441,7 @@ class InMemoryLocalRAGService:
candidates = [] candidates = []
for chunk_id, chunk_data in self._chunks.items(): for chunk_id, chunk_data in self._chunks.items():
stored_embedding = chunk_data["embedding"] stored_embedding = chunk_data["embedding"]
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding) cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1: if cosine < 0.1:
continue continue
@ -511,15 +505,3 @@ class InMemoryLocalRAGService:
source_doc_id=doc.doc_id, source_doc_id=doc.doc_id,
metadata=doc.metadata, metadata=doc.metadata,
) )
@staticmethod
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""计算两个向量的余弦相似度"""
if len(vec_a) != len(vec_b) or not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
if magnitude_a == 0.0 or magnitude_b == 0.0:
return 0.0
return dot_product / (magnitude_a * magnitude_b)

View File

@ -457,6 +457,23 @@ async def list_path_optimizations(
@router.websocket("/evolution-dashboard/ws") @router.websocket("/evolution-dashboard/ws")
async def evolution_dashboard_ws(websocket: WebSocket): async def evolution_dashboard_ws(websocket: WebSocket):
"""自进化仪表盘实时更新 WebSocket""" """自进化仪表盘实时更新 WebSocket"""
# Authentication - check api_key
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json(
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
)
await websocket.close(code=4001, reason="Invalid or missing api_key")
return
await websocket.accept() await websocket.accept()
_ws_connections.append(websocket) _ws_connections.append(websocket)

View File

@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["kb-management"]) router = APIRouter(tags=["kb-management"])
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# In-memory Knowledge Source Store # In-memory Knowledge Source Store
@ -183,14 +185,18 @@ async def upload_document(
try: try:
from agentkit.memory.document_loader import DocumentLoader from agentkit.memory.document_loader import DocumentLoader
content = await file.read() content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
loader = DocumentLoader() loader = DocumentLoader()
doc = loader.load_bytes(content, file.filename) doc = loader.load_bytes(content, file.filename)
# Estimate chunks based on content length (rough approximation) # Estimate chunks based on content length (rough approximation)
chunks = max(1, len(doc.content) // 500) chunks = max(1, len(doc.content) // 500)
except ImportError: except ImportError:
# DocumentLoader not available, use basic estimation # DocumentLoader not available, use basic estimation
content = await file.read() content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
chunks = max(1, len(content) // 500) chunks = max(1, len(content) // 500)
except Exception as e: except Exception as e:
logger.warning(f"Document parsing failed: {e}") logger.warning(f"Document parsing failed: {e}")

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import re
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
@ -39,6 +40,7 @@ class WorkflowStore:
self._executions: dict[str, WorkflowExecution] = {} self._executions: dict[str, WorkflowExecution] = {}
self._max_workflows = max_workflows self._max_workflows = max_workflows
self._max_executions = max_executions self._max_executions = max_executions
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition: def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
workflow.updated_at = datetime.now(timezone.utc).isoformat() workflow.updated_at = datetime.now(timezone.utc).isoformat()
@ -226,31 +228,70 @@ async def _execute_workflow(
try: try:
if stage.type == "approval": if stage.type == "approval":
# Pause execution and wait for approval # Pause execution and wait for approval via asyncio.Event
event_key = f"{execution.execution_id}:{stage_name}"
approval_event = asyncio.Event()
_store._approval_events[event_key] = approval_event
execution.status = "paused" execution.status = "paused"
execution.current_stage = stage_name
_store.update_execution( _store.update_execution(
execution.execution_id, execution.execution_id,
status="paused", status="paused",
current_stage=stage_name,
) )
await _broadcast_ws({ await _broadcast_ws({
"event": "approval_required", "event": "approval_required",
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"stage": stage_name, "stage": stage_name,
}) })
# In a real implementation, this would wait for external approval
# For now, we simulate auto-approval after a brief pause # Wait for approval with timeout
await asyncio.sleep(0.1) try:
execution.stage_results[stage_name] = { approval_timeout = stage.config.get("approval_timeout", 3600)
"status": "approved", await asyncio.wait_for(approval_event.wait(), timeout=approval_timeout)
"approver": "auto", # Check if execution was cancelled/rejected while waiting
"comment": "自动审批通过", if execution.status == "cancelled":
} await _broadcast_ws({
execution.status = "running" "event": "stage_failed",
_store.update_execution( "execution_id": execution.execution_id,
execution.execution_id, "stage": stage_name,
status="running", "error": "Approval rejected",
stage_results=execution.stage_results, })
) return
# Approval was granted — the /approve endpoint already set stage_results
# Only update status to running if not already set
if execution.status != "running":
execution.status = "running"
_store.update_execution(
execution.execution_id,
status="running",
)
except asyncio.TimeoutError:
execution.stage_results[stage_name] = {
"status": "timeout",
"approver": "none",
"comment": "审批超时",
}
execution.status = "failed"
execution.error = f"Approval timeout for stage {stage_name}"
execution.completed_at = datetime.now(timezone.utc).isoformat()
_store.update_execution(
execution.execution_id,
status="failed",
error=execution.error,
completed_at=execution.completed_at,
stage_results=execution.stage_results,
)
await _broadcast_ws({
"event": "stage_failed",
"execution_id": execution.execution_id,
"stage": stage_name,
"error": "Approval timeout",
})
return
finally:
_store._approval_events.pop(event_key, None)
elif stage.type == "condition": elif stage.type == "condition":
# Evaluate condition expression # Evaluate condition expression
condition_expr = stage.config.get("expression", "") condition_expr = stage.config.get("expression", "")
@ -318,23 +359,60 @@ async def _execute_workflow(
}) })
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
_SAFE_OPERATORS = {"==", "!=", ">", "<", ">=", "<="}
def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool: def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
"""Simple condition evaluation.""" """Evaluate a condition expression safely."""
expression = expression.strip()
if not expression: if not expression:
return True return True
if "==" in expression:
parts = expression.split("==", 1) # Try each operator (longer operators first to avoid partial matches)
left = variables.get(parts[0].strip(), parts[0].strip()) for op in sorted(_SAFE_OPERATORS, key=len, reverse=True):
right = parts[1].strip().strip("'\"") if op in expression:
return str(left) == right parts = expression.split(op, 1)
elif "!=" in expression: if len(parts) != 2:
parts = expression.split("!=", 1) continue
left = variables.get(parts[0].strip(), parts[0].strip()) left = parts[0].strip()
right = parts[1].strip().strip("'\"") right = parts[1].strip()
return str(left) != right
else: # Validate variable names
if left and not _SAFE_VAR_PATTERN.match(left):
raise ValueError(f"Invalid variable name in condition: {left}")
left_val = variables.get(left, left)
# Strip quotes from right side if present
if right.startswith('"') and right.endswith('"'):
right_val = right[1:-1]
elif right.startswith("'") and right.endswith("'"):
right_val = right[1:-1]
elif right and _SAFE_VAR_PATTERN.match(right):
right_val = variables.get(right, right)
else:
right_val = right
# Compare based on operator
if op == "==":
return str(left_val) == str(right_val)
if op == "!=":
return str(left_val) != str(right_val)
if op == ">":
return float(left_val) > float(right_val)
if op == "<":
return float(left_val) < float(right_val)
if op == ">=":
return float(left_val) >= float(right_val)
if op == "<=":
return float(left_val) <= float(right_val)
# Boolean check for variable existence
if _SAFE_VAR_PATTERN.match(expression):
return bool(variables.get(expression)) return bool(variables.get(expression))
raise ValueError(f"Invalid condition expression: {expression}")
async def _broadcast_ws(message: dict[str, Any]) -> None: async def _broadcast_ws(message: dict[str, Any]) -> None:
"""Broadcast a message to all WebSocket subscribers.""" """Broadcast a message to all WebSocket subscribers."""
@ -487,12 +565,12 @@ async def approve_execution(
status="running", status="running",
stage_results=execution.stage_results, stage_results=execution.stage_results,
) )
# Resume execution # Resume the waiting execution by setting the approval event
workflow = store.get(execution.workflow_id) stage_name = execution.current_stage
if workflow: if stage_name:
asyncio.create_task( event_key = f"{execution_id}:{stage_name}"
_execute_workflow(workflow, execution, execution.variables, store=store) if event_key in store._approval_events:
) store._approval_events[event_key].set()
else: else:
execution.status = "cancelled" execution.status = "cancelled"
execution.completed_at = datetime.now(timezone.utc).isoformat() execution.completed_at = datetime.now(timezone.utc).isoformat()
@ -508,6 +586,12 @@ async def approve_execution(
completed_at=execution.completed_at, completed_at=execution.completed_at,
stage_results=execution.stage_results, stage_results=execution.stage_results,
) )
# Set the approval event so the waiting coroutine can observe the cancelled state
stage_name = execution.current_stage
if stage_name:
event_key = f"{execution_id}:{stage_name}"
if event_key in store._approval_events:
store._approval_events[event_key].set()
return execution.model_dump() return execution.model_dump()

View File

@ -18,6 +18,8 @@ from dataclasses import dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 自动应答规则:(prompt_pattern, response) # 自动应答规则:(prompt_pattern, response)
# WARNING: auto_respond is disabled by default for safety.
# Enable it only when you explicitly want automatic yes/confirm responses.
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [ _AUTO_RESPOND_RULES: list[tuple[str, str]] = [
(r"\[y/N\]\s*$", "y"), (r"\[y/N\]\s*$", "y"),
(r"\[Y/n\]\s*$", "y"), (r"\[Y/n\]\s*$", "y"),
@ -61,7 +63,7 @@ class PTYSession:
def __init__( def __init__(
self, self,
auto_respond: bool = True, auto_respond: bool = False,
custom_rules: list[tuple[str, str]] | None = None, custom_rules: list[tuple[str, str]] | None = None,
default_timeout: float = 30.0, default_timeout: float = 30.0,
buffer_size: int = 4096, buffer_size: int = 4096,

View File

@ -9,6 +9,8 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import re
import shlex
import time import time
from typing import Any, Callable, Awaitable from typing import Any, Callable, Awaitable
@ -21,6 +23,8 @@ logger = logging.getLogger(__name__)
# 安全白名单:这些命令前缀不需要确认 # 安全白名单:这些命令前缀不需要确认
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"cd",
"export",
"ls", "ls",
"cat", "cat",
"head", "head",
@ -48,6 +52,7 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"sort", "sort",
"uniq", "uniq",
"diff", "diff",
"sleep",
"git status", "git status",
"git log", "git log",
"git diff", "git diff",
@ -101,6 +106,9 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
) )
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
class ShellTool(Tool): class ShellTool(Tool):
"""Shell 命令执行工具 """Shell 命令执行工具
@ -364,18 +372,39 @@ class ShellTool(Tool):
""" """
command_stripped = command.strip() command_stripped = command.strip()
# 白名单检查 # Check for shell operators that chain commands (always dangerous)
for prefix in _SAFE_COMMAND_PREFIXES: if _SHELL_OPERATORS.search(command_stripped):
if command_stripped.startswith(prefix): return True
return False
# 危险模式检查 # Parse the actual binary being invoked
try:
tokens = shlex.split(command_stripped)
if not tokens:
return True
binary = os.path.basename(tokens[0])
except ValueError:
# Unparsable command - treat as dangerous
return True
# Whitelist check: first try full command prefix match, then binary-only match
for prefix in _SAFE_COMMAND_PREFIXES:
prefix_stripped = prefix.lower().strip()
if " " in prefix_stripped:
# Compound prefix like "git status" - match against full command
if command_stripped.lower().startswith(prefix_stripped):
return False
else:
# Simple prefix - match against binary name only
if binary.lower().startswith(prefix_stripped):
return False
# Dangerous pattern check
command_lower = command_stripped.lower() command_lower = command_stripped.lower()
for pattern in _DANGEROUS_PATTERNS: for pattern in _DANGEROUS_PATTERNS:
if pattern in command_lower: if pattern in command_lower:
return True return True
return False return True # Unknown commands are dangerous by default
async def _request_confirmation(self, command: str) -> bool: async def _request_confirmation(self, command: str) -> bool:
"""请求人工确认危险命令 """请求人工确认危险命令

View File

@ -9,6 +9,8 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import re
import shlex
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@ -17,6 +19,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ENV_KEY_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
@dataclass @dataclass
class CommandRecord: class CommandRecord:
@ -190,15 +194,13 @@ class TerminalSession:
# 注入 cd # 注入 cd
if self._cwd: if self._cwd:
# 使用 shlex.quote 风格的简单转义 parts.append(f"cd {shlex.quote(self._cwd)}")
cwd_escaped = self._cwd.replace("'", "'\\''")
parts.append(f"cd '{cwd_escaped}'")
# 注入环境变量 # 注入环境变量
for key, value in self._env.items(): for key, value in self._env.items():
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度 if not _ENV_KEY_PATTERN.match(key):
val_escaped = value.replace("'", "'\\''") continue # Skip invalid env key names
parts.append(f"export {key}='{val_escaped}'") parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
parts.append(command) parts.append(command)
return " && ".join(parts) return " && ".join(parts)

View File

@ -0,0 +1 @@
"""AgentKit utility modules."""

View File

@ -0,0 +1,30 @@
"""Shared vector math utilities."""
from __future__ import annotations
import math
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
"""Compute cosine similarity between two vectors.
Args:
vec_a: First vector.
vec_b: Second vector.
Returns:
Cosine similarity score between -1 and 1.
"""
if len(vec_a) != len(vec_b):
return 0.0
if not vec_a:
return 0.0
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return dot_product / (norm_a * norm_b)

View File

@ -433,12 +433,34 @@ async def test_workflow_with_approval():
assert execution.status == "pending" assert execution.status == "pending"
assert execution.execution_id assert execution.execution_id
# 5. Execute workflow (runs in background) # 5. Execute workflow in background (approval stage will wait for event)
from agentkit.server.routes.workflows import _execute_workflow from agentkit.server.routes.workflows import _execute_workflow
await _execute_workflow(workflow, execution, variables={}, store=store) # Use a short approval timeout for testing
workflow.stages[1].config["approval_timeout"] = 5
# 6. Verify execution completed (auto-approval in test mode) async def _approve_after_pause():
"""Wait for execution to pause, then approve."""
for _ in range(100):
await asyncio.sleep(0.05)
updated = store.get_execution(execution.execution_id)
if updated and updated.status == "paused":
break
# Trigger approval
event_key = f"{execution.execution_id}:human_review"
if event_key in store._approval_events:
execution.stage_results["human_review"] = {
"status": "approved",
"approver": "test_user",
"comment": "Auto-approved in test",
}
store._approval_events[event_key].set()
approve_task = asyncio.create_task(_approve_after_pause())
await _execute_workflow(workflow, execution, variables={}, store=store)
await approve_task
# 6. Verify execution completed
updated = store.get_execution(execution.execution_id) updated = store.get_execution(execution.execution_id)
assert updated is not None assert updated is not None
assert updated.status == "completed" assert updated.status == "completed"
@ -452,7 +474,7 @@ async def test_workflow_with_approval():
approval_result = updated.stage_results["human_review"] approval_result = updated.stage_results["human_review"]
assert approval_result.get("status") in ("approved", "completed") assert approval_result.get("status") in ("approved", "completed")
# 9. Test manual approval flow # 9. Test second workflow with approval
workflow2 = WorkflowDefinition( workflow2 = WorkflowDefinition(
workflow_id="wf-manual-approval", workflow_id="wf-manual-approval",
name="手动审批流程", name="手动审批流程",
@ -468,6 +490,7 @@ async def test_workflow_with_approval():
agent="reviewer", agent="reviewer",
action="approve", action="approve",
type="approval", type="approval",
config={"approval_timeout": 5},
depends_on=["step1"], depends_on=["step1"],
), ),
WorkflowStage( WorkflowStage(
@ -481,28 +504,30 @@ async def test_workflow_with_approval():
) )
store.save(workflow2) store.save(workflow2)
# Simulate manual approval via API
execution2 = store.create_execution(workflow2.workflow_id) execution2 = store.create_execution(workflow2.workflow_id)
execution2.status = "paused"
execution2.current_stage = "approval_step"
store.update_execution(
execution2.execution_id,
status="paused",
current_stage="approval_step",
)
# Approve async def _approve2_after_pause():
execution2.stage_results["approval_step"] = { for _ in range(100):
"status": "approved", await asyncio.sleep(0.05)
"approver": "user", updated2 = store.get_execution(execution2.execution_id)
"comment": "LGTM", if updated2 and updated2.status == "paused":
} break
execution2.status = "running" event_key2 = f"{execution2.execution_id}:approval_step"
store.update_execution( if event_key2 in store._approval_events:
execution2.execution_id, execution2.stage_results["approval_step"] = {
status="running", "status": "approved",
stage_results=execution2.stage_results, "approver": "user",
) "comment": "LGTM",
}
store._approval_events[event_key2].set()
approve_task2 = asyncio.create_task(_approve2_after_pause())
await _execute_workflow(workflow2, execution2, variables={}, store=store)
await approve_task2
updated2 = store.get_execution(execution2.execution_id)
assert updated2 is not None
assert updated2.status == "completed"
# Verify approval was recorded # Verify approval was recorded
paused_exec = store.get_execution(execution2.execution_id) paused_exec = store.get_execution(execution2.execution_id)
@ -1014,7 +1039,27 @@ async def test_multi_source_rag_with_workflow(local_rag, mock_embedder):
execution = store.create_execution(workflow.workflow_id) execution = store.create_execution(workflow.workflow_id)
from agentkit.server.routes.workflows import _execute_workflow from agentkit.server.routes.workflows import _execute_workflow
# Set short approval timeout and handle approval
workflow.stages[1].config["approval_timeout"] = 5
async def _approve_kb_review():
for _ in range(100):
await asyncio.sleep(0.05)
upd = store.get_execution(execution.execution_id)
if upd and upd.status == "paused":
break
event_key = f"{execution.execution_id}:review_findings"
if event_key in store._approval_events:
execution.stage_results["review_findings"] = {
"status": "approved",
"approver": "test_user",
"comment": "Approved",
}
store._approval_events[event_key].set()
approve_task = asyncio.create_task(_approve_kb_review())
await _execute_workflow(workflow, execution, variables={}, store=store) await _execute_workflow(workflow, execution, variables={}, store=store)
await approve_task
updated = store.get_execution(execution.execution_id) updated = store.get_execution(execution.execution_id)
assert updated.status == "completed" assert updated.status == "completed"

View File

@ -9,10 +9,10 @@ import pytest
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
from agentkit.evolution.experience_store import ( from agentkit.evolution.experience_store import (
InMemoryExperienceStore, InMemoryExperienceStore,
_compute_cosine_similarity,
_parse_time_window, _parse_time_window,
) )
from agentkit.memory.embedder import MockEmbedder from agentkit.memory.embedder import MockEmbedder
from agentkit.utils.vector_math import compute_cosine_similarity
# ── Fixtures ────────────────────────────────────────────── # ── Fixtures ──────────────────────────────────────────────
@ -136,23 +136,23 @@ class TestEvolutionMetrics:
class TestHelperFunctions: class TestHelperFunctions:
def test_cosine_similarity_identical(self): def test_cosine_similarity_identical(self):
vec = [1.0, 0.0, 0.0] vec = [1.0, 0.0, 0.0]
assert _compute_cosine_similarity(vec, vec) == pytest.approx(1.0) assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
def test_cosine_similarity_orthogonal(self): def test_cosine_similarity_orthogonal(self):
a = [1.0, 0.0] a = [1.0, 0.0]
b = [0.0, 1.0] b = [0.0, 1.0]
assert _compute_cosine_similarity(a, b) == pytest.approx(0.0) assert compute_cosine_similarity(a, b) == pytest.approx(0.0)
def test_cosine_similarity_opposite(self): def test_cosine_similarity_opposite(self):
a = [1.0, 0.0] a = [1.0, 0.0]
b = [-1.0, 0.0] b = [-1.0, 0.0]
assert _compute_cosine_similarity(a, b) == pytest.approx(-1.0) assert compute_cosine_similarity(a, b) == pytest.approx(-1.0)
def test_cosine_similarity_empty(self): def test_cosine_similarity_empty(self):
assert _compute_cosine_similarity([], []) == 0.0 assert compute_cosine_similarity([], []) == 0.0
def test_cosine_similarity_mismatched_dims(self): def test_cosine_similarity_mismatched_dims(self):
assert _compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0 assert compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
def test_parse_time_window_hours(self): def test_parse_time_window_hours(self):
delta = _parse_time_window("24h") delta = _parse_time_window("24h")

View File

@ -1,6 +1,7 @@
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器""" """Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
import pytest import pytest
import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
@ -64,7 +65,7 @@ class TestKnowledgeBaseProtocol:
assert isinstance(adapter2, KnowledgeBase) assert isinstance(adapter2, KnowledgeBase)
adapter3 = GenericHTTPAdapter( adapter3 = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
assert isinstance(adapter3, KnowledgeBase) assert isinstance(adapter3, KnowledgeBase)
@ -256,6 +257,8 @@ class TestFeishuKBAdapterSearch:
async def test_search_success(self, adapter): async def test_search_success(self, adapter):
# Mock authentication # Mock authentication
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
@ -307,6 +310,7 @@ class TestFeishuKBAdapterSearch:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_api_error(self, adapter): async def test_search_api_error(self, adapter):
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
@ -328,6 +332,7 @@ class TestFeishuKBAdapterSearch:
import httpx import httpx
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 500 mock_resp.status_code = 500
@ -380,6 +385,7 @@ class TestFeishuKBAdapterListSources:
async def test_list_sources_success(self): async def test_list_sources_success(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
@ -420,6 +426,7 @@ class TestFeishuKBAdapterGetDocument:
async def test_get_document_success(self): async def test_get_document_success(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
@ -450,6 +457,7 @@ class TestFeishuKBAdapterGetDocument:
async def test_get_document_not_found(self): async def test_get_document_not_found(self):
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
adapter._access_token = "t-xxx" adapter._access_token = "t-xxx"
adapter._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
@ -736,23 +744,23 @@ class TestGenericHTTPAdapterInit:
def test_basic_init(self): def test_basic_init(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
assert adapter._endpoint_url == "http://localhost:8000/api/kb" assert adapter._endpoint_url == "https://example.com/api/kb"
assert adapter._auth_config == {} assert adapter._auth_config == {}
assert adapter._extra_headers == {} assert adapter._extra_headers == {}
assert adapter._source_type == "generic_http" assert adapter._source_type == "generic_http"
def test_init_with_auth_bearer(self): def test_init_with_auth_bearer(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb/", endpoint_url="https://example.com/api/kb/",
auth_config={"type": "bearer", "token": "sk-test"}, auth_config={"type": "bearer", "token": "sk-test"},
headers={"X-Custom": "value"}, headers={"X-Custom": "value"},
source_id="my-kb", source_id="my-kb",
source_name="My KB", source_name="My KB",
timeout=60, timeout=60,
) )
assert adapter._endpoint_url == "http://localhost:8000/api/kb" assert adapter._endpoint_url == "https://example.com/api/kb"
assert adapter._auth_config["type"] == "bearer" assert adapter._auth_config["type"] == "bearer"
assert adapter._extra_headers == {"X-Custom": "value"} assert adapter._extra_headers == {"X-Custom": "value"}
assert adapter._source_id == "my-kb" assert adapter._source_id == "my-kb"
@ -760,7 +768,7 @@ class TestGenericHTTPAdapterInit:
def test_client_bearer_auth_header(self): def test_client_bearer_auth_header(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
auth_config={"type": "bearer", "token": "sk-test"}, auth_config={"type": "bearer", "token": "sk-test"},
) )
client = adapter._make_client() client = adapter._make_client()
@ -768,7 +776,7 @@ class TestGenericHTTPAdapterInit:
def test_client_basic_auth_header(self): def test_client_basic_auth_header(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
auth_config={"type": "basic", "username": "user", "password": "pass"}, auth_config={"type": "basic", "username": "user", "password": "pass"},
) )
client = adapter._make_client() client = adapter._make_client()
@ -777,7 +785,7 @@ class TestGenericHTTPAdapterInit:
def test_client_api_key_header(self): def test_client_api_key_header(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"}, auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
) )
client = adapter._make_client() client = adapter._make_client()
@ -790,7 +798,7 @@ class TestGenericHTTPAdapterSearch:
@pytest.fixture @pytest.fixture
def adapter(self): def adapter(self):
return GenericHTTPAdapter( return GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
auth_config={"type": "bearer", "token": "sk-test"}, auth_config={"type": "bearer", "token": "sk-test"},
) )
@ -892,7 +900,7 @@ class TestGenericHTTPAdapterIngest:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ingest_success(self): async def test_ingest_success(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -921,7 +929,7 @@ class TestGenericHTTPAdapterIngest:
import httpx import httpx
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -946,7 +954,7 @@ class TestGenericHTTPAdapterDeleteById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_success(self): async def test_delete_success(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -963,7 +971,7 @@ class TestGenericHTTPAdapterDeleteById:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_not_found(self): async def test_delete_not_found(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -983,7 +991,7 @@ class TestGenericHTTPAdapterGetDocument:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_document_success(self): async def test_get_document_success(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -1011,7 +1019,7 @@ class TestGenericHTTPAdapterGetDocument:
import httpx import httpx
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -1034,7 +1042,7 @@ class TestGenericHTTPAdapterHealthCheck:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check_ok(self): async def test_health_check_ok(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -1050,7 +1058,7 @@ class TestGenericHTTPAdapterHealthCheck:
async def test_health_check_fallback_to_root(self): async def test_health_check_fallback_to_root(self):
"""health endpoint 不存在时回退到根路径""" """health endpoint 不存在时回退到根路径"""
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
import httpx import httpx
@ -1076,7 +1084,7 @@ class TestGenericHTTPAdapterHealthCheck:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check_connection_error(self): async def test_health_check_connection_error(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_client = AsyncMock() mock_client = AsyncMock()
@ -1092,7 +1100,7 @@ class TestGenericHTTPAdapterListSources:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_sources_success(self): async def test_list_sources_success(self):
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_resp = MagicMock() mock_resp = MagicMock()
@ -1116,7 +1124,7 @@ class TestGenericHTTPAdapterListSources:
async def test_list_sources_endpoint_not_found(self): async def test_list_sources_endpoint_not_found(self):
"""sources endpoint 不存在时返回默认信息源""" """sources endpoint 不存在时返回默认信息源"""
adapter = GenericHTTPAdapter( adapter = GenericHTTPAdapter(
endpoint_url="http://localhost:8000/api/kb", endpoint_url="https://example.com/api/kb",
) )
mock_client = AsyncMock() mock_client = AsyncMock()
@ -1146,7 +1154,7 @@ class TestCrossAdapterIntegration:
username="user@test.com", username="user@test.com",
api_token="token", api_token="token",
), ),
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"), GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
] ]
for adapter in adapters: for adapter in adapters:
assert isinstance(adapter, KnowledgeBase) assert isinstance(adapter, KnowledgeBase)
@ -1166,7 +1174,7 @@ class TestCrossAdapterIntegration:
username="user@test.com", username="user@test.com",
api_token="token", api_token="token",
), ),
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"), GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
] ]
for adapter in adapters: for adapter in adapters:
assert hasattr(adapter, "search") assert hasattr(adapter, "search")
@ -1179,6 +1187,7 @@ class TestCrossAdapterIntegration:
# Feishu # Feishu
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret") feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
feishu._access_token = "t-xxx" feishu._access_token = "t-xxx"
feishu._token_expiry = time.time() + 7200
mock_resp = MagicMock() mock_resp = MagicMock()
mock_resp.status_code = 200 mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock() mock_resp.raise_for_status = MagicMock()
@ -1220,7 +1229,7 @@ class TestCrossAdapterIntegration:
assert all(isinstance(r, QueryResult) for r in results) assert all(isinstance(r, QueryResult) for r in results)
# GenericHTTP # GenericHTTP
generic = GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb") generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb")
mock_resp3 = MagicMock() mock_resp3 = MagicMock()
mock_resp3.status_code = 200 mock_resp3.status_code = 200
mock_resp3.raise_for_status = MagicMock() mock_resp3.raise_for_status = MagicMock()

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase
from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.base import MemoryItem from agentkit.memory.base import MemoryItem
from agentkit.memory.embedder import MockEmbedder from agentkit.memory.embedder import MockEmbedder
from agentkit.utils.vector_math import compute_cosine_similarity
# ── 真实 SQLAlchemy 模型(用于测试) ───────────────────── # ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
@ -112,40 +113,40 @@ def _make_row_mapping(data: dict) -> _RowMapping:
class TestCosineSimilarity: class TestCosineSimilarity:
"""_compute_cosine_similarity 测试""" """compute_cosine_similarity 测试"""
def test_identical_vectors_return_one(self): def test_identical_vectors_return_one(self):
"""相同向量余弦相似度为 1""" """相同向量余弦相似度为 1"""
vec = [1.0, 0.0, 0.0] vec = [1.0, 0.0, 0.0]
assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0) assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
def test_orthogonal_vectors_return_zero(self): def test_orthogonal_vectors_return_zero(self):
"""正交向量余弦相似度为 0""" """正交向量余弦相似度为 0"""
vec_a = [1.0, 0.0] vec_a = [1.0, 0.0]
vec_b = [0.0, 1.0] vec_b = [0.0, 1.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0) assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
def test_opposite_vectors_return_minus_one(self): def test_opposite_vectors_return_minus_one(self):
"""相反向量余弦相似度为 -1""" """相反向量余弦相似度为 -1"""
vec_a = [1.0, 0.0] vec_a = [1.0, 0.0]
vec_b = [-1.0, 0.0] vec_b = [-1.0, 0.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0) assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
def test_dimension_mismatch_returns_zero(self): def test_dimension_mismatch_returns_zero(self):
"""维度不匹配返回 0""" """维度不匹配返回 0"""
vec_a = [1.0, 2.0] vec_a = [1.0, 2.0]
vec_b = [1.0] vec_b = [1.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0 assert compute_cosine_similarity(vec_a, vec_b) == 0.0
def test_empty_vectors_return_zero(self): def test_empty_vectors_return_zero(self):
"""空向量返回 0""" """空向量返回 0"""
assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0 assert compute_cosine_similarity([], []) == 0.0
def test_zero_vector_returns_zero(self): def test_zero_vector_returns_zero(self):
"""零向量返回 0""" """零向量返回 0"""
vec_a = [0.0, 0.0] vec_a = [0.0, 0.0]
vec_b = [1.0, 2.0] vec_b = [1.0, 2.0]
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0 assert compute_cosine_similarity(vec_a, vec_b) == 0.0
# ── MockEmbedder 测试 ─────────────────────────────────── # ── MockEmbedder 测试 ───────────────────────────────────

View File

@ -25,7 +25,7 @@ class TestPTYSessionConstruction:
def test_default_construction(self): def test_default_construction(self):
pty = PTYSession() pty = PTYSession()
assert pty.is_running is False assert pty.is_running is False
assert pty._auto_respond is True assert pty._auto_respond is False
assert pty._default_timeout == 30.0 assert pty._default_timeout == 30.0
def test_custom_construction(self): def test_custom_construction(self):