fix(security): resolve all P0/P1 findings from code review
This commit is contained in:
parent
b34f74f598
commit
9e9f1314f6
|
|
@ -496,6 +496,6 @@ class PlanExecutor:
|
|||
# 所有步骤要么完成要么跳过
|
||||
return TaskStatus.COMPLETED
|
||||
if failed > 0:
|
||||
return TaskStatus.COMPLETED # 部分成功
|
||||
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||||
|
||||
return TaskStatus.COMPLETED
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class TaskStatus(str, Enum):
|
|||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
PARTIALLY_COMPLETED = "partially_completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
HANDOFF = "handoff"
|
||||
|
|
|
|||
|
|
@ -12,14 +12,18 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
|
||||
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
||||
from agentkit.memory.embedder import Embedder
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -69,6 +73,8 @@ class ExperienceStore:
|
|||
self._retrieve_limit = retrieve_limit
|
||||
self._pgvector_enabled = pgvector_enabled
|
||||
self._table_name = table_name
|
||||
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
||||
|
||||
async def record_experience(self, experience: TaskExperience) -> str:
|
||||
"""记录任务经验
|
||||
|
|
@ -193,7 +199,7 @@ class ExperienceStore:
|
|||
time_decay_score = (row.get("success_rate") or 0.5) * decay
|
||||
|
||||
if row_embedding is not None:
|
||||
cosine_sim = _compute_cosine_similarity(query_embedding, row_embedding)
|
||||
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
|
@ -251,7 +257,7 @@ class ExperienceStore:
|
|||
time_decay_score = (entry.success_rate or 0.5) * decay
|
||||
|
||||
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
||||
cosine_sim = _compute_cosine_similarity(query_embedding, entry.embedding)
|
||||
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
|
@ -425,7 +431,7 @@ class InMemoryExperienceStore:
|
|||
time_decay_score = exp.success_rate * decay
|
||||
|
||||
if query_embedding is not None and exp.embedding is not None:
|
||||
cosine_sim = _compute_cosine_similarity(query_embedding, exp.embedding)
|
||||
cosine_sim = compute_cosine_similarity(query_embedding, exp.embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
|
@ -485,21 +491,6 @@ class InMemoryExperienceStore:
|
|||
# ── 辅助函数 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""计算两个向量的余弦相似度"""
|
||||
if len(vec_a) != len(vec_b):
|
||||
logger.warning(f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}")
|
||||
return 0.0
|
||||
if not vec_a:
|
||||
return 0.0
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
||||
|
||||
def _parse_time_window(window: str) -> timedelta:
|
||||
"""解析时间窗口字符串为 timedelta
|
||||
|
||||
|
|
|
|||
|
|
@ -207,10 +207,12 @@ class PitfallDetector:
|
|||
if error:
|
||||
s.failure_reasons.append(error)
|
||||
|
||||
# 收集优化建议
|
||||
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
|
||||
# 收集优化建议 — only add to steps that are part of this experience
|
||||
if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
|
||||
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
|
||||
for step_name, s in stats.items():
|
||||
s.optimization_tips.extend(exp.optimization_tips)
|
||||
if not experience_steps or step_name in experience_steps:
|
||||
s.optimization_tips.extend(exp.optimization_tips)
|
||||
|
||||
return stats
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -17,6 +19,33 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_cql(value: str) -> str:
|
||||
"""Escape special characters in CQL values."""
|
||||
return value.replace("\\", "\\\\").replace('"', '\\"')
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ConfluenceAdapter(KBAdapter):
|
||||
"""Confluence 知识库适配器
|
||||
|
||||
|
|
@ -49,6 +78,8 @@ class ConfluenceAdapter(KBAdapter):
|
|||
timeout=timeout,
|
||||
)
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not _is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
self._username = username
|
||||
self._api_token = api_token
|
||||
self._space_keys = space_keys or []
|
||||
|
|
@ -88,10 +119,10 @@ class ConfluenceAdapter(KBAdapter):
|
|||
"""
|
||||
client = self._get_client()
|
||||
try:
|
||||
cql = f'text ~ "{query}"'
|
||||
cql = f'text ~ "{_escape_cql(query)}"'
|
||||
if self._space_keys:
|
||||
space_filter = " OR ".join(
|
||||
f'space = "{key}"' for key in self._space_keys
|
||||
f'space = "{_escape_cql(key)}"' for key in self._space_keys
|
||||
)
|
||||
cql = f'{cql} AND ({space_filter})'
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,11 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -17,6 +20,28 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class FeishuKBAdapter(KBAdapter):
|
||||
"""飞书知识库适配器
|
||||
|
||||
|
|
@ -51,8 +76,11 @@ class FeishuKBAdapter(KBAdapter):
|
|||
self._app_id = app_id
|
||||
self._app_secret = app_secret
|
||||
self._base_url = base_url.rstrip("/")
|
||||
if not _is_safe_url(self._base_url):
|
||||
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||
self._space_ids = space_ids or []
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0.0
|
||||
|
||||
def _make_client(self) -> httpx.AsyncClient:
|
||||
"""创建飞书 API HTTP 客户端"""
|
||||
|
|
@ -67,7 +95,7 @@ class FeishuKBAdapter(KBAdapter):
|
|||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""获取飞书 tenant_access_token"""
|
||||
if self._access_token:
|
||||
if self._access_token and time.time() < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
client = self._get_client()
|
||||
|
|
@ -83,6 +111,8 @@ class FeishuKBAdapter(KBAdapter):
|
|||
data = resp.json()
|
||||
if data.get("code") == 0:
|
||||
self._access_token = data.get("tenant_access_token")
|
||||
expire_seconds = data.get("expire", 7200)
|
||||
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
|
||||
# 重建客户端以携带 token
|
||||
await self.close()
|
||||
return self._access_token
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -17,6 +19,31 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
# Block common internal hostnames
|
||||
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||
return False
|
||||
# Try to resolve as IP and check for private ranges
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||
return False
|
||||
except ValueError:
|
||||
# Not an IP address, that's OK (it's a domain name)
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class GenericHTTPAdapter(KBAdapter):
|
||||
"""通用 HTTP 知识库适配器
|
||||
|
||||
|
|
@ -53,6 +80,8 @@ class GenericHTTPAdapter(KBAdapter):
|
|||
timeout=timeout,
|
||||
)
|
||||
self._endpoint_url = endpoint_url.rstrip("/")
|
||||
if not _is_safe_url(self._endpoint_url):
|
||||
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
||||
self._auth_config = auth_config or {}
|
||||
self._extra_headers = headers or {}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from sqlalchemy import text
|
|||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
from agentkit.memory.embedder import Embedder
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -123,7 +124,7 @@ class EpisodicMemory(Memory):
|
|||
if row_embedding is None:
|
||||
return None
|
||||
|
||||
cosine = self._compute_cosine_similarity(query_embedding, row_embedding)
|
||||
cosine = compute_cosine_similarity(query_embedding, row_embedding)
|
||||
if cosine < 0.1:
|
||||
return None
|
||||
|
||||
|
|
@ -165,7 +166,7 @@ class EpisodicMemory(Memory):
|
|||
entry_embedding = entry.embedding
|
||||
if entry_embedding is None:
|
||||
continue
|
||||
cosine = self._compute_cosine_similarity(query_embedding, entry_embedding)
|
||||
cosine = compute_cosine_similarity(query_embedding, entry_embedding)
|
||||
if cosine > best_score:
|
||||
best_score = cosine
|
||||
best_item = entry
|
||||
|
|
@ -260,7 +261,7 @@ class EpisodicMemory(Memory):
|
|||
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
||||
|
||||
if row_embedding is not None:
|
||||
cosine_sim = self._compute_cosine_similarity(query_embedding, row_embedding)
|
||||
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
|
@ -327,7 +328,7 @@ class EpisodicMemory(Memory):
|
|||
|
||||
# 混合评分:alpha * cosine + (1 - alpha) * time_decay
|
||||
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
||||
cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding)
|
||||
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
|
@ -375,20 +376,3 @@ class EpisodicMemory(Memory):
|
|||
await db.rollback()
|
||||
logger.error(f"Failed to delete episodic memory: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""计算两个向量的余弦相似度"""
|
||||
if len(vec_a) != len(vec_b):
|
||||
logger.warning(
|
||||
f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}"
|
||||
)
|
||||
return 0.0
|
||||
if not vec_a:
|
||||
return 0.0
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
|
|
|||
|
|
@ -10,10 +10,13 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
|
||||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||||
from agentkit.memory.document_loader import Document as LoaderDocument
|
||||
from agentkit.memory.embedder import Embedder
|
||||
|
|
@ -23,6 +26,7 @@ from agentkit.memory.knowledge_base import (
|
|||
QueryResult,
|
||||
SourceInfo,
|
||||
)
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -70,6 +74,8 @@ class LocalRAGService:
|
|||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._table_name = table_name
|
||||
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
||||
self._pgvector_enabled = pgvector_enabled
|
||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
|
@ -335,7 +341,7 @@ class LocalRAGService:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
|
||||
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||||
if cosine < 0.1:
|
||||
continue
|
||||
|
||||
|
|
@ -359,18 +365,6 @@ class LocalRAGService:
|
|||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||
return candidates[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""计算两个向量的余弦相似度"""
|
||||
if len(vec_a) != len(vec_b) or not vec_a:
|
||||
return 0.0
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
||||
|
||||
class InMemoryLocalRAGService:
|
||||
"""基于内存的本地 RAG 服务
|
||||
|
|
@ -447,7 +441,7 @@ class InMemoryLocalRAGService:
|
|||
candidates = []
|
||||
for chunk_id, chunk_data in self._chunks.items():
|
||||
stored_embedding = chunk_data["embedding"]
|
||||
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
|
||||
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||||
if cosine < 0.1:
|
||||
continue
|
||||
|
||||
|
|
@ -511,15 +505,3 @@ class InMemoryLocalRAGService:
|
|||
source_doc_id=doc.doc_id,
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""计算两个向量的余弦相似度"""
|
||||
if len(vec_a) != len(vec_b) or not vec_a:
|
||||
return 0.0
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
|
|
|||
|
|
@ -457,6 +457,23 @@ async def list_path_optimizations(
|
|||
@router.websocket("/evolution-dashboard/ws")
|
||||
async def evolution_dashboard_ws(websocket: WebSocket):
|
||||
"""自进化仪表盘实时更新 WebSocket"""
|
||||
# Authentication - check api_key
|
||||
configured_api_key: str | None = None
|
||||
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
|
||||
configured_api_key = websocket.app.state.server_config.api_key
|
||||
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
|
||||
configured_api_key = websocket.app.state.api_key
|
||||
|
||||
if configured_api_key:
|
||||
provided = websocket.query_params.get("api_key")
|
||||
if provided != configured_api_key:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
|
||||
)
|
||||
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
_ws_connections.append(websocket)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
router = APIRouter(tags=["kb-management"])
|
||||
|
||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory Knowledge Source Store
|
||||
|
|
@ -183,14 +185,18 @@ async def upload_document(
|
|||
try:
|
||||
from agentkit.memory.document_loader import DocumentLoader
|
||||
|
||||
content = await file.read()
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
|
||||
loader = DocumentLoader()
|
||||
doc = loader.load_bytes(content, file.filename)
|
||||
# Estimate chunks based on content length (rough approximation)
|
||||
chunks = max(1, len(doc.content) // 500)
|
||||
except ImportError:
|
||||
# DocumentLoader not available, use basic estimation
|
||||
content = await file.read()
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
|
||||
chunks = max(1, len(content) // 500)
|
||||
except Exception as e:
|
||||
logger.warning(f"Document parsing failed: {e}")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
|
@ -39,6 +40,7 @@ class WorkflowStore:
|
|||
self._executions: dict[str, WorkflowExecution] = {}
|
||||
self._max_workflows = max_workflows
|
||||
self._max_executions = max_executions
|
||||
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
|
||||
|
||||
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
||||
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -226,31 +228,70 @@ async def _execute_workflow(
|
|||
|
||||
try:
|
||||
if stage.type == "approval":
|
||||
# Pause execution and wait for approval
|
||||
# Pause execution and wait for approval via asyncio.Event
|
||||
event_key = f"{execution.execution_id}:{stage_name}"
|
||||
approval_event = asyncio.Event()
|
||||
_store._approval_events[event_key] = approval_event
|
||||
|
||||
execution.status = "paused"
|
||||
execution.current_stage = stage_name
|
||||
_store.update_execution(
|
||||
execution.execution_id,
|
||||
status="paused",
|
||||
current_stage=stage_name,
|
||||
)
|
||||
await _broadcast_ws({
|
||||
"event": "approval_required",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
})
|
||||
# In a real implementation, this would wait for external approval
|
||||
# For now, we simulate auto-approval after a brief pause
|
||||
await asyncio.sleep(0.1)
|
||||
execution.stage_results[stage_name] = {
|
||||
"status": "approved",
|
||||
"approver": "auto",
|
||||
"comment": "自动审批通过",
|
||||
}
|
||||
execution.status = "running"
|
||||
_store.update_execution(
|
||||
execution.execution_id,
|
||||
status="running",
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
|
||||
# Wait for approval with timeout
|
||||
try:
|
||||
approval_timeout = stage.config.get("approval_timeout", 3600)
|
||||
await asyncio.wait_for(approval_event.wait(), timeout=approval_timeout)
|
||||
# Check if execution was cancelled/rejected while waiting
|
||||
if execution.status == "cancelled":
|
||||
await _broadcast_ws({
|
||||
"event": "stage_failed",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
"error": "Approval rejected",
|
||||
})
|
||||
return
|
||||
# Approval was granted — the /approve endpoint already set stage_results
|
||||
# Only update status to running if not already set
|
||||
if execution.status != "running":
|
||||
execution.status = "running"
|
||||
_store.update_execution(
|
||||
execution.execution_id,
|
||||
status="running",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
execution.stage_results[stage_name] = {
|
||||
"status": "timeout",
|
||||
"approver": "none",
|
||||
"comment": "审批超时",
|
||||
}
|
||||
execution.status = "failed"
|
||||
execution.error = f"Approval timeout for stage {stage_name}"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
_store.update_execution(
|
||||
execution.execution_id,
|
||||
status="failed",
|
||||
error=execution.error,
|
||||
completed_at=execution.completed_at,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
await _broadcast_ws({
|
||||
"event": "stage_failed",
|
||||
"execution_id": execution.execution_id,
|
||||
"stage": stage_name,
|
||||
"error": "Approval timeout",
|
||||
})
|
||||
return
|
||||
finally:
|
||||
_store._approval_events.pop(event_key, None)
|
||||
elif stage.type == "condition":
|
||||
# Evaluate condition expression
|
||||
condition_expr = stage.config.get("expression", "")
|
||||
|
|
@ -318,23 +359,60 @@ async def _execute_workflow(
|
|||
})
|
||||
|
||||
|
||||
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||
_SAFE_OPERATORS = {"==", "!=", ">", "<", ">=", "<="}
|
||||
|
||||
|
||||
def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
||||
"""Simple condition evaluation."""
|
||||
"""Evaluate a condition expression safely."""
|
||||
expression = expression.strip()
|
||||
if not expression:
|
||||
return True
|
||||
if "==" in expression:
|
||||
parts = expression.split("==", 1)
|
||||
left = variables.get(parts[0].strip(), parts[0].strip())
|
||||
right = parts[1].strip().strip("'\"")
|
||||
return str(left) == right
|
||||
elif "!=" in expression:
|
||||
parts = expression.split("!=", 1)
|
||||
left = variables.get(parts[0].strip(), parts[0].strip())
|
||||
right = parts[1].strip().strip("'\"")
|
||||
return str(left) != right
|
||||
else:
|
||||
|
||||
# Try each operator (longer operators first to avoid partial matches)
|
||||
for op in sorted(_SAFE_OPERATORS, key=len, reverse=True):
|
||||
if op in expression:
|
||||
parts = expression.split(op, 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
left = parts[0].strip()
|
||||
right = parts[1].strip()
|
||||
|
||||
# Validate variable names
|
||||
if left and not _SAFE_VAR_PATTERN.match(left):
|
||||
raise ValueError(f"Invalid variable name in condition: {left}")
|
||||
|
||||
left_val = variables.get(left, left)
|
||||
# Strip quotes from right side if present
|
||||
if right.startswith('"') and right.endswith('"'):
|
||||
right_val = right[1:-1]
|
||||
elif right.startswith("'") and right.endswith("'"):
|
||||
right_val = right[1:-1]
|
||||
elif right and _SAFE_VAR_PATTERN.match(right):
|
||||
right_val = variables.get(right, right)
|
||||
else:
|
||||
right_val = right
|
||||
|
||||
# Compare based on operator
|
||||
if op == "==":
|
||||
return str(left_val) == str(right_val)
|
||||
if op == "!=":
|
||||
return str(left_val) != str(right_val)
|
||||
if op == ">":
|
||||
return float(left_val) > float(right_val)
|
||||
if op == "<":
|
||||
return float(left_val) < float(right_val)
|
||||
if op == ">=":
|
||||
return float(left_val) >= float(right_val)
|
||||
if op == "<=":
|
||||
return float(left_val) <= float(right_val)
|
||||
|
||||
# Boolean check for variable existence
|
||||
if _SAFE_VAR_PATTERN.match(expression):
|
||||
return bool(variables.get(expression))
|
||||
|
||||
raise ValueError(f"Invalid condition expression: {expression}")
|
||||
|
||||
|
||||
async def _broadcast_ws(message: dict[str, Any]) -> None:
|
||||
"""Broadcast a message to all WebSocket subscribers."""
|
||||
|
|
@ -487,12 +565,12 @@ async def approve_execution(
|
|||
status="running",
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Resume execution
|
||||
workflow = store.get(execution.workflow_id)
|
||||
if workflow:
|
||||
asyncio.create_task(
|
||||
_execute_workflow(workflow, execution, execution.variables, store=store)
|
||||
)
|
||||
# Resume the waiting execution by setting the approval event
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
else:
|
||||
execution.status = "cancelled"
|
||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -508,6 +586,12 @@ async def approve_execution(
|
|||
completed_at=execution.completed_at,
|
||||
stage_results=execution.stage_results,
|
||||
)
|
||||
# Set the approval event so the waiting coroutine can observe the cancelled state
|
||||
stage_name = execution.current_stage
|
||||
if stage_name:
|
||||
event_key = f"{execution_id}:{stage_name}"
|
||||
if event_key in store._approval_events:
|
||||
store._approval_events[event_key].set()
|
||||
|
||||
return execution.model_dump()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from dataclasses import dataclass
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 自动应答规则:(prompt_pattern, response)
|
||||
# WARNING: auto_respond is disabled by default for safety.
|
||||
# Enable it only when you explicitly want automatic yes/confirm responses.
|
||||
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
|
||||
(r"\[y/N\]\s*$", "y"),
|
||||
(r"\[Y/n\]\s*$", "y"),
|
||||
|
|
@ -61,7 +63,7 @@ class PTYSession:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
auto_respond: bool = True,
|
||||
auto_respond: bool = False,
|
||||
custom_rules: list[tuple[str, str]] | None = None,
|
||||
default_timeout: float = 30.0,
|
||||
buffer_size: int = 4096,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
|
|
@ -21,6 +23,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# 安全白名单:这些命令前缀不需要确认
|
||||
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||
"cd",
|
||||
"export",
|
||||
"ls",
|
||||
"cat",
|
||||
"head",
|
||||
|
|
@ -48,6 +52,7 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
|||
"sort",
|
||||
"uniq",
|
||||
"diff",
|
||||
"sleep",
|
||||
"git status",
|
||||
"git log",
|
||||
"git diff",
|
||||
|
|
@ -101,6 +106,9 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
|
|||
)
|
||||
|
||||
|
||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
|
||||
|
||||
|
||||
class ShellTool(Tool):
|
||||
"""Shell 命令执行工具
|
||||
|
||||
|
|
@ -364,18 +372,39 @@ class ShellTool(Tool):
|
|||
"""
|
||||
command_stripped = command.strip()
|
||||
|
||||
# 白名单检查
|
||||
for prefix in _SAFE_COMMAND_PREFIXES:
|
||||
if command_stripped.startswith(prefix):
|
||||
return False
|
||||
# Check for shell operators that chain commands (always dangerous)
|
||||
if _SHELL_OPERATORS.search(command_stripped):
|
||||
return True
|
||||
|
||||
# 危险模式检查
|
||||
# Parse the actual binary being invoked
|
||||
try:
|
||||
tokens = shlex.split(command_stripped)
|
||||
if not tokens:
|
||||
return True
|
||||
binary = os.path.basename(tokens[0])
|
||||
except ValueError:
|
||||
# Unparsable command - treat as dangerous
|
||||
return True
|
||||
|
||||
# Whitelist check: first try full command prefix match, then binary-only match
|
||||
for prefix in _SAFE_COMMAND_PREFIXES:
|
||||
prefix_stripped = prefix.lower().strip()
|
||||
if " " in prefix_stripped:
|
||||
# Compound prefix like "git status" - match against full command
|
||||
if command_stripped.lower().startswith(prefix_stripped):
|
||||
return False
|
||||
else:
|
||||
# Simple prefix - match against binary name only
|
||||
if binary.lower().startswith(prefix_stripped):
|
||||
return False
|
||||
|
||||
# Dangerous pattern check
|
||||
command_lower = command_stripped.lower()
|
||||
for pattern in _DANGEROUS_PATTERNS:
|
||||
if pattern in command_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
return True # Unknown commands are dangerous by default
|
||||
|
||||
async def _request_confirmation(self, command: str) -> bool:
|
||||
"""请求人工确认危险命令
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
|
@ -17,6 +19,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENV_KEY_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandRecord:
|
||||
|
|
@ -190,15 +194,13 @@ class TerminalSession:
|
|||
|
||||
# 注入 cd
|
||||
if self._cwd:
|
||||
# 使用 shlex.quote 风格的简单转义
|
||||
cwd_escaped = self._cwd.replace("'", "'\\''")
|
||||
parts.append(f"cd '{cwd_escaped}'")
|
||||
parts.append(f"cd {shlex.quote(self._cwd)}")
|
||||
|
||||
# 注入环境变量
|
||||
for key, value in self._env.items():
|
||||
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
|
||||
val_escaped = value.replace("'", "'\\''")
|
||||
parts.append(f"export {key}='{val_escaped}'")
|
||||
if not _ENV_KEY_PATTERN.match(key):
|
||||
continue # Skip invalid env key names
|
||||
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
|
||||
|
||||
parts.append(command)
|
||||
return " && ".join(parts)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""AgentKit utility modules."""
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
"""Shared vector math utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors.
|
||||
|
||||
Args:
|
||||
vec_a: First vector.
|
||||
vec_b: Second vector.
|
||||
|
||||
Returns:
|
||||
Cosine similarity score between -1 and 1.
|
||||
"""
|
||||
if len(vec_a) != len(vec_b):
|
||||
return 0.0
|
||||
if not vec_a:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
norm_a = math.sqrt(sum(a * a for a in vec_a))
|
||||
norm_b = math.sqrt(sum(b * b for b in vec_b))
|
||||
|
||||
if norm_a == 0.0 or norm_b == 0.0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
|
@ -433,12 +433,34 @@ async def test_workflow_with_approval():
|
|||
assert execution.status == "pending"
|
||||
assert execution.execution_id
|
||||
|
||||
# 5. Execute workflow (runs in background)
|
||||
# 5. Execute workflow in background (approval stage will wait for event)
|
||||
from agentkit.server.routes.workflows import _execute_workflow
|
||||
|
||||
await _execute_workflow(workflow, execution, variables={}, store=store)
|
||||
# Use a short approval timeout for testing
|
||||
workflow.stages[1].config["approval_timeout"] = 5
|
||||
|
||||
# 6. Verify execution completed (auto-approval in test mode)
|
||||
async def _approve_after_pause():
|
||||
"""Wait for execution to pause, then approve."""
|
||||
for _ in range(100):
|
||||
await asyncio.sleep(0.05)
|
||||
updated = store.get_execution(execution.execution_id)
|
||||
if updated and updated.status == "paused":
|
||||
break
|
||||
# Trigger approval
|
||||
event_key = f"{execution.execution_id}:human_review"
|
||||
if event_key in store._approval_events:
|
||||
execution.stage_results["human_review"] = {
|
||||
"status": "approved",
|
||||
"approver": "test_user",
|
||||
"comment": "Auto-approved in test",
|
||||
}
|
||||
store._approval_events[event_key].set()
|
||||
|
||||
approve_task = asyncio.create_task(_approve_after_pause())
|
||||
await _execute_workflow(workflow, execution, variables={}, store=store)
|
||||
await approve_task
|
||||
|
||||
# 6. Verify execution completed
|
||||
updated = store.get_execution(execution.execution_id)
|
||||
assert updated is not None
|
||||
assert updated.status == "completed"
|
||||
|
|
@ -452,7 +474,7 @@ async def test_workflow_with_approval():
|
|||
approval_result = updated.stage_results["human_review"]
|
||||
assert approval_result.get("status") in ("approved", "completed")
|
||||
|
||||
# 9. Test manual approval flow
|
||||
# 9. Test second workflow with approval
|
||||
workflow2 = WorkflowDefinition(
|
||||
workflow_id="wf-manual-approval",
|
||||
name="手动审批流程",
|
||||
|
|
@ -468,6 +490,7 @@ async def test_workflow_with_approval():
|
|||
agent="reviewer",
|
||||
action="approve",
|
||||
type="approval",
|
||||
config={"approval_timeout": 5},
|
||||
depends_on=["step1"],
|
||||
),
|
||||
WorkflowStage(
|
||||
|
|
@ -481,28 +504,30 @@ async def test_workflow_with_approval():
|
|||
)
|
||||
store.save(workflow2)
|
||||
|
||||
# Simulate manual approval via API
|
||||
execution2 = store.create_execution(workflow2.workflow_id)
|
||||
execution2.status = "paused"
|
||||
execution2.current_stage = "approval_step"
|
||||
store.update_execution(
|
||||
execution2.execution_id,
|
||||
status="paused",
|
||||
current_stage="approval_step",
|
||||
)
|
||||
|
||||
# Approve
|
||||
execution2.stage_results["approval_step"] = {
|
||||
"status": "approved",
|
||||
"approver": "user",
|
||||
"comment": "LGTM",
|
||||
}
|
||||
execution2.status = "running"
|
||||
store.update_execution(
|
||||
execution2.execution_id,
|
||||
status="running",
|
||||
stage_results=execution2.stage_results,
|
||||
)
|
||||
async def _approve2_after_pause():
|
||||
for _ in range(100):
|
||||
await asyncio.sleep(0.05)
|
||||
updated2 = store.get_execution(execution2.execution_id)
|
||||
if updated2 and updated2.status == "paused":
|
||||
break
|
||||
event_key2 = f"{execution2.execution_id}:approval_step"
|
||||
if event_key2 in store._approval_events:
|
||||
execution2.stage_results["approval_step"] = {
|
||||
"status": "approved",
|
||||
"approver": "user",
|
||||
"comment": "LGTM",
|
||||
}
|
||||
store._approval_events[event_key2].set()
|
||||
|
||||
approve_task2 = asyncio.create_task(_approve2_after_pause())
|
||||
await _execute_workflow(workflow2, execution2, variables={}, store=store)
|
||||
await approve_task2
|
||||
|
||||
updated2 = store.get_execution(execution2.execution_id)
|
||||
assert updated2 is not None
|
||||
assert updated2.status == "completed"
|
||||
|
||||
# Verify approval was recorded
|
||||
paused_exec = store.get_execution(execution2.execution_id)
|
||||
|
|
@ -1014,7 +1039,27 @@ async def test_multi_source_rag_with_workflow(local_rag, mock_embedder):
|
|||
execution = store.create_execution(workflow.workflow_id)
|
||||
from agentkit.server.routes.workflows import _execute_workflow
|
||||
|
||||
# Set short approval timeout and handle approval
|
||||
workflow.stages[1].config["approval_timeout"] = 5
|
||||
|
||||
async def _approve_kb_review():
|
||||
for _ in range(100):
|
||||
await asyncio.sleep(0.05)
|
||||
upd = store.get_execution(execution.execution_id)
|
||||
if upd and upd.status == "paused":
|
||||
break
|
||||
event_key = f"{execution.execution_id}:review_findings"
|
||||
if event_key in store._approval_events:
|
||||
execution.stage_results["review_findings"] = {
|
||||
"status": "approved",
|
||||
"approver": "test_user",
|
||||
"comment": "Approved",
|
||||
}
|
||||
store._approval_events[event_key].set()
|
||||
|
||||
approve_task = asyncio.create_task(_approve_kb_review())
|
||||
await _execute_workflow(workflow, execution, variables={}, store=store)
|
||||
await approve_task
|
||||
|
||||
updated = store.get_execution(execution.execution_id)
|
||||
assert updated.status == "completed"
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ import pytest
|
|||
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
||||
from agentkit.evolution.experience_store import (
|
||||
InMemoryExperienceStore,
|
||||
_compute_cosine_similarity,
|
||||
_parse_time_window,
|
||||
)
|
||||
from agentkit.memory.embedder import MockEmbedder
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────
|
||||
|
|
@ -136,23 +136,23 @@ class TestEvolutionMetrics:
|
|||
class TestHelperFunctions:
|
||||
def test_cosine_similarity_identical(self):
|
||||
vec = [1.0, 0.0, 0.0]
|
||||
assert _compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||
|
||||
def test_cosine_similarity_orthogonal(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [0.0, 1.0]
|
||||
assert _compute_cosine_similarity(a, b) == pytest.approx(0.0)
|
||||
assert compute_cosine_similarity(a, b) == pytest.approx(0.0)
|
||||
|
||||
def test_cosine_similarity_opposite(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [-1.0, 0.0]
|
||||
assert _compute_cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||
assert compute_cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||
|
||||
def test_cosine_similarity_empty(self):
|
||||
assert _compute_cosine_similarity([], []) == 0.0
|
||||
assert compute_cosine_similarity([], []) == 0.0
|
||||
|
||||
def test_cosine_similarity_mismatched_dims(self):
|
||||
assert _compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
|
||||
assert compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
|
||||
|
||||
def test_parse_time_window_hours(self):
|
||||
delta = _parse_time_window("24h")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
|
||||
|
|
@ -64,7 +65,7 @@ class TestKnowledgeBaseProtocol:
|
|||
assert isinstance(adapter2, KnowledgeBase)
|
||||
|
||||
adapter3 = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
assert isinstance(adapter3, KnowledgeBase)
|
||||
|
||||
|
|
@ -256,6 +257,8 @@ class TestFeishuKBAdapterSearch:
|
|||
async def test_search_success(self, adapter):
|
||||
# Mock authentication
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
|
@ -307,6 +310,7 @@ class TestFeishuKBAdapterSearch:
|
|||
@pytest.mark.asyncio
|
||||
async def test_search_api_error(self, adapter):
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
|
@ -328,6 +332,7 @@ class TestFeishuKBAdapterSearch:
|
|||
import httpx
|
||||
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
|
|
@ -380,6 +385,7 @@ class TestFeishuKBAdapterListSources:
|
|||
async def test_list_sources_success(self):
|
||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
|
@ -420,6 +426,7 @@ class TestFeishuKBAdapterGetDocument:
|
|||
async def test_get_document_success(self):
|
||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
|
@ -450,6 +457,7 @@ class TestFeishuKBAdapterGetDocument:
|
|||
async def test_get_document_not_found(self):
|
||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||
adapter._access_token = "t-xxx"
|
||||
adapter._token_expiry = time.time() + 7200
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
|
|
@ -736,23 +744,23 @@ class TestGenericHTTPAdapterInit:
|
|||
|
||||
def test_basic_init(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
|
||||
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||||
assert adapter._auth_config == {}
|
||||
assert adapter._extra_headers == {}
|
||||
assert adapter._source_type == "generic_http"
|
||||
|
||||
def test_init_with_auth_bearer(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb/",
|
||||
endpoint_url="https://example.com/api/kb/",
|
||||
auth_config={"type": "bearer", "token": "sk-test"},
|
||||
headers={"X-Custom": "value"},
|
||||
source_id="my-kb",
|
||||
source_name="My KB",
|
||||
timeout=60,
|
||||
)
|
||||
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
|
||||
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||||
assert adapter._auth_config["type"] == "bearer"
|
||||
assert adapter._extra_headers == {"X-Custom": "value"}
|
||||
assert adapter._source_id == "my-kb"
|
||||
|
|
@ -760,7 +768,7 @@ class TestGenericHTTPAdapterInit:
|
|||
|
||||
def test_client_bearer_auth_header(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
auth_config={"type": "bearer", "token": "sk-test"},
|
||||
)
|
||||
client = adapter._make_client()
|
||||
|
|
@ -768,7 +776,7 @@ class TestGenericHTTPAdapterInit:
|
|||
|
||||
def test_client_basic_auth_header(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
auth_config={"type": "basic", "username": "user", "password": "pass"},
|
||||
)
|
||||
client = adapter._make_client()
|
||||
|
|
@ -777,7 +785,7 @@ class TestGenericHTTPAdapterInit:
|
|||
|
||||
def test_client_api_key_header(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
|
||||
)
|
||||
client = adapter._make_client()
|
||||
|
|
@ -790,7 +798,7 @@ class TestGenericHTTPAdapterSearch:
|
|||
@pytest.fixture
|
||||
def adapter(self):
|
||||
return GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
auth_config={"type": "bearer", "token": "sk-test"},
|
||||
)
|
||||
|
||||
|
|
@ -892,7 +900,7 @@ class TestGenericHTTPAdapterIngest:
|
|||
@pytest.mark.asyncio
|
||||
async def test_ingest_success(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -921,7 +929,7 @@ class TestGenericHTTPAdapterIngest:
|
|||
import httpx
|
||||
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -946,7 +954,7 @@ class TestGenericHTTPAdapterDeleteById:
|
|||
@pytest.mark.asyncio
|
||||
async def test_delete_success(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -963,7 +971,7 @@ class TestGenericHTTPAdapterDeleteById:
|
|||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -983,7 +991,7 @@ class TestGenericHTTPAdapterGetDocument:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_document_success(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -1011,7 +1019,7 @@ class TestGenericHTTPAdapterGetDocument:
|
|||
import httpx
|
||||
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -1034,7 +1042,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
|||
@pytest.mark.asyncio
|
||||
async def test_health_check_ok(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -1050,7 +1058,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
|||
async def test_health_check_fallback_to_root(self):
|
||||
"""health endpoint 不存在时回退到根路径"""
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
|
@ -1076,7 +1084,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
|||
@pytest.mark.asyncio
|
||||
async def test_health_check_connection_error(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
|
@ -1092,7 +1100,7 @@ class TestGenericHTTPAdapterListSources:
|
|||
@pytest.mark.asyncio
|
||||
async def test_list_sources_success(self):
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
|
|
@ -1116,7 +1124,7 @@ class TestGenericHTTPAdapterListSources:
|
|||
async def test_list_sources_endpoint_not_found(self):
|
||||
"""sources endpoint 不存在时返回默认信息源"""
|
||||
adapter = GenericHTTPAdapter(
|
||||
endpoint_url="http://localhost:8000/api/kb",
|
||||
endpoint_url="https://example.com/api/kb",
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
|
|
@ -1146,7 +1154,7 @@ class TestCrossAdapterIntegration:
|
|||
username="user@test.com",
|
||||
api_token="token",
|
||||
),
|
||||
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
|
||||
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||||
]
|
||||
for adapter in adapters:
|
||||
assert isinstance(adapter, KnowledgeBase)
|
||||
|
|
@ -1166,7 +1174,7 @@ class TestCrossAdapterIntegration:
|
|||
username="user@test.com",
|
||||
api_token="token",
|
||||
),
|
||||
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
|
||||
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||||
]
|
||||
for adapter in adapters:
|
||||
assert hasattr(adapter, "search")
|
||||
|
|
@ -1179,6 +1187,7 @@ class TestCrossAdapterIntegration:
|
|||
# Feishu
|
||||
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||
feishu._access_token = "t-xxx"
|
||||
feishu._token_expiry = time.time() + 7200
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
|
@ -1220,7 +1229,7 @@ class TestCrossAdapterIntegration:
|
|||
assert all(isinstance(r, QueryResult) for r in results)
|
||||
|
||||
# GenericHTTP
|
||||
generic = GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb")
|
||||
generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb")
|
||||
mock_resp3 = MagicMock()
|
||||
mock_resp3.status_code = 200
|
||||
mock_resp3.raise_for_status = MagicMock()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase
|
|||
from agentkit.memory.episodic import EpisodicMemory
|
||||
from agentkit.memory.base import MemoryItem
|
||||
from agentkit.memory.embedder import MockEmbedder
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
|
||||
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
||||
|
|
@ -112,40 +113,40 @@ def _make_row_mapping(data: dict) -> _RowMapping:
|
|||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
"""_compute_cosine_similarity 测试"""
|
||||
"""compute_cosine_similarity 测试"""
|
||||
|
||||
def test_identical_vectors_return_one(self):
|
||||
"""相同向量余弦相似度为 1"""
|
||||
vec = [1.0, 0.0, 0.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_vectors_return_zero(self):
|
||||
"""正交向量余弦相似度为 0"""
|
||||
vec_a = [1.0, 0.0]
|
||||
vec_b = [0.0, 1.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
||||
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
||||
|
||||
def test_opposite_vectors_return_minus_one(self):
|
||||
"""相反向量余弦相似度为 -1"""
|
||||
vec_a = [1.0, 0.0]
|
||||
vec_b = [-1.0, 0.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
||||
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
||||
|
||||
def test_dimension_mismatch_returns_zero(self):
|
||||
"""维度不匹配返回 0"""
|
||||
vec_a = [1.0, 2.0]
|
||||
vec_b = [1.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
|
||||
def test_empty_vectors_return_zero(self):
|
||||
"""空向量返回 0"""
|
||||
assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0
|
||||
assert compute_cosine_similarity([], []) == 0.0
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
"""零向量返回 0"""
|
||||
vec_a = [0.0, 0.0]
|
||||
vec_b = [1.0, 2.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
|
||||
|
||||
# ── MockEmbedder 测试 ───────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class TestPTYSessionConstruction:
|
|||
def test_default_construction(self):
|
||||
pty = PTYSession()
|
||||
assert pty.is_running is False
|
||||
assert pty._auto_respond is True
|
||||
assert pty._auto_respond is False
|
||||
assert pty._default_timeout == 30.0
|
||||
|
||||
def test_custom_construction(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue