fix(security): resolve all P0/P1 findings from code review
This commit is contained in:
parent
b34f74f598
commit
9e9f1314f6
|
|
@ -496,6 +496,6 @@ class PlanExecutor:
|
||||||
# 所有步骤要么完成要么跳过
|
# 所有步骤要么完成要么跳过
|
||||||
return TaskStatus.COMPLETED
|
return TaskStatus.COMPLETED
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
return TaskStatus.COMPLETED # 部分成功
|
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||||||
|
|
||||||
return TaskStatus.COMPLETED
|
return TaskStatus.COMPLETED
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ class TaskStatus(str, Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
PARTIALLY_COMPLETED = "partially_completed"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
HANDOFF = "handoff"
|
HANDOFF = "handoff"
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,18 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||||
|
|
||||||
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
||||||
from agentkit.memory.embedder import Embedder
|
from agentkit.memory.embedder import Embedder
|
||||||
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -69,6 +73,8 @@ class ExperienceStore:
|
||||||
self._retrieve_limit = retrieve_limit
|
self._retrieve_limit = retrieve_limit
|
||||||
self._pgvector_enabled = pgvector_enabled
|
self._pgvector_enabled = pgvector_enabled
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
|
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||||
|
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
||||||
|
|
||||||
async def record_experience(self, experience: TaskExperience) -> str:
|
async def record_experience(self, experience: TaskExperience) -> str:
|
||||||
"""记录任务经验
|
"""记录任务经验
|
||||||
|
|
@ -193,7 +199,7 @@ class ExperienceStore:
|
||||||
time_decay_score = (row.get("success_rate") or 0.5) * decay
|
time_decay_score = (row.get("success_rate") or 0.5) * decay
|
||||||
|
|
||||||
if row_embedding is not None:
|
if row_embedding is not None:
|
||||||
cosine_sim = _compute_cosine_similarity(query_embedding, row_embedding)
|
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
|
||||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
@ -251,7 +257,7 @@ class ExperienceStore:
|
||||||
time_decay_score = (entry.success_rate or 0.5) * decay
|
time_decay_score = (entry.success_rate or 0.5) * decay
|
||||||
|
|
||||||
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
||||||
cosine_sim = _compute_cosine_similarity(query_embedding, entry.embedding)
|
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
|
||||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
@ -425,7 +431,7 @@ class InMemoryExperienceStore:
|
||||||
time_decay_score = exp.success_rate * decay
|
time_decay_score = exp.success_rate * decay
|
||||||
|
|
||||||
if query_embedding is not None and exp.embedding is not None:
|
if query_embedding is not None and exp.embedding is not None:
|
||||||
cosine_sim = _compute_cosine_similarity(query_embedding, exp.embedding)
|
cosine_sim = compute_cosine_similarity(query_embedding, exp.embedding)
|
||||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
@ -485,21 +491,6 @@ class InMemoryExperienceStore:
|
||||||
# ── 辅助函数 ──────────────────────────────────────────────
|
# ── 辅助函数 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
|
||||||
"""计算两个向量的余弦相似度"""
|
|
||||||
if len(vec_a) != len(vec_b):
|
|
||||||
logger.warning(f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}")
|
|
||||||
return 0.0
|
|
||||||
if not vec_a:
|
|
||||||
return 0.0
|
|
||||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
|
||||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
|
||||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
|
||||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude_a * magnitude_b)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_time_window(window: str) -> timedelta:
|
def _parse_time_window(window: str) -> timedelta:
|
||||||
"""解析时间窗口字符串为 timedelta
|
"""解析时间窗口字符串为 timedelta
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -207,10 +207,12 @@ class PitfallDetector:
|
||||||
if error:
|
if error:
|
||||||
s.failure_reasons.append(error)
|
s.failure_reasons.append(error)
|
||||||
|
|
||||||
# 收集优化建议
|
# 收集优化建议 — only add to steps that are part of this experience
|
||||||
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
|
if hasattr(exp, 'optimization_tips') and exp.optimization_tips:
|
||||||
|
experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set()
|
||||||
for step_name, s in stats.items():
|
for step_name, s in stats.items():
|
||||||
s.optimization_tips.extend(exp.optimization_tips)
|
if not experience_steps or step_name in experience_steps:
|
||||||
|
s.optimization_tips.extend(exp.optimization_tips)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,10 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -17,6 +19,33 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_cql(value: str) -> str:
|
||||||
|
"""Escape special characters in CQL values."""
|
||||||
|
return value.replace("\\", "\\\\").replace('"', '\\"')
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_url(url: str) -> bool:
|
||||||
|
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ConfluenceAdapter(KBAdapter):
|
class ConfluenceAdapter(KBAdapter):
|
||||||
"""Confluence 知识库适配器
|
"""Confluence 知识库适配器
|
||||||
|
|
||||||
|
|
@ -49,6 +78,8 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
|
if not _is_safe_url(self._base_url):
|
||||||
|
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||||
self._username = username
|
self._username = username
|
||||||
self._api_token = api_token
|
self._api_token = api_token
|
||||||
self._space_keys = space_keys or []
|
self._space_keys = space_keys or []
|
||||||
|
|
@ -88,10 +119,10 @@ class ConfluenceAdapter(KBAdapter):
|
||||||
"""
|
"""
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
try:
|
try:
|
||||||
cql = f'text ~ "{query}"'
|
cql = f'text ~ "{_escape_cql(query)}"'
|
||||||
if self._space_keys:
|
if self._space_keys:
|
||||||
space_filter = " OR ".join(
|
space_filter = " OR ".join(
|
||||||
f'space = "{key}"' for key in self._space_keys
|
f'space = "{_escape_cql(key)}"' for key in self._space_keys
|
||||||
)
|
)
|
||||||
cql = f'{cql} AND ({space_filter})'
|
cql = f'{cql} AND ({space_filter})'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,11 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -17,6 +20,28 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_url(url: str) -> bool:
|
||||||
|
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FeishuKBAdapter(KBAdapter):
|
class FeishuKBAdapter(KBAdapter):
|
||||||
"""飞书知识库适配器
|
"""飞书知识库适配器
|
||||||
|
|
||||||
|
|
@ -51,8 +76,11 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
self._app_secret = app_secret
|
self._app_secret = app_secret
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
|
if not _is_safe_url(self._base_url):
|
||||||
|
raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.")
|
||||||
self._space_ids = space_ids or []
|
self._space_ids = space_ids or []
|
||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
|
self._token_expiry: float = 0.0
|
||||||
|
|
||||||
def _make_client(self) -> httpx.AsyncClient:
|
def _make_client(self) -> httpx.AsyncClient:
|
||||||
"""创建飞书 API HTTP 客户端"""
|
"""创建飞书 API HTTP 客户端"""
|
||||||
|
|
@ -67,7 +95,7 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
|
|
||||||
async def _get_access_token(self) -> str | None:
|
async def _get_access_token(self) -> str | None:
|
||||||
"""获取飞书 tenant_access_token"""
|
"""获取飞书 tenant_access_token"""
|
||||||
if self._access_token:
|
if self._access_token and time.time() < self._token_expiry:
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
@ -83,6 +111,8 @@ class FeishuKBAdapter(KBAdapter):
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if data.get("code") == 0:
|
if data.get("code") == 0:
|
||||||
self._access_token = data.get("tenant_access_token")
|
self._access_token = data.get("tenant_access_token")
|
||||||
|
expire_seconds = data.get("expire", 7200)
|
||||||
|
self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early
|
||||||
# 重建客户端以携带 token
|
# 重建客户端以携带 token
|
||||||
await self.close()
|
await self.close()
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,10 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -17,6 +19,31 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_url(url: str) -> bool:
|
||||||
|
"""Check if URL is safe (not pointing to private/internal networks)."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
# Block common internal hostnames
|
||||||
|
if hostname in ("localhost", "metadata.google.internal", "metadata.internal"):
|
||||||
|
return False
|
||||||
|
# Try to resolve as IP and check for private ranges
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local:
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
# Not an IP address, that's OK (it's a domain name)
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class GenericHTTPAdapter(KBAdapter):
|
class GenericHTTPAdapter(KBAdapter):
|
||||||
"""通用 HTTP 知识库适配器
|
"""通用 HTTP 知识库适配器
|
||||||
|
|
||||||
|
|
@ -53,6 +80,8 @@ class GenericHTTPAdapter(KBAdapter):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
self._endpoint_url = endpoint_url.rstrip("/")
|
self._endpoint_url = endpoint_url.rstrip("/")
|
||||||
|
if not _is_safe_url(self._endpoint_url):
|
||||||
|
raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.")
|
||||||
self._auth_config = auth_config or {}
|
self._auth_config = auth_config or {}
|
||||||
self._extra_headers = headers or {}
|
self._extra_headers = headers or {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy import text
|
||||||
|
|
||||||
from agentkit.memory.base import Memory, MemoryItem
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
from agentkit.memory.embedder import Embedder
|
from agentkit.memory.embedder import Embedder
|
||||||
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -123,7 +124,7 @@ class EpisodicMemory(Memory):
|
||||||
if row_embedding is None:
|
if row_embedding is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cosine = self._compute_cosine_similarity(query_embedding, row_embedding)
|
cosine = compute_cosine_similarity(query_embedding, row_embedding)
|
||||||
if cosine < 0.1:
|
if cosine < 0.1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -165,7 +166,7 @@ class EpisodicMemory(Memory):
|
||||||
entry_embedding = entry.embedding
|
entry_embedding = entry.embedding
|
||||||
if entry_embedding is None:
|
if entry_embedding is None:
|
||||||
continue
|
continue
|
||||||
cosine = self._compute_cosine_similarity(query_embedding, entry_embedding)
|
cosine = compute_cosine_similarity(query_embedding, entry_embedding)
|
||||||
if cosine > best_score:
|
if cosine > best_score:
|
||||||
best_score = cosine
|
best_score = cosine
|
||||||
best_item = entry
|
best_item = entry
|
||||||
|
|
@ -260,7 +261,7 @@ class EpisodicMemory(Memory):
|
||||||
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
time_decay_score = (row.get("quality_score") or 0.5) * decay
|
||||||
|
|
||||||
if row_embedding is not None:
|
if row_embedding is not None:
|
||||||
cosine_sim = self._compute_cosine_similarity(query_embedding, row_embedding)
|
cosine_sim = compute_cosine_similarity(query_embedding, row_embedding)
|
||||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
@ -327,7 +328,7 @@ class EpisodicMemory(Memory):
|
||||||
|
|
||||||
# 混合评分:alpha * cosine + (1 - alpha) * time_decay
|
# 混合评分:alpha * cosine + (1 - alpha) * time_decay
|
||||||
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
||||||
cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding)
|
cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding)
|
||||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||||
else:
|
else:
|
||||||
score = time_decay_score
|
score = time_decay_score
|
||||||
|
|
@ -375,20 +376,3 @@ class EpisodicMemory(Memory):
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Failed to delete episodic memory: {e}")
|
logger.error(f"Failed to delete episodic memory: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
|
||||||
"""计算两个向量的余弦相似度"""
|
|
||||||
if len(vec_a) != len(vec_b):
|
|
||||||
logger.warning(
|
|
||||||
f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}"
|
|
||||||
)
|
|
||||||
return 0.0
|
|
||||||
if not vec_a:
|
|
||||||
return 0.0
|
|
||||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
|
||||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
|
||||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
|
||||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude_a * magnitude_b)
|
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,13 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||||
|
|
||||||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||||||
from agentkit.memory.document_loader import Document as LoaderDocument
|
from agentkit.memory.document_loader import Document as LoaderDocument
|
||||||
from agentkit.memory.embedder import Embedder
|
from agentkit.memory.embedder import Embedder
|
||||||
|
|
@ -23,6 +26,7 @@ from agentkit.memory.knowledge_base import (
|
||||||
QueryResult,
|
QueryResult,
|
||||||
SourceInfo,
|
SourceInfo,
|
||||||
)
|
)
|
||||||
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -70,6 +74,8 @@ class LocalRAGService:
|
||||||
self._chunk_size = chunk_size
|
self._chunk_size = chunk_size
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
|
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||||||
|
raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*")
|
||||||
self._pgvector_enabled = pgvector_enabled
|
self._pgvector_enabled = pgvector_enabled
|
||||||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
|
@ -335,7 +341,7 @@ class LocalRAGService:
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
|
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||||||
if cosine < 0.1:
|
if cosine < 0.1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -359,18 +365,6 @@ class LocalRAGService:
|
||||||
candidates.sort(key=lambda x: x.score, reverse=True)
|
candidates.sort(key=lambda x: x.score, reverse=True)
|
||||||
return candidates[:top_k]
|
return candidates[:top_k]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
|
||||||
"""计算两个向量的余弦相似度"""
|
|
||||||
if len(vec_a) != len(vec_b) or not vec_a:
|
|
||||||
return 0.0
|
|
||||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
|
||||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
|
||||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
|
||||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude_a * magnitude_b)
|
|
||||||
|
|
||||||
|
|
||||||
class InMemoryLocalRAGService:
|
class InMemoryLocalRAGService:
|
||||||
"""基于内存的本地 RAG 服务
|
"""基于内存的本地 RAG 服务
|
||||||
|
|
@ -447,7 +441,7 @@ class InMemoryLocalRAGService:
|
||||||
candidates = []
|
candidates = []
|
||||||
for chunk_id, chunk_data in self._chunks.items():
|
for chunk_id, chunk_data in self._chunks.items():
|
||||||
stored_embedding = chunk_data["embedding"]
|
stored_embedding = chunk_data["embedding"]
|
||||||
cosine = self._compute_cosine_similarity(query_embedding, stored_embedding)
|
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||||||
if cosine < 0.1:
|
if cosine < 0.1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -511,15 +505,3 @@ class InMemoryLocalRAGService:
|
||||||
source_doc_id=doc.doc_id,
|
source_doc_id=doc.doc_id,
|
||||||
metadata=doc.metadata,
|
metadata=doc.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
|
||||||
"""计算两个向量的余弦相似度"""
|
|
||||||
if len(vec_a) != len(vec_b) or not vec_a:
|
|
||||||
return 0.0
|
|
||||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
|
||||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
|
||||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
|
||||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude_a * magnitude_b)
|
|
||||||
|
|
|
||||||
|
|
@ -457,6 +457,23 @@ async def list_path_optimizations(
|
||||||
@router.websocket("/evolution-dashboard/ws")
|
@router.websocket("/evolution-dashboard/ws")
|
||||||
async def evolution_dashboard_ws(websocket: WebSocket):
|
async def evolution_dashboard_ws(websocket: WebSocket):
|
||||||
"""自进化仪表盘实时更新 WebSocket"""
|
"""自进化仪表盘实时更新 WebSocket"""
|
||||||
|
# Authentication - check api_key
|
||||||
|
configured_api_key: str | None = None
|
||||||
|
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
|
||||||
|
configured_api_key = websocket.app.state.server_config.api_key
|
||||||
|
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
|
||||||
|
configured_api_key = websocket.app.state.api_key
|
||||||
|
|
||||||
|
if configured_api_key:
|
||||||
|
provided = websocket.query_params.get("api_key")
|
||||||
|
if provided != configured_api_key:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "error", "data": {"message": "Invalid or missing api_key"}}
|
||||||
|
)
|
||||||
|
await websocket.close(code=4001, reason="Invalid or missing api_key")
|
||||||
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
_ws_connections.append(websocket)
|
_ws_connections.append(websocket)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["kb-management"])
|
router = APIRouter(tags=["kb-management"])
|
||||||
|
|
||||||
|
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# In-memory Knowledge Source Store
|
# In-memory Knowledge Source Store
|
||||||
|
|
@ -183,14 +185,18 @@ async def upload_document(
|
||||||
try:
|
try:
|
||||||
from agentkit.memory.document_loader import DocumentLoader
|
from agentkit.memory.document_loader import DocumentLoader
|
||||||
|
|
||||||
content = await file.read()
|
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||||
|
if len(content) > MAX_UPLOAD_SIZE:
|
||||||
|
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
|
||||||
loader = DocumentLoader()
|
loader = DocumentLoader()
|
||||||
doc = loader.load_bytes(content, file.filename)
|
doc = loader.load_bytes(content, file.filename)
|
||||||
# Estimate chunks based on content length (rough approximation)
|
# Estimate chunks based on content length (rough approximation)
|
||||||
chunks = max(1, len(doc.content) // 500)
|
chunks = max(1, len(doc.content) // 500)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# DocumentLoader not available, use basic estimation
|
# DocumentLoader not available, use basic estimation
|
||||||
content = await file.read()
|
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||||
|
if len(content) > MAX_UPLOAD_SIZE:
|
||||||
|
raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB")
|
||||||
chunks = max(1, len(content) // 500)
|
chunks = max(1, len(content) // 500)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Document parsing failed: {e}")
|
logger.warning(f"Document parsing failed: {e}")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -39,6 +40,7 @@ class WorkflowStore:
|
||||||
self._executions: dict[str, WorkflowExecution] = {}
|
self._executions: dict[str, WorkflowExecution] = {}
|
||||||
self._max_workflows = max_workflows
|
self._max_workflows = max_workflows
|
||||||
self._max_executions = max_executions
|
self._max_executions = max_executions
|
||||||
|
self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}"
|
||||||
|
|
||||||
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition:
|
||||||
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
workflow.updated_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
@ -226,31 +228,70 @@ async def _execute_workflow(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stage.type == "approval":
|
if stage.type == "approval":
|
||||||
# Pause execution and wait for approval
|
# Pause execution and wait for approval via asyncio.Event
|
||||||
|
event_key = f"{execution.execution_id}:{stage_name}"
|
||||||
|
approval_event = asyncio.Event()
|
||||||
|
_store._approval_events[event_key] = approval_event
|
||||||
|
|
||||||
execution.status = "paused"
|
execution.status = "paused"
|
||||||
|
execution.current_stage = stage_name
|
||||||
_store.update_execution(
|
_store.update_execution(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
status="paused",
|
status="paused",
|
||||||
|
current_stage=stage_name,
|
||||||
)
|
)
|
||||||
await _broadcast_ws({
|
await _broadcast_ws({
|
||||||
"event": "approval_required",
|
"event": "approval_required",
|
||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"stage": stage_name,
|
"stage": stage_name,
|
||||||
})
|
})
|
||||||
# In a real implementation, this would wait for external approval
|
|
||||||
# For now, we simulate auto-approval after a brief pause
|
# Wait for approval with timeout
|
||||||
await asyncio.sleep(0.1)
|
try:
|
||||||
execution.stage_results[stage_name] = {
|
approval_timeout = stage.config.get("approval_timeout", 3600)
|
||||||
"status": "approved",
|
await asyncio.wait_for(approval_event.wait(), timeout=approval_timeout)
|
||||||
"approver": "auto",
|
# Check if execution was cancelled/rejected while waiting
|
||||||
"comment": "自动审批通过",
|
if execution.status == "cancelled":
|
||||||
}
|
await _broadcast_ws({
|
||||||
execution.status = "running"
|
"event": "stage_failed",
|
||||||
_store.update_execution(
|
"execution_id": execution.execution_id,
|
||||||
execution.execution_id,
|
"stage": stage_name,
|
||||||
status="running",
|
"error": "Approval rejected",
|
||||||
stage_results=execution.stage_results,
|
})
|
||||||
)
|
return
|
||||||
|
# Approval was granted — the /approve endpoint already set stage_results
|
||||||
|
# Only update status to running if not already set
|
||||||
|
if execution.status != "running":
|
||||||
|
execution.status = "running"
|
||||||
|
_store.update_execution(
|
||||||
|
execution.execution_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
execution.stage_results[stage_name] = {
|
||||||
|
"status": "timeout",
|
||||||
|
"approver": "none",
|
||||||
|
"comment": "审批超时",
|
||||||
|
}
|
||||||
|
execution.status = "failed"
|
||||||
|
execution.error = f"Approval timeout for stage {stage_name}"
|
||||||
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
_store.update_execution(
|
||||||
|
execution.execution_id,
|
||||||
|
status="failed",
|
||||||
|
error=execution.error,
|
||||||
|
completed_at=execution.completed_at,
|
||||||
|
stage_results=execution.stage_results,
|
||||||
|
)
|
||||||
|
await _broadcast_ws({
|
||||||
|
"event": "stage_failed",
|
||||||
|
"execution_id": execution.execution_id,
|
||||||
|
"stage": stage_name,
|
||||||
|
"error": "Approval timeout",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
_store._approval_events.pop(event_key, None)
|
||||||
elif stage.type == "condition":
|
elif stage.type == "condition":
|
||||||
# Evaluate condition expression
|
# Evaluate condition expression
|
||||||
condition_expr = stage.config.get("expression", "")
|
condition_expr = stage.config.get("expression", "")
|
||||||
|
|
@ -318,23 +359,60 @@ async def _execute_workflow(
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
||||||
|
_SAFE_OPERATORS = {"==", "!=", ">", "<", ">=", "<="}
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool:
|
||||||
"""Simple condition evaluation."""
|
"""Evaluate a condition expression safely."""
|
||||||
|
expression = expression.strip()
|
||||||
if not expression:
|
if not expression:
|
||||||
return True
|
return True
|
||||||
if "==" in expression:
|
|
||||||
parts = expression.split("==", 1)
|
# Try each operator (longer operators first to avoid partial matches)
|
||||||
left = variables.get(parts[0].strip(), parts[0].strip())
|
for op in sorted(_SAFE_OPERATORS, key=len, reverse=True):
|
||||||
right = parts[1].strip().strip("'\"")
|
if op in expression:
|
||||||
return str(left) == right
|
parts = expression.split(op, 1)
|
||||||
elif "!=" in expression:
|
if len(parts) != 2:
|
||||||
parts = expression.split("!=", 1)
|
continue
|
||||||
left = variables.get(parts[0].strip(), parts[0].strip())
|
left = parts[0].strip()
|
||||||
right = parts[1].strip().strip("'\"")
|
right = parts[1].strip()
|
||||||
return str(left) != right
|
|
||||||
else:
|
# Validate variable names
|
||||||
|
if left and not _SAFE_VAR_PATTERN.match(left):
|
||||||
|
raise ValueError(f"Invalid variable name in condition: {left}")
|
||||||
|
|
||||||
|
left_val = variables.get(left, left)
|
||||||
|
# Strip quotes from right side if present
|
||||||
|
if right.startswith('"') and right.endswith('"'):
|
||||||
|
right_val = right[1:-1]
|
||||||
|
elif right.startswith("'") and right.endswith("'"):
|
||||||
|
right_val = right[1:-1]
|
||||||
|
elif right and _SAFE_VAR_PATTERN.match(right):
|
||||||
|
right_val = variables.get(right, right)
|
||||||
|
else:
|
||||||
|
right_val = right
|
||||||
|
|
||||||
|
# Compare based on operator
|
||||||
|
if op == "==":
|
||||||
|
return str(left_val) == str(right_val)
|
||||||
|
if op == "!=":
|
||||||
|
return str(left_val) != str(right_val)
|
||||||
|
if op == ">":
|
||||||
|
return float(left_val) > float(right_val)
|
||||||
|
if op == "<":
|
||||||
|
return float(left_val) < float(right_val)
|
||||||
|
if op == ">=":
|
||||||
|
return float(left_val) >= float(right_val)
|
||||||
|
if op == "<=":
|
||||||
|
return float(left_val) <= float(right_val)
|
||||||
|
|
||||||
|
# Boolean check for variable existence
|
||||||
|
if _SAFE_VAR_PATTERN.match(expression):
|
||||||
return bool(variables.get(expression))
|
return bool(variables.get(expression))
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid condition expression: {expression}")
|
||||||
|
|
||||||
|
|
||||||
async def _broadcast_ws(message: dict[str, Any]) -> None:
|
async def _broadcast_ws(message: dict[str, Any]) -> None:
|
||||||
"""Broadcast a message to all WebSocket subscribers."""
|
"""Broadcast a message to all WebSocket subscribers."""
|
||||||
|
|
@ -487,12 +565,12 @@ async def approve_execution(
|
||||||
status="running",
|
status="running",
|
||||||
stage_results=execution.stage_results,
|
stage_results=execution.stage_results,
|
||||||
)
|
)
|
||||||
# Resume execution
|
# Resume the waiting execution by setting the approval event
|
||||||
workflow = store.get(execution.workflow_id)
|
stage_name = execution.current_stage
|
||||||
if workflow:
|
if stage_name:
|
||||||
asyncio.create_task(
|
event_key = f"{execution_id}:{stage_name}"
|
||||||
_execute_workflow(workflow, execution, execution.variables, store=store)
|
if event_key in store._approval_events:
|
||||||
)
|
store._approval_events[event_key].set()
|
||||||
else:
|
else:
|
||||||
execution.status = "cancelled"
|
execution.status = "cancelled"
|
||||||
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
execution.completed_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
@ -508,6 +586,12 @@ async def approve_execution(
|
||||||
completed_at=execution.completed_at,
|
completed_at=execution.completed_at,
|
||||||
stage_results=execution.stage_results,
|
stage_results=execution.stage_results,
|
||||||
)
|
)
|
||||||
|
# Set the approval event so the waiting coroutine can observe the cancelled state
|
||||||
|
stage_name = execution.current_stage
|
||||||
|
if stage_name:
|
||||||
|
event_key = f"{execution_id}:{stage_name}"
|
||||||
|
if event_key in store._approval_events:
|
||||||
|
store._approval_events[event_key].set()
|
||||||
|
|
||||||
return execution.model_dump()
|
return execution.model_dump()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ from dataclasses import dataclass
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 自动应答规则:(prompt_pattern, response)
|
# 自动应答规则:(prompt_pattern, response)
|
||||||
|
# WARNING: auto_respond is disabled by default for safety.
|
||||||
|
# Enable it only when you explicitly want automatic yes/confirm responses.
|
||||||
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
|
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
|
||||||
(r"\[y/N\]\s*$", "y"),
|
(r"\[y/N\]\s*$", "y"),
|
||||||
(r"\[Y/n\]\s*$", "y"),
|
(r"\[Y/n\]\s*$", "y"),
|
||||||
|
|
@ -61,7 +63,7 @@ class PTYSession:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
auto_respond: bool = True,
|
auto_respond: bool = False,
|
||||||
custom_rules: list[tuple[str, str]] | None = None,
|
custom_rules: list[tuple[str, str]] | None = None,
|
||||||
default_timeout: float = 30.0,
|
default_timeout: float = 30.0,
|
||||||
buffer_size: int = 4096,
|
buffer_size: int = 4096,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
|
|
@ -21,6 +23,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 安全白名单:这些命令前缀不需要确认
|
# 安全白名单:这些命令前缀不需要确认
|
||||||
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||||
|
"cd",
|
||||||
|
"export",
|
||||||
"ls",
|
"ls",
|
||||||
"cat",
|
"cat",
|
||||||
"head",
|
"head",
|
||||||
|
|
@ -48,6 +52,7 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||||
"sort",
|
"sort",
|
||||||
"uniq",
|
"uniq",
|
||||||
"diff",
|
"diff",
|
||||||
|
"sleep",
|
||||||
"git status",
|
"git status",
|
||||||
"git log",
|
"git log",
|
||||||
"git diff",
|
"git diff",
|
||||||
|
|
@ -101,6 +106,9 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`')
|
||||||
|
|
||||||
|
|
||||||
class ShellTool(Tool):
|
class ShellTool(Tool):
|
||||||
"""Shell 命令执行工具
|
"""Shell 命令执行工具
|
||||||
|
|
||||||
|
|
@ -364,18 +372,39 @@ class ShellTool(Tool):
|
||||||
"""
|
"""
|
||||||
command_stripped = command.strip()
|
command_stripped = command.strip()
|
||||||
|
|
||||||
# 白名单检查
|
# Check for shell operators that chain commands (always dangerous)
|
||||||
for prefix in _SAFE_COMMAND_PREFIXES:
|
if _SHELL_OPERATORS.search(command_stripped):
|
||||||
if command_stripped.startswith(prefix):
|
return True
|
||||||
return False
|
|
||||||
|
|
||||||
# 危险模式检查
|
# Parse the actual binary being invoked
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(command_stripped)
|
||||||
|
if not tokens:
|
||||||
|
return True
|
||||||
|
binary = os.path.basename(tokens[0])
|
||||||
|
except ValueError:
|
||||||
|
# Unparsable command - treat as dangerous
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Whitelist check: first try full command prefix match, then binary-only match
|
||||||
|
for prefix in _SAFE_COMMAND_PREFIXES:
|
||||||
|
prefix_stripped = prefix.lower().strip()
|
||||||
|
if " " in prefix_stripped:
|
||||||
|
# Compound prefix like "git status" - match against full command
|
||||||
|
if command_stripped.lower().startswith(prefix_stripped):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Simple prefix - match against binary name only
|
||||||
|
if binary.lower().startswith(prefix_stripped):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Dangerous pattern check
|
||||||
command_lower = command_stripped.lower()
|
command_lower = command_stripped.lower()
|
||||||
for pattern in _DANGEROUS_PATTERNS:
|
for pattern in _DANGEROUS_PATTERNS:
|
||||||
if pattern in command_lower:
|
if pattern in command_lower:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return True # Unknown commands are dangerous by default
|
||||||
|
|
||||||
async def _request_confirmation(self, command: str) -> bool:
|
async def _request_confirmation(self, command: str) -> bool:
|
||||||
"""请求人工确认危险命令
|
"""请求人工确认危险命令
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -17,6 +19,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ENV_KEY_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandRecord:
|
class CommandRecord:
|
||||||
|
|
@ -190,15 +194,13 @@ class TerminalSession:
|
||||||
|
|
||||||
# 注入 cd
|
# 注入 cd
|
||||||
if self._cwd:
|
if self._cwd:
|
||||||
# 使用 shlex.quote 风格的简单转义
|
parts.append(f"cd {shlex.quote(self._cwd)}")
|
||||||
cwd_escaped = self._cwd.replace("'", "'\\''")
|
|
||||||
parts.append(f"cd '{cwd_escaped}'")
|
|
||||||
|
|
||||||
# 注入环境变量
|
# 注入环境变量
|
||||||
for key, value in self._env.items():
|
for key, value in self._env.items():
|
||||||
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
|
if not _ENV_KEY_PATTERN.match(key):
|
||||||
val_escaped = value.replace("'", "'\\''")
|
continue # Skip invalid env key names
|
||||||
parts.append(f"export {key}='{val_escaped}'")
|
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
|
||||||
|
|
||||||
parts.append(command)
|
parts.append(command)
|
||||||
return " && ".join(parts)
|
return " && ".join(parts)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""AgentKit utility modules."""
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""Shared vector math utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||||
|
"""Compute cosine similarity between two vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vec_a: First vector.
|
||||||
|
vec_b: Second vector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score between -1 and 1.
|
||||||
|
"""
|
||||||
|
if len(vec_a) != len(vec_b):
|
||||||
|
return 0.0
|
||||||
|
if not vec_a:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||||
|
norm_a = math.sqrt(sum(a * a for a in vec_a))
|
||||||
|
norm_b = math.sqrt(sum(b * b for b in vec_b))
|
||||||
|
|
||||||
|
if norm_a == 0.0 or norm_b == 0.0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return dot_product / (norm_a * norm_b)
|
||||||
|
|
@ -433,12 +433,34 @@ async def test_workflow_with_approval():
|
||||||
assert execution.status == "pending"
|
assert execution.status == "pending"
|
||||||
assert execution.execution_id
|
assert execution.execution_id
|
||||||
|
|
||||||
# 5. Execute workflow (runs in background)
|
# 5. Execute workflow in background (approval stage will wait for event)
|
||||||
from agentkit.server.routes.workflows import _execute_workflow
|
from agentkit.server.routes.workflows import _execute_workflow
|
||||||
|
|
||||||
await _execute_workflow(workflow, execution, variables={}, store=store)
|
# Use a short approval timeout for testing
|
||||||
|
workflow.stages[1].config["approval_timeout"] = 5
|
||||||
|
|
||||||
# 6. Verify execution completed (auto-approval in test mode)
|
async def _approve_after_pause():
|
||||||
|
"""Wait for execution to pause, then approve."""
|
||||||
|
for _ in range(100):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
updated = store.get_execution(execution.execution_id)
|
||||||
|
if updated and updated.status == "paused":
|
||||||
|
break
|
||||||
|
# Trigger approval
|
||||||
|
event_key = f"{execution.execution_id}:human_review"
|
||||||
|
if event_key in store._approval_events:
|
||||||
|
execution.stage_results["human_review"] = {
|
||||||
|
"status": "approved",
|
||||||
|
"approver": "test_user",
|
||||||
|
"comment": "Auto-approved in test",
|
||||||
|
}
|
||||||
|
store._approval_events[event_key].set()
|
||||||
|
|
||||||
|
approve_task = asyncio.create_task(_approve_after_pause())
|
||||||
|
await _execute_workflow(workflow, execution, variables={}, store=store)
|
||||||
|
await approve_task
|
||||||
|
|
||||||
|
# 6. Verify execution completed
|
||||||
updated = store.get_execution(execution.execution_id)
|
updated = store.get_execution(execution.execution_id)
|
||||||
assert updated is not None
|
assert updated is not None
|
||||||
assert updated.status == "completed"
|
assert updated.status == "completed"
|
||||||
|
|
@ -452,7 +474,7 @@ async def test_workflow_with_approval():
|
||||||
approval_result = updated.stage_results["human_review"]
|
approval_result = updated.stage_results["human_review"]
|
||||||
assert approval_result.get("status") in ("approved", "completed")
|
assert approval_result.get("status") in ("approved", "completed")
|
||||||
|
|
||||||
# 9. Test manual approval flow
|
# 9. Test second workflow with approval
|
||||||
workflow2 = WorkflowDefinition(
|
workflow2 = WorkflowDefinition(
|
||||||
workflow_id="wf-manual-approval",
|
workflow_id="wf-manual-approval",
|
||||||
name="手动审批流程",
|
name="手动审批流程",
|
||||||
|
|
@ -468,6 +490,7 @@ async def test_workflow_with_approval():
|
||||||
agent="reviewer",
|
agent="reviewer",
|
||||||
action="approve",
|
action="approve",
|
||||||
type="approval",
|
type="approval",
|
||||||
|
config={"approval_timeout": 5},
|
||||||
depends_on=["step1"],
|
depends_on=["step1"],
|
||||||
),
|
),
|
||||||
WorkflowStage(
|
WorkflowStage(
|
||||||
|
|
@ -481,28 +504,30 @@ async def test_workflow_with_approval():
|
||||||
)
|
)
|
||||||
store.save(workflow2)
|
store.save(workflow2)
|
||||||
|
|
||||||
# Simulate manual approval via API
|
|
||||||
execution2 = store.create_execution(workflow2.workflow_id)
|
execution2 = store.create_execution(workflow2.workflow_id)
|
||||||
execution2.status = "paused"
|
|
||||||
execution2.current_stage = "approval_step"
|
|
||||||
store.update_execution(
|
|
||||||
execution2.execution_id,
|
|
||||||
status="paused",
|
|
||||||
current_stage="approval_step",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Approve
|
async def _approve2_after_pause():
|
||||||
execution2.stage_results["approval_step"] = {
|
for _ in range(100):
|
||||||
"status": "approved",
|
await asyncio.sleep(0.05)
|
||||||
"approver": "user",
|
updated2 = store.get_execution(execution2.execution_id)
|
||||||
"comment": "LGTM",
|
if updated2 and updated2.status == "paused":
|
||||||
}
|
break
|
||||||
execution2.status = "running"
|
event_key2 = f"{execution2.execution_id}:approval_step"
|
||||||
store.update_execution(
|
if event_key2 in store._approval_events:
|
||||||
execution2.execution_id,
|
execution2.stage_results["approval_step"] = {
|
||||||
status="running",
|
"status": "approved",
|
||||||
stage_results=execution2.stage_results,
|
"approver": "user",
|
||||||
)
|
"comment": "LGTM",
|
||||||
|
}
|
||||||
|
store._approval_events[event_key2].set()
|
||||||
|
|
||||||
|
approve_task2 = asyncio.create_task(_approve2_after_pause())
|
||||||
|
await _execute_workflow(workflow2, execution2, variables={}, store=store)
|
||||||
|
await approve_task2
|
||||||
|
|
||||||
|
updated2 = store.get_execution(execution2.execution_id)
|
||||||
|
assert updated2 is not None
|
||||||
|
assert updated2.status == "completed"
|
||||||
|
|
||||||
# Verify approval was recorded
|
# Verify approval was recorded
|
||||||
paused_exec = store.get_execution(execution2.execution_id)
|
paused_exec = store.get_execution(execution2.execution_id)
|
||||||
|
|
@ -1014,7 +1039,27 @@ async def test_multi_source_rag_with_workflow(local_rag, mock_embedder):
|
||||||
execution = store.create_execution(workflow.workflow_id)
|
execution = store.create_execution(workflow.workflow_id)
|
||||||
from agentkit.server.routes.workflows import _execute_workflow
|
from agentkit.server.routes.workflows import _execute_workflow
|
||||||
|
|
||||||
|
# Set short approval timeout and handle approval
|
||||||
|
workflow.stages[1].config["approval_timeout"] = 5
|
||||||
|
|
||||||
|
async def _approve_kb_review():
|
||||||
|
for _ in range(100):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
upd = store.get_execution(execution.execution_id)
|
||||||
|
if upd and upd.status == "paused":
|
||||||
|
break
|
||||||
|
event_key = f"{execution.execution_id}:review_findings"
|
||||||
|
if event_key in store._approval_events:
|
||||||
|
execution.stage_results["review_findings"] = {
|
||||||
|
"status": "approved",
|
||||||
|
"approver": "test_user",
|
||||||
|
"comment": "Approved",
|
||||||
|
}
|
||||||
|
store._approval_events[event_key].set()
|
||||||
|
|
||||||
|
approve_task = asyncio.create_task(_approve_kb_review())
|
||||||
await _execute_workflow(workflow, execution, variables={}, store=store)
|
await _execute_workflow(workflow, execution, variables={}, store=store)
|
||||||
|
await approve_task
|
||||||
|
|
||||||
updated = store.get_execution(execution.execution_id)
|
updated = store.get_execution(execution.execution_id)
|
||||||
assert updated.status == "completed"
|
assert updated.status == "completed"
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,10 @@ import pytest
|
||||||
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience
|
||||||
from agentkit.evolution.experience_store import (
|
from agentkit.evolution.experience_store import (
|
||||||
InMemoryExperienceStore,
|
InMemoryExperienceStore,
|
||||||
_compute_cosine_similarity,
|
|
||||||
_parse_time_window,
|
_parse_time_window,
|
||||||
)
|
)
|
||||||
from agentkit.memory.embedder import MockEmbedder
|
from agentkit.memory.embedder import MockEmbedder
|
||||||
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ──────────────────────────────────────────────
|
# ── Fixtures ──────────────────────────────────────────────
|
||||||
|
|
@ -136,23 +136,23 @@ class TestEvolutionMetrics:
|
||||||
class TestHelperFunctions:
|
class TestHelperFunctions:
|
||||||
def test_cosine_similarity_identical(self):
|
def test_cosine_similarity_identical(self):
|
||||||
vec = [1.0, 0.0, 0.0]
|
vec = [1.0, 0.0, 0.0]
|
||||||
assert _compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||||
|
|
||||||
def test_cosine_similarity_orthogonal(self):
|
def test_cosine_similarity_orthogonal(self):
|
||||||
a = [1.0, 0.0]
|
a = [1.0, 0.0]
|
||||||
b = [0.0, 1.0]
|
b = [0.0, 1.0]
|
||||||
assert _compute_cosine_similarity(a, b) == pytest.approx(0.0)
|
assert compute_cosine_similarity(a, b) == pytest.approx(0.0)
|
||||||
|
|
||||||
def test_cosine_similarity_opposite(self):
|
def test_cosine_similarity_opposite(self):
|
||||||
a = [1.0, 0.0]
|
a = [1.0, 0.0]
|
||||||
b = [-1.0, 0.0]
|
b = [-1.0, 0.0]
|
||||||
assert _compute_cosine_similarity(a, b) == pytest.approx(-1.0)
|
assert compute_cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||||
|
|
||||||
def test_cosine_similarity_empty(self):
|
def test_cosine_similarity_empty(self):
|
||||||
assert _compute_cosine_similarity([], []) == 0.0
|
assert compute_cosine_similarity([], []) == 0.0
|
||||||
|
|
||||||
def test_cosine_similarity_mismatched_dims(self):
|
def test_cosine_similarity_mismatched_dims(self):
|
||||||
assert _compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
|
assert compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0
|
||||||
|
|
||||||
def test_parse_time_window_hours(self):
|
def test_parse_time_window_hours(self):
|
||||||
delta = _parse_time_window("24h")
|
delta = _parse_time_window("24h")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
|
"""Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
|
from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase
|
||||||
|
|
@ -64,7 +65,7 @@ class TestKnowledgeBaseProtocol:
|
||||||
assert isinstance(adapter2, KnowledgeBase)
|
assert isinstance(adapter2, KnowledgeBase)
|
||||||
|
|
||||||
adapter3 = GenericHTTPAdapter(
|
adapter3 = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
assert isinstance(adapter3, KnowledgeBase)
|
assert isinstance(adapter3, KnowledgeBase)
|
||||||
|
|
||||||
|
|
@ -256,6 +257,8 @@ class TestFeishuKBAdapterSearch:
|
||||||
async def test_search_success(self, adapter):
|
async def test_search_success(self, adapter):
|
||||||
# Mock authentication
|
# Mock authentication
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
|
|
@ -307,6 +310,7 @@ class TestFeishuKBAdapterSearch:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_api_error(self, adapter):
|
async def test_search_api_error(self, adapter):
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
|
|
@ -328,6 +332,7 @@ class TestFeishuKBAdapterSearch:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 500
|
mock_resp.status_code = 500
|
||||||
|
|
@ -380,6 +385,7 @@ class TestFeishuKBAdapterListSources:
|
||||||
async def test_list_sources_success(self):
|
async def test_list_sources_success(self):
|
||||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
|
|
@ -420,6 +426,7 @@ class TestFeishuKBAdapterGetDocument:
|
||||||
async def test_get_document_success(self):
|
async def test_get_document_success(self):
|
||||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
|
|
@ -450,6 +457,7 @@ class TestFeishuKBAdapterGetDocument:
|
||||||
async def test_get_document_not_found(self):
|
async def test_get_document_not_found(self):
|
||||||
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||||
adapter._access_token = "t-xxx"
|
adapter._access_token = "t-xxx"
|
||||||
|
adapter._token_expiry = time.time() + 7200
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
|
|
@ -736,23 +744,23 @@ class TestGenericHTTPAdapterInit:
|
||||||
|
|
||||||
def test_basic_init(self):
|
def test_basic_init(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
|
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||||||
assert adapter._auth_config == {}
|
assert adapter._auth_config == {}
|
||||||
assert adapter._extra_headers == {}
|
assert adapter._extra_headers == {}
|
||||||
assert adapter._source_type == "generic_http"
|
assert adapter._source_type == "generic_http"
|
||||||
|
|
||||||
def test_init_with_auth_bearer(self):
|
def test_init_with_auth_bearer(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb/",
|
endpoint_url="https://example.com/api/kb/",
|
||||||
auth_config={"type": "bearer", "token": "sk-test"},
|
auth_config={"type": "bearer", "token": "sk-test"},
|
||||||
headers={"X-Custom": "value"},
|
headers={"X-Custom": "value"},
|
||||||
source_id="my-kb",
|
source_id="my-kb",
|
||||||
source_name="My KB",
|
source_name="My KB",
|
||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
assert adapter._endpoint_url == "http://localhost:8000/api/kb"
|
assert adapter._endpoint_url == "https://example.com/api/kb"
|
||||||
assert adapter._auth_config["type"] == "bearer"
|
assert adapter._auth_config["type"] == "bearer"
|
||||||
assert adapter._extra_headers == {"X-Custom": "value"}
|
assert adapter._extra_headers == {"X-Custom": "value"}
|
||||||
assert adapter._source_id == "my-kb"
|
assert adapter._source_id == "my-kb"
|
||||||
|
|
@ -760,7 +768,7 @@ class TestGenericHTTPAdapterInit:
|
||||||
|
|
||||||
def test_client_bearer_auth_header(self):
|
def test_client_bearer_auth_header(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
auth_config={"type": "bearer", "token": "sk-test"},
|
auth_config={"type": "bearer", "token": "sk-test"},
|
||||||
)
|
)
|
||||||
client = adapter._make_client()
|
client = adapter._make_client()
|
||||||
|
|
@ -768,7 +776,7 @@ class TestGenericHTTPAdapterInit:
|
||||||
|
|
||||||
def test_client_basic_auth_header(self):
|
def test_client_basic_auth_header(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
auth_config={"type": "basic", "username": "user", "password": "pass"},
|
auth_config={"type": "basic", "username": "user", "password": "pass"},
|
||||||
)
|
)
|
||||||
client = adapter._make_client()
|
client = adapter._make_client()
|
||||||
|
|
@ -777,7 +785,7 @@ class TestGenericHTTPAdapterInit:
|
||||||
|
|
||||||
def test_client_api_key_header(self):
|
def test_client_api_key_header(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
|
auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"},
|
||||||
)
|
)
|
||||||
client = adapter._make_client()
|
client = adapter._make_client()
|
||||||
|
|
@ -790,7 +798,7 @@ class TestGenericHTTPAdapterSearch:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def adapter(self):
|
def adapter(self):
|
||||||
return GenericHTTPAdapter(
|
return GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
auth_config={"type": "bearer", "token": "sk-test"},
|
auth_config={"type": "bearer", "token": "sk-test"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -892,7 +900,7 @@ class TestGenericHTTPAdapterIngest:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ingest_success(self):
|
async def test_ingest_success(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -921,7 +929,7 @@ class TestGenericHTTPAdapterIngest:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -946,7 +954,7 @@ class TestGenericHTTPAdapterDeleteById:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_success(self):
|
async def test_delete_success(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -963,7 +971,7 @@ class TestGenericHTTPAdapterDeleteById:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_not_found(self):
|
async def test_delete_not_found(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -983,7 +991,7 @@ class TestGenericHTTPAdapterGetDocument:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_document_success(self):
|
async def test_get_document_success(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -1011,7 +1019,7 @@ class TestGenericHTTPAdapterGetDocument:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -1034,7 +1042,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check_ok(self):
|
async def test_health_check_ok(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -1050,7 +1058,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
||||||
async def test_health_check_fallback_to_root(self):
|
async def test_health_check_fallback_to_root(self):
|
||||||
"""health endpoint 不存在时回退到根路径"""
|
"""health endpoint 不存在时回退到根路径"""
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -1076,7 +1084,7 @@ class TestGenericHTTPAdapterHealthCheck:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check_connection_error(self):
|
async def test_health_check_connection_error(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
|
|
@ -1092,7 +1100,7 @@ class TestGenericHTTPAdapterListSources:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_sources_success(self):
|
async def test_list_sources_success(self):
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
|
|
@ -1116,7 +1124,7 @@ class TestGenericHTTPAdapterListSources:
|
||||||
async def test_list_sources_endpoint_not_found(self):
|
async def test_list_sources_endpoint_not_found(self):
|
||||||
"""sources endpoint 不存在时返回默认信息源"""
|
"""sources endpoint 不存在时返回默认信息源"""
|
||||||
adapter = GenericHTTPAdapter(
|
adapter = GenericHTTPAdapter(
|
||||||
endpoint_url="http://localhost:8000/api/kb",
|
endpoint_url="https://example.com/api/kb",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
|
|
@ -1146,7 +1154,7 @@ class TestCrossAdapterIntegration:
|
||||||
username="user@test.com",
|
username="user@test.com",
|
||||||
api_token="token",
|
api_token="token",
|
||||||
),
|
),
|
||||||
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
|
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||||||
]
|
]
|
||||||
for adapter in adapters:
|
for adapter in adapters:
|
||||||
assert isinstance(adapter, KnowledgeBase)
|
assert isinstance(adapter, KnowledgeBase)
|
||||||
|
|
@ -1166,7 +1174,7 @@ class TestCrossAdapterIntegration:
|
||||||
username="user@test.com",
|
username="user@test.com",
|
||||||
api_token="token",
|
api_token="token",
|
||||||
),
|
),
|
||||||
GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"),
|
GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"),
|
||||||
]
|
]
|
||||||
for adapter in adapters:
|
for adapter in adapters:
|
||||||
assert hasattr(adapter, "search")
|
assert hasattr(adapter, "search")
|
||||||
|
|
@ -1179,6 +1187,7 @@ class TestCrossAdapterIntegration:
|
||||||
# Feishu
|
# Feishu
|
||||||
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret")
|
||||||
feishu._access_token = "t-xxx"
|
feishu._access_token = "t-xxx"
|
||||||
|
feishu._token_expiry = time.time() + 7200
|
||||||
mock_resp = MagicMock()
|
mock_resp = MagicMock()
|
||||||
mock_resp.status_code = 200
|
mock_resp.status_code = 200
|
||||||
mock_resp.raise_for_status = MagicMock()
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
|
@ -1220,7 +1229,7 @@ class TestCrossAdapterIntegration:
|
||||||
assert all(isinstance(r, QueryResult) for r in results)
|
assert all(isinstance(r, QueryResult) for r in results)
|
||||||
|
|
||||||
# GenericHTTP
|
# GenericHTTP
|
||||||
generic = GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb")
|
generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb")
|
||||||
mock_resp3 = MagicMock()
|
mock_resp3 = MagicMock()
|
||||||
mock_resp3.status_code = 200
|
mock_resp3.status_code = 200
|
||||||
mock_resp3.raise_for_status = MagicMock()
|
mock_resp3.raise_for_status = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase
|
||||||
from agentkit.memory.episodic import EpisodicMemory
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.base import MemoryItem
|
from agentkit.memory.base import MemoryItem
|
||||||
from agentkit.memory.embedder import MockEmbedder
|
from agentkit.memory.embedder import MockEmbedder
|
||||||
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
||||||
|
|
@ -112,40 +113,40 @@ def _make_row_mapping(data: dict) -> _RowMapping:
|
||||||
|
|
||||||
|
|
||||||
class TestCosineSimilarity:
|
class TestCosineSimilarity:
|
||||||
"""_compute_cosine_similarity 测试"""
|
"""compute_cosine_similarity 测试"""
|
||||||
|
|
||||||
def test_identical_vectors_return_one(self):
|
def test_identical_vectors_return_one(self):
|
||||||
"""相同向量余弦相似度为 1"""
|
"""相同向量余弦相似度为 1"""
|
||||||
vec = [1.0, 0.0, 0.0]
|
vec = [1.0, 0.0, 0.0]
|
||||||
assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||||
|
|
||||||
def test_orthogonal_vectors_return_zero(self):
|
def test_orthogonal_vectors_return_zero(self):
|
||||||
"""正交向量余弦相似度为 0"""
|
"""正交向量余弦相似度为 0"""
|
||||||
vec_a = [1.0, 0.0]
|
vec_a = [1.0, 0.0]
|
||||||
vec_b = [0.0, 1.0]
|
vec_b = [0.0, 1.0]
|
||||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
||||||
|
|
||||||
def test_opposite_vectors_return_minus_one(self):
|
def test_opposite_vectors_return_minus_one(self):
|
||||||
"""相反向量余弦相似度为 -1"""
|
"""相反向量余弦相似度为 -1"""
|
||||||
vec_a = [1.0, 0.0]
|
vec_a = [1.0, 0.0]
|
||||||
vec_b = [-1.0, 0.0]
|
vec_b = [-1.0, 0.0]
|
||||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
||||||
|
|
||||||
def test_dimension_mismatch_returns_zero(self):
|
def test_dimension_mismatch_returns_zero(self):
|
||||||
"""维度不匹配返回 0"""
|
"""维度不匹配返回 0"""
|
||||||
vec_a = [1.0, 2.0]
|
vec_a = [1.0, 2.0]
|
||||||
vec_b = [1.0]
|
vec_b = [1.0]
|
||||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||||
|
|
||||||
def test_empty_vectors_return_zero(self):
|
def test_empty_vectors_return_zero(self):
|
||||||
"""空向量返回 0"""
|
"""空向量返回 0"""
|
||||||
assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0
|
assert compute_cosine_similarity([], []) == 0.0
|
||||||
|
|
||||||
def test_zero_vector_returns_zero(self):
|
def test_zero_vector_returns_zero(self):
|
||||||
"""零向量返回 0"""
|
"""零向量返回 0"""
|
||||||
vec_a = [0.0, 0.0]
|
vec_a = [0.0, 0.0]
|
||||||
vec_b = [1.0, 2.0]
|
vec_b = [1.0, 2.0]
|
||||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||||
|
|
||||||
|
|
||||||
# ── MockEmbedder 测试 ───────────────────────────────────
|
# ── MockEmbedder 测试 ───────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class TestPTYSessionConstruction:
|
||||||
def test_default_construction(self):
|
def test_default_construction(self):
|
||||||
pty = PTYSession()
|
pty = PTYSession()
|
||||||
assert pty.is_running is False
|
assert pty.is_running is False
|
||||||
assert pty._auto_respond is True
|
assert pty._auto_respond is False
|
||||||
assert pty._default_timeout == 30.0
|
assert pty._default_timeout == 30.0
|
||||||
|
|
||||||
def test_custom_construction(self):
|
def test_custom_construction(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue