From 9e9f1314f6aeca0d248e916a52ac03c9ccc79504 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Wed, 10 Jun 2026 07:12:41 +0800 Subject: [PATCH] fix(security): resolve all P0/P1 findings from code review --- src/agentkit/core/plan_executor.py | 2 +- src/agentkit/core/protocol.py | 1 + src/agentkit/evolution/experience_store.py | 27 ++-- src/agentkit/evolution/pitfall_detector.py | 8 +- src/agentkit/memory/adapters/confluence.py | 35 +++- src/agentkit/memory/adapters/feishu.py | 32 +++- src/agentkit/memory/adapters/generic_http.py | 29 ++++ src/agentkit/memory/episodic.py | 26 +-- src/agentkit/memory/local_rag.py | 34 +--- .../server/routes/evolution_dashboard.py | 17 ++ src/agentkit/server/routes/kb_management.py | 10 +- src/agentkit/server/routes/workflows.py | 150 ++++++++++++++---- src/agentkit/tools/pty_session.py | 4 +- src/agentkit/tools/shell.py | 41 ++++- src/agentkit/tools/terminal_session.py | 14 +- src/agentkit/utils/__init__.py | 1 + src/agentkit/utils/vector_math.py | 30 ++++ .../integration/test_goal_driven_scenario.py | 93 ++++++++--- tests/unit/evolution/test_experience_store.py | 12 +- tests/unit/memory/test_adapters.py | 55 ++++--- tests/unit/test_episodic_vector_search.py | 15 +- tests/unit/tools/test_pty_session.py | 2 +- 22 files changed, 457 insertions(+), 181 deletions(-) create mode 100644 src/agentkit/utils/__init__.py create mode 100644 src/agentkit/utils/vector_math.py diff --git a/src/agentkit/core/plan_executor.py b/src/agentkit/core/plan_executor.py index 89f62a8..2ed2a29 100644 --- a/src/agentkit/core/plan_executor.py +++ b/src/agentkit/core/plan_executor.py @@ -496,6 +496,6 @@ class PlanExecutor: # 所有步骤要么完成要么跳过 return TaskStatus.COMPLETED if failed > 0: - return TaskStatus.COMPLETED # 部分成功 + return TaskStatus.PARTIALLY_COMPLETED # 部分成功 return TaskStatus.COMPLETED diff --git a/src/agentkit/core/protocol.py b/src/agentkit/core/protocol.py index 91e76ac..6b5286c 100644 --- a/src/agentkit/core/protocol.py +++ b/src/agentkit/core/protocol.py @@ -13,6 +13,7 @@ class TaskStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" + PARTIALLY_COMPLETED = "partially_completed" FAILED = "failed" CANCELLED = "cancelled" HANDOFF = "handoff" diff --git a/src/agentkit/evolution/experience_store.py b/src/agentkit/evolution/experience_store.py index 8c4d41a..89e32c1 100644 --- a/src/agentkit/evolution/experience_store.py +++ b/src/agentkit/evolution/experience_store.py @@ -12,14 +12,18 @@ from __future__ import annotations import logging import math +import re import uuid from datetime import datetime, timedelta, timezone from typing import Any from sqlalchemy import text +_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') + from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience from agentkit.memory.embedder import Embedder +from agentkit.utils.vector_math import compute_cosine_similarity logger = logging.getLogger(__name__) @@ -69,6 +73,8 @@ class ExperienceStore: self._retrieve_limit = retrieve_limit self._pgvector_enabled = pgvector_enabled self._table_name = table_name + if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): + raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") async def record_experience(self, experience: TaskExperience) -> str: """记录任务经验 @@ -193,7 +199,7 @@ class ExperienceStore: time_decay_score = (row.get("success_rate") or 0.5) * decay if row_embedding is not None: - cosine_sim = _compute_cosine_similarity(query_embedding, row_embedding) + cosine_sim = compute_cosine_similarity(query_embedding, row_embedding) score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score else: score = time_decay_score @@ -251,7 +257,7 @@ class ExperienceStore: time_decay_score = (entry.success_rate or 0.5) * decay if self._embedder and query_embedding is not None and entry.embedding is not None: - cosine_sim = _compute_cosine_similarity(query_embedding, entry.embedding) + cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding) score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score else: score = time_decay_score @@ -425,7 +431,7 @@ class InMemoryExperienceStore: time_decay_score = exp.success_rate * decay if query_embedding is not None and exp.embedding is not None: - cosine_sim = _compute_cosine_similarity(query_embedding, exp.embedding) + cosine_sim = compute_cosine_similarity(query_embedding, exp.embedding) score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score else: score = time_decay_score @@ -485,21 +491,6 @@ class InMemoryExperienceStore: # ── 辅助函数 ────────────────────────────────────────────── -def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float: - """计算两个向量的余弦相似度""" - if len(vec_a) != len(vec_b): - logger.warning(f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}") - return 0.0 - if not vec_a: - return 0.0 - dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) - magnitude_a = sum(a**2 for a in vec_a) ** 0.5 - magnitude_b = sum(b**2 for b in vec_b) ** 0.5 - if magnitude_a == 0.0 or magnitude_b == 0.0: - return 0.0 - return dot_product / (magnitude_a * magnitude_b) - - def _parse_time_window(window: str) -> timedelta: """解析时间窗口字符串为 timedelta diff --git a/src/agentkit/evolution/pitfall_detector.py b/src/agentkit/evolution/pitfall_detector.py index 87bdfc5..3567aa4 100644 --- a/src/agentkit/evolution/pitfall_detector.py +++ b/src/agentkit/evolution/pitfall_detector.py @@ -207,10 +207,12 @@ class PitfallDetector: if error: s.failure_reasons.append(error) - # 收集优化建议 - if hasattr(exp, "optimization_tips") and exp.optimization_tips: + # 收集优化建议 — only add to steps that are part of this experience + if hasattr(exp, 'optimization_tips') and exp.optimization_tips: + experience_steps = set(exp.steps) if hasattr(exp, 'steps') and exp.steps else set() for step_name, s in stats.items(): - s.optimization_tips.extend(exp.optimization_tips) + if not experience_steps or step_name in experience_steps: + s.optimization_tips.extend(exp.optimization_tips) return stats diff --git a/src/agentkit/memory/adapters/confluence.py b/src/agentkit/memory/adapters/confluence.py index 0bed8da..4858b1c 100644 --- a/src/agentkit/memory/adapters/confluence.py +++ b/src/agentkit/memory/adapters/confluence.py @@ -6,8 +6,10 @@ from __future__ import annotations +import ipaddress import logging from typing import Any +from urllib.parse import urlparse import httpx @@ -17,6 +19,33 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo logger = logging.getLogger(__name__) +def _escape_cql(value: str) -> str: + """Escape special characters in CQL values.""" + return value.replace("\\", "\\\\").replace('"', '\\"') + + +def _is_safe_url(url: str) -> bool: + """Check if URL is safe (not pointing to private/internal networks).""" + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False + hostname = parsed.hostname + if not hostname: + return False + if hostname in ("localhost", "metadata.google.internal", "metadata.internal"): + return False + try: + ip = ipaddress.ip_address(hostname) + if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local: + return False + except ValueError: + pass + return True + except Exception: + return False + + class ConfluenceAdapter(KBAdapter): """Confluence 知识库适配器 @@ -49,6 +78,8 @@ class ConfluenceAdapter(KBAdapter): timeout=timeout, ) self._base_url = base_url.rstrip("/") + if not _is_safe_url(self._base_url): + raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") self._username = username self._api_token = api_token self._space_keys = space_keys or [] @@ -88,10 +119,10 @@ class ConfluenceAdapter(KBAdapter): """ client = self._get_client() try: - cql = f'text ~ "{query}"' + cql = f'text ~ "{_escape_cql(query)}"' if self._space_keys: space_filter = " OR ".join( - f'space = "{key}"' for key in self._space_keys + f'space = "{_escape_cql(key)}"' for key in self._space_keys ) cql = f'{cql} AND ({space_filter})' diff --git a/src/agentkit/memory/adapters/feishu.py b/src/agentkit/memory/adapters/feishu.py index 6672618..214b731 100644 --- a/src/agentkit/memory/adapters/feishu.py +++ b/src/agentkit/memory/adapters/feishu.py @@ -6,8 +6,11 @@ from __future__ import annotations +import ipaddress import logging +import time from typing import Any +from urllib.parse import urlparse import httpx @@ -17,6 +20,28 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo logger = logging.getLogger(__name__) +def _is_safe_url(url: str) -> bool: + """Check if URL is safe (not pointing to private/internal networks).""" + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False + hostname = parsed.hostname + if not hostname: + return False + if hostname in ("localhost", "metadata.google.internal", "metadata.internal"): + return False + try: + ip = ipaddress.ip_address(hostname) + if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local: + return False + except ValueError: + pass + return True + except Exception: + return False + + class FeishuKBAdapter(KBAdapter): """飞书知识库适配器 @@ -51,8 +76,11 @@ class FeishuKBAdapter(KBAdapter): self._app_id = app_id self._app_secret = app_secret self._base_url = base_url.rstrip("/") + if not _is_safe_url(self._base_url): + raise ValueError(f"Unsafe base_url: {self._base_url}. Private/internal URLs are not allowed.") self._space_ids = space_ids or [] self._access_token: str | None = None + self._token_expiry: float = 0.0 def _make_client(self) -> httpx.AsyncClient: """创建飞书 API HTTP 客户端""" @@ -67,7 +95,7 @@ class FeishuKBAdapter(KBAdapter): async def _get_access_token(self) -> str | None: """获取飞书 tenant_access_token""" - if self._access_token: + if self._access_token and time.time() < self._token_expiry: return self._access_token client = self._get_client() @@ -83,6 +111,8 @@ class FeishuKBAdapter(KBAdapter): data = resp.json() if data.get("code") == 0: self._access_token = data.get("tenant_access_token") + expire_seconds = data.get("expire", 7200) + self._token_expiry = time.time() + expire_seconds - 300 # Refresh 5 minutes early # 重建客户端以携带 token await self.close() return self._access_token diff --git a/src/agentkit/memory/adapters/generic_http.py b/src/agentkit/memory/adapters/generic_http.py index 7de5826..7528196 100644 --- a/src/agentkit/memory/adapters/generic_http.py +++ b/src/agentkit/memory/adapters/generic_http.py @@ -6,8 +6,10 @@ from __future__ import annotations +import ipaddress import logging from typing import Any +from urllib.parse import urlparse import httpx @@ -17,6 +19,31 @@ from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo logger = logging.getLogger(__name__) +def _is_safe_url(url: str) -> bool: + """Check if URL is safe (not pointing to private/internal networks).""" + try: + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return False + hostname = parsed.hostname + if not hostname: + return False + # Block common internal hostnames + if hostname in ("localhost", "metadata.google.internal", "metadata.internal"): + return False + # Try to resolve as IP and check for private ranges + try: + ip = ipaddress.ip_address(hostname) + if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local: + return False + except ValueError: + # Not an IP address, that's OK (it's a domain name) + pass + return True + except Exception: + return False + + class GenericHTTPAdapter(KBAdapter): """通用 HTTP 知识库适配器 @@ -53,6 +80,8 @@ class GenericHTTPAdapter(KBAdapter): timeout=timeout, ) self._endpoint_url = endpoint_url.rstrip("/") + if not _is_safe_url(self._endpoint_url): + raise ValueError(f"Unsafe endpoint_url: {self._endpoint_url}. Private/internal URLs are not allowed.") self._auth_config = auth_config or {} self._extra_headers = headers or {} diff --git a/src/agentkit/memory/episodic.py b/src/agentkit/memory/episodic.py index 5db5350..f6817e4 100644 --- a/src/agentkit/memory/episodic.py +++ b/src/agentkit/memory/episodic.py @@ -10,6 +10,7 @@ from sqlalchemy import text from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.embedder import Embedder +from agentkit.utils.vector_math import compute_cosine_similarity logger = logging.getLogger(__name__) @@ -123,7 +124,7 @@ class EpisodicMemory(Memory): if row_embedding is None: return None - cosine = self._compute_cosine_similarity(query_embedding, row_embedding) + cosine = compute_cosine_similarity(query_embedding, row_embedding) if cosine < 0.1: return None @@ -165,7 +166,7 @@ class EpisodicMemory(Memory): entry_embedding = entry.embedding if entry_embedding is None: continue - cosine = self._compute_cosine_similarity(query_embedding, entry_embedding) + cosine = compute_cosine_similarity(query_embedding, entry_embedding) if cosine > best_score: best_score = cosine best_item = entry @@ -260,7 +261,7 @@ class EpisodicMemory(Memory): time_decay_score = (row.get("quality_score") or 0.5) * decay if row_embedding is not None: - cosine_sim = self._compute_cosine_similarity(query_embedding, row_embedding) + cosine_sim = compute_cosine_similarity(query_embedding, row_embedding) score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score else: score = time_decay_score @@ -327,7 +328,7 @@ class EpisodicMemory(Memory): # 混合评分:alpha * cosine + (1 - alpha) * time_decay if self._embedder and query_embedding is not None and entry.embedding is not None: - cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding) + cosine_sim = compute_cosine_similarity(query_embedding, entry.embedding) score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score else: score = time_decay_score @@ -375,20 +376,3 @@ class EpisodicMemory(Memory): await db.rollback() logger.error(f"Failed to delete episodic memory: {e}") return False - - @staticmethod - def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float: - """计算两个向量的余弦相似度""" - if len(vec_a) != len(vec_b): - logger.warning( - f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}" - ) - return 0.0 - if not vec_a: - return 0.0 - dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) - magnitude_a = sum(a**2 for a in vec_a) ** 0.5 - magnitude_b = sum(b**2 for b in vec_b) ** 0.5 - if magnitude_a == 0.0 or magnitude_b == 0.0: - return 0.0 - return dot_product / (magnitude_a * magnitude_b) diff --git a/src/agentkit/memory/local_rag.py b/src/agentkit/memory/local_rag.py index 6e9e2c5..9fbfaf0 100644 --- a/src/agentkit/memory/local_rag.py +++ b/src/agentkit/memory/local_rag.py @@ -10,10 +10,13 @@ from __future__ import annotations import json import logging +import re import uuid from datetime import datetime, timezone from typing import Any +_SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') + from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.embedder import Embedder @@ -23,6 +26,7 @@ from agentkit.memory.knowledge_base import ( QueryResult, SourceInfo, ) +from agentkit.utils.vector_math import compute_cosine_similarity logger = logging.getLogger(__name__) @@ -70,6 +74,8 @@ class LocalRAGService: self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._table_name = table_name + if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): + raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") self._pgvector_enabled = pgvector_enabled self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) @@ -335,7 +341,7 @@ class LocalRAGService: except (json.JSONDecodeError, TypeError): continue - cosine = self._compute_cosine_similarity(query_embedding, stored_embedding) + cosine = compute_cosine_similarity(query_embedding, stored_embedding) if cosine < 0.1: continue @@ -359,18 +365,6 @@ class LocalRAGService: candidates.sort(key=lambda x: x.score, reverse=True) return candidates[:top_k] - @staticmethod - def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float: - """计算两个向量的余弦相似度""" - if len(vec_a) != len(vec_b) or not vec_a: - return 0.0 - dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) - magnitude_a = sum(a**2 for a in vec_a) ** 0.5 - magnitude_b = sum(b**2 for b in vec_b) ** 0.5 - if magnitude_a == 0.0 or magnitude_b == 0.0: - return 0.0 - return dot_product / (magnitude_a * magnitude_b) - class InMemoryLocalRAGService: """基于内存的本地 RAG 服务 @@ -447,7 +441,7 @@ class InMemoryLocalRAGService: candidates = [] for chunk_id, chunk_data in self._chunks.items(): stored_embedding = chunk_data["embedding"] - cosine = self._compute_cosine_similarity(query_embedding, stored_embedding) + cosine = compute_cosine_similarity(query_embedding, stored_embedding) if cosine < 0.1: continue @@ -511,15 +505,3 @@ class InMemoryLocalRAGService: source_doc_id=doc.doc_id, metadata=doc.metadata, ) - - @staticmethod - def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float: - """计算两个向量的余弦相似度""" - if len(vec_a) != len(vec_b) or not vec_a: - return 0.0 - dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) - magnitude_a = sum(a**2 for a in vec_a) ** 0.5 - magnitude_b = sum(b**2 for b in vec_b) ** 0.5 - if magnitude_a == 0.0 or magnitude_b == 0.0: - return 0.0 - return dot_product / (magnitude_a * magnitude_b) diff --git a/src/agentkit/server/routes/evolution_dashboard.py b/src/agentkit/server/routes/evolution_dashboard.py index 4fea39b..bcbd7af 100644 --- a/src/agentkit/server/routes/evolution_dashboard.py +++ b/src/agentkit/server/routes/evolution_dashboard.py @@ -457,6 +457,23 @@ async def list_path_optimizations( @router.websocket("/evolution-dashboard/ws") async def evolution_dashboard_ws(websocket: WebSocket): """自进化仪表盘实时更新 WebSocket""" + # Authentication - check api_key + configured_api_key: str | None = None + if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config: + configured_api_key = websocket.app.state.server_config.api_key + if configured_api_key is None and hasattr(websocket.app.state, "api_key"): + configured_api_key = websocket.app.state.api_key + + if configured_api_key: + provided = websocket.query_params.get("api_key") + if provided != configured_api_key: + await websocket.accept() + await websocket.send_json( + {"type": "error", "data": {"message": "Invalid or missing api_key"}} + ) + await websocket.close(code=4001, reason="Invalid or missing api_key") + return + await websocket.accept() _ws_connections.append(websocket) diff --git a/src/agentkit/server/routes/kb_management.py b/src/agentkit/server/routes/kb_management.py index d75ac09..9ab760c 100644 --- a/src/agentkit/server/routes/kb_management.py +++ b/src/agentkit/server/routes/kb_management.py @@ -15,6 +15,8 @@ logger = logging.getLogger(__name__) router = APIRouter(tags=["kb-management"]) +MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB + # --------------------------------------------------------------------------- # In-memory Knowledge Source Store @@ -183,14 +185,18 @@ async def upload_document( try: from agentkit.memory.document_loader import DocumentLoader - content = await file.read() + content = await file.read(MAX_UPLOAD_SIZE + 1) + if len(content) > MAX_UPLOAD_SIZE: + raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB") loader = DocumentLoader() doc = loader.load_bytes(content, file.filename) # Estimate chunks based on content length (rough approximation) chunks = max(1, len(doc.content) // 500) except ImportError: # DocumentLoader not available, use basic estimation - content = await file.read() + content = await file.read(MAX_UPLOAD_SIZE + 1) + if len(content) > MAX_UPLOAD_SIZE: + raise HTTPException(status_code=413, detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024*1024)}MB") chunks = max(1, len(content) // 500) except Exception as e: logger.warning(f"Document parsing failed: {e}") diff --git a/src/agentkit/server/routes/workflows.py b/src/agentkit/server/routes/workflows.py index 71e30f4..ac13703 100644 --- a/src/agentkit/server/routes/workflows.py +++ b/src/agentkit/server/routes/workflows.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import json import logging +import re import uuid from datetime import datetime, timezone from typing import Any @@ -39,6 +40,7 @@ class WorkflowStore: self._executions: dict[str, WorkflowExecution] = {} self._max_workflows = max_workflows self._max_executions = max_executions + self._approval_events: dict[str, asyncio.Event] = {} # key: f"{execution_id}:{stage_name}" def save(self, workflow: WorkflowDefinition) -> WorkflowDefinition: workflow.updated_at = datetime.now(timezone.utc).isoformat() @@ -226,31 +228,70 @@ async def _execute_workflow( try: if stage.type == "approval": - # Pause execution and wait for approval + # Pause execution and wait for approval via asyncio.Event + event_key = f"{execution.execution_id}:{stage_name}" + approval_event = asyncio.Event() + _store._approval_events[event_key] = approval_event + execution.status = "paused" + execution.current_stage = stage_name _store.update_execution( execution.execution_id, status="paused", + current_stage=stage_name, ) await _broadcast_ws({ "event": "approval_required", "execution_id": execution.execution_id, "stage": stage_name, }) - # In a real implementation, this would wait for external approval - # For now, we simulate auto-approval after a brief pause - await asyncio.sleep(0.1) - execution.stage_results[stage_name] = { - "status": "approved", - "approver": "auto", - "comment": "自动审批通过", - } - execution.status = "running" - _store.update_execution( - execution.execution_id, - status="running", - stage_results=execution.stage_results, - ) + + # Wait for approval with timeout + try: + approval_timeout = stage.config.get("approval_timeout", 3600) + await asyncio.wait_for(approval_event.wait(), timeout=approval_timeout) + # Check if execution was cancelled/rejected while waiting + if execution.status == "cancelled": + await _broadcast_ws({ + "event": "stage_failed", + "execution_id": execution.execution_id, + "stage": stage_name, + "error": "Approval rejected", + }) + return + # Approval was granted — the /approve endpoint already set stage_results + # Only update status to running if not already set + if execution.status != "running": + execution.status = "running" + _store.update_execution( + execution.execution_id, + status="running", + ) + except asyncio.TimeoutError: + execution.stage_results[stage_name] = { + "status": "timeout", + "approver": "none", + "comment": "审批超时", + } + execution.status = "failed" + execution.error = f"Approval timeout for stage {stage_name}" + execution.completed_at = datetime.now(timezone.utc).isoformat() + _store.update_execution( + execution.execution_id, + status="failed", + error=execution.error, + completed_at=execution.completed_at, + stage_results=execution.stage_results, + ) + await _broadcast_ws({ + "event": "stage_failed", + "execution_id": execution.execution_id, + "stage": stage_name, + "error": "Approval timeout", + }) + return + finally: + _store._approval_events.pop(event_key, None) elif stage.type == "condition": # Evaluate condition expression condition_expr = stage.config.get("expression", "") @@ -318,23 +359,60 @@ async def _execute_workflow( }) +_SAFE_VAR_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') +_SAFE_OPERATORS = {"==", "!=", ">", "<", ">=", "<="} + + def _evaluate_condition(expression: str, variables: dict[str, Any]) -> bool: - """Simple condition evaluation.""" + """Evaluate a condition expression safely.""" + expression = expression.strip() if not expression: return True - if "==" in expression: - parts = expression.split("==", 1) - left = variables.get(parts[0].strip(), parts[0].strip()) - right = parts[1].strip().strip("'\"") - return str(left) == right - elif "!=" in expression: - parts = expression.split("!=", 1) - left = variables.get(parts[0].strip(), parts[0].strip()) - right = parts[1].strip().strip("'\"") - return str(left) != right - else: + + # Try each operator (longer operators first to avoid partial matches) + for op in sorted(_SAFE_OPERATORS, key=len, reverse=True): + if op in expression: + parts = expression.split(op, 1) + if len(parts) != 2: + continue + left = parts[0].strip() + right = parts[1].strip() + + # Validate variable names + if left and not _SAFE_VAR_PATTERN.match(left): + raise ValueError(f"Invalid variable name in condition: {left}") + + left_val = variables.get(left, left) + # Strip quotes from right side if present + if right.startswith('"') and right.endswith('"'): + right_val = right[1:-1] + elif right.startswith("'") and right.endswith("'"): + right_val = right[1:-1] + elif right and _SAFE_VAR_PATTERN.match(right): + right_val = variables.get(right, right) + else: + right_val = right + + # Compare based on operator + if op == "==": + return str(left_val) == str(right_val) + if op == "!=": + return str(left_val) != str(right_val) + if op == ">": + return float(left_val) > float(right_val) + if op == "<": + return float(left_val) < float(right_val) + if op == ">=": + return float(left_val) >= float(right_val) + if op == "<=": + return float(left_val) <= float(right_val) + + # Boolean check for variable existence + if _SAFE_VAR_PATTERN.match(expression): return bool(variables.get(expression)) + raise ValueError(f"Invalid condition expression: {expression}") + async def _broadcast_ws(message: dict[str, Any]) -> None: """Broadcast a message to all WebSocket subscribers.""" @@ -487,12 +565,12 @@ async def approve_execution( status="running", stage_results=execution.stage_results, ) - # Resume execution - workflow = store.get(execution.workflow_id) - if workflow: - asyncio.create_task( - _execute_workflow(workflow, execution, execution.variables, store=store) - ) + # Resume the waiting execution by setting the approval event + stage_name = execution.current_stage + if stage_name: + event_key = f"{execution_id}:{stage_name}" + if event_key in store._approval_events: + store._approval_events[event_key].set() else: execution.status = "cancelled" execution.completed_at = datetime.now(timezone.utc).isoformat() @@ -508,6 +586,12 @@ async def approve_execution( completed_at=execution.completed_at, stage_results=execution.stage_results, ) + # Set the approval event so the waiting coroutine can observe the cancelled state + stage_name = execution.current_stage + if stage_name: + event_key = f"{execution_id}:{stage_name}" + if event_key in store._approval_events: + store._approval_events[event_key].set() return execution.model_dump() diff --git a/src/agentkit/tools/pty_session.py b/src/agentkit/tools/pty_session.py index fdf10e1..7c41f9b 100644 --- a/src/agentkit/tools/pty_session.py +++ b/src/agentkit/tools/pty_session.py @@ -18,6 +18,8 @@ from dataclasses import dataclass logger = logging.getLogger(__name__) # 自动应答规则:(prompt_pattern, response) +# WARNING: auto_respond is disabled by default for safety. +# Enable it only when you explicitly want automatic yes/confirm responses. _AUTO_RESPOND_RULES: list[tuple[str, str]] = [ (r"\[y/N\]\s*$", "y"), (r"\[Y/n\]\s*$", "y"), @@ -61,7 +63,7 @@ class PTYSession: def __init__( self, - auto_respond: bool = True, + auto_respond: bool = False, custom_rules: list[tuple[str, str]] | None = None, default_timeout: float = 30.0, buffer_size: int = 4096, diff --git a/src/agentkit/tools/shell.py b/src/agentkit/tools/shell.py index 2d9bfdb..35940ea 100644 --- a/src/agentkit/tools/shell.py +++ b/src/agentkit/tools/shell.py @@ -9,6 +9,8 @@ from __future__ import annotations import asyncio import logging import os +import re +import shlex import time from typing import Any, Callable, Awaitable @@ -21,6 +23,8 @@ logger = logging.getLogger(__name__) # 安全白名单:这些命令前缀不需要确认 _SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( + "cd", + "export", "ls", "cat", "head", @@ -48,6 +52,7 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( "sort", "uniq", "diff", + "sleep", "git status", "git log", "git diff", @@ -101,6 +106,9 @@ _DANGEROUS_PATTERNS: tuple[str, ...] = ( ) +_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|`') + + class ShellTool(Tool): """Shell 命令执行工具 @@ -364,18 +372,39 @@ class ShellTool(Tool): """ command_stripped = command.strip() - # 白名单检查 - for prefix in _SAFE_COMMAND_PREFIXES: - if command_stripped.startswith(prefix): - return False + # Check for shell operators that chain commands (always dangerous) + if _SHELL_OPERATORS.search(command_stripped): + return True - # 危险模式检查 + # Parse the actual binary being invoked + try: + tokens = shlex.split(command_stripped) + if not tokens: + return True + binary = os.path.basename(tokens[0]) + except ValueError: + # Unparsable command - treat as dangerous + return True + + # Whitelist check: first try full command prefix match, then binary-only match + for prefix in _SAFE_COMMAND_PREFIXES: + prefix_stripped = prefix.lower().strip() + if " " in prefix_stripped: + # Compound prefix like "git status" - match against full command + if command_stripped.lower().startswith(prefix_stripped): + return False + else: + # Simple prefix - match against binary name only + if binary.lower().startswith(prefix_stripped): + return False + + # Dangerous pattern check command_lower = command_stripped.lower() for pattern in _DANGEROUS_PATTERNS: if pattern in command_lower: return True - return False + return True # Unknown commands are dangerous by default async def _request_confirmation(self, command: str) -> bool: """请求人工确认危险命令 diff --git a/src/agentkit/tools/terminal_session.py b/src/agentkit/tools/terminal_session.py index 6ab72db..59dcb08 100644 --- a/src/agentkit/tools/terminal_session.py +++ b/src/agentkit/tools/terminal_session.py @@ -9,6 +9,8 @@ from __future__ import annotations import asyncio import logging import os +import re +import shlex import time from dataclasses import dataclass, field from typing import Any @@ -17,6 +19,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput logger = logging.getLogger(__name__) +_ENV_KEY_PATTERN = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') + @dataclass class CommandRecord: @@ -190,15 +194,13 @@ class TerminalSession: # 注入 cd if self._cwd: - # 使用 shlex.quote 风格的简单转义 - cwd_escaped = self._cwd.replace("'", "'\\''") - parts.append(f"cd '{cwd_escaped}'") + parts.append(f"cd {shlex.quote(self._cwd)}") # 注入环境变量 for key, value in self._env.items(): - # 跳过 os.environ 中已有的且值未变的变量,减少命令长度 - val_escaped = value.replace("'", "'\\''") - parts.append(f"export {key}='{val_escaped}'") + if not _ENV_KEY_PATTERN.match(key): + continue # Skip invalid env key names + parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}") parts.append(command) return " && ".join(parts) diff --git a/src/agentkit/utils/__init__.py b/src/agentkit/utils/__init__.py new file mode 100644 index 0000000..2257399 --- /dev/null +++ b/src/agentkit/utils/__init__.py @@ -0,0 +1 @@ +"""AgentKit utility modules.""" diff --git a/src/agentkit/utils/vector_math.py b/src/agentkit/utils/vector_math.py new file mode 100644 index 0000000..857cbfc --- /dev/null +++ b/src/agentkit/utils/vector_math.py @@ -0,0 +1,30 @@ +"""Shared vector math utilities.""" + +from __future__ import annotations + +import math + + +def compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float: + """Compute cosine similarity between two vectors. + + Args: + vec_a: First vector. + vec_b: Second vector. + + Returns: + Cosine similarity score between -1 and 1. + """ + if len(vec_a) != len(vec_b): + return 0.0 + if not vec_a: + return 0.0 + + dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) + norm_a = math.sqrt(sum(a * a for a in vec_a)) + norm_b = math.sqrt(sum(b * b for b in vec_b)) + + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + + return dot_product / (norm_a * norm_b) diff --git a/tests/integration/test_goal_driven_scenario.py b/tests/integration/test_goal_driven_scenario.py index f76d301..7552265 100644 --- a/tests/integration/test_goal_driven_scenario.py +++ b/tests/integration/test_goal_driven_scenario.py @@ -433,12 +433,34 @@ async def test_workflow_with_approval(): assert execution.status == "pending" assert execution.execution_id - # 5. Execute workflow (runs in background) + # 5. Execute workflow in background (approval stage will wait for event) from agentkit.server.routes.workflows import _execute_workflow - await _execute_workflow(workflow, execution, variables={}, store=store) + # Use a short approval timeout for testing + workflow.stages[1].config["approval_timeout"] = 5 - # 6. Verify execution completed (auto-approval in test mode) + async def _approve_after_pause(): + """Wait for execution to pause, then approve.""" + for _ in range(100): + await asyncio.sleep(0.05) + updated = store.get_execution(execution.execution_id) + if updated and updated.status == "paused": + break + # Trigger approval + event_key = f"{execution.execution_id}:human_review" + if event_key in store._approval_events: + execution.stage_results["human_review"] = { + "status": "approved", + "approver": "test_user", + "comment": "Auto-approved in test", + } + store._approval_events[event_key].set() + + approve_task = asyncio.create_task(_approve_after_pause()) + await _execute_workflow(workflow, execution, variables={}, store=store) + await approve_task + + # 6. Verify execution completed updated = store.get_execution(execution.execution_id) assert updated is not None assert updated.status == "completed" @@ -452,7 +474,7 @@ async def test_workflow_with_approval(): approval_result = updated.stage_results["human_review"] assert approval_result.get("status") in ("approved", "completed") - # 9. Test manual approval flow + # 9. Test second workflow with approval workflow2 = WorkflowDefinition( workflow_id="wf-manual-approval", name="手动审批流程", @@ -468,6 +490,7 @@ async def test_workflow_with_approval(): agent="reviewer", action="approve", type="approval", + config={"approval_timeout": 5}, depends_on=["step1"], ), WorkflowStage( @@ -481,28 +504,30 @@ async def test_workflow_with_approval(): ) store.save(workflow2) - # Simulate manual approval via API execution2 = store.create_execution(workflow2.workflow_id) - execution2.status = "paused" - execution2.current_stage = "approval_step" - store.update_execution( - execution2.execution_id, - status="paused", - current_stage="approval_step", - ) - # Approve - execution2.stage_results["approval_step"] = { - "status": "approved", - "approver": "user", - "comment": "LGTM", - } - execution2.status = "running" - store.update_execution( - execution2.execution_id, - status="running", - stage_results=execution2.stage_results, - ) + async def _approve2_after_pause(): + for _ in range(100): + await asyncio.sleep(0.05) + updated2 = store.get_execution(execution2.execution_id) + if updated2 and updated2.status == "paused": + break + event_key2 = f"{execution2.execution_id}:approval_step" + if event_key2 in store._approval_events: + execution2.stage_results["approval_step"] = { + "status": "approved", + "approver": "user", + "comment": "LGTM", + } + store._approval_events[event_key2].set() + + approve_task2 = asyncio.create_task(_approve2_after_pause()) + await _execute_workflow(workflow2, execution2, variables={}, store=store) + await approve_task2 + + updated2 = store.get_execution(execution2.execution_id) + assert updated2 is not None + assert updated2.status == "completed" # Verify approval was recorded paused_exec = store.get_execution(execution2.execution_id) @@ -1014,7 +1039,27 @@ async def test_multi_source_rag_with_workflow(local_rag, mock_embedder): execution = store.create_execution(workflow.workflow_id) from agentkit.server.routes.workflows import _execute_workflow + # Set short approval timeout and handle approval + workflow.stages[1].config["approval_timeout"] = 5 + + async def _approve_kb_review(): + for _ in range(100): + await asyncio.sleep(0.05) + upd = store.get_execution(execution.execution_id) + if upd and upd.status == "paused": + break + event_key = f"{execution.execution_id}:review_findings" + if event_key in store._approval_events: + execution.stage_results["review_findings"] = { + "status": "approved", + "approver": "test_user", + "comment": "Approved", + } + store._approval_events[event_key].set() + + approve_task = asyncio.create_task(_approve_kb_review()) await _execute_workflow(workflow, execution, variables={}, store=store) + await approve_task updated = store.get_execution(execution.execution_id) assert updated.status == "completed" diff --git a/tests/unit/evolution/test_experience_store.py b/tests/unit/evolution/test_experience_store.py index 1587233..d5faf2d 100644 --- a/tests/unit/evolution/test_experience_store.py +++ b/tests/unit/evolution/test_experience_store.py @@ -9,10 +9,10 @@ import pytest from agentkit.evolution.experience_schema import EvolutionMetrics, TaskExperience from agentkit.evolution.experience_store import ( InMemoryExperienceStore, - _compute_cosine_similarity, _parse_time_window, ) from agentkit.memory.embedder import MockEmbedder +from agentkit.utils.vector_math import compute_cosine_similarity # ── Fixtures ────────────────────────────────────────────── @@ -136,23 +136,23 @@ class TestEvolutionMetrics: class TestHelperFunctions: def test_cosine_similarity_identical(self): vec = [1.0, 0.0, 0.0] - assert _compute_cosine_similarity(vec, vec) == pytest.approx(1.0) + assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0) def test_cosine_similarity_orthogonal(self): a = [1.0, 0.0] b = [0.0, 1.0] - assert _compute_cosine_similarity(a, b) == pytest.approx(0.0) + assert compute_cosine_similarity(a, b) == pytest.approx(0.0) def test_cosine_similarity_opposite(self): a = [1.0, 0.0] b = [-1.0, 0.0] - assert _compute_cosine_similarity(a, b) == pytest.approx(-1.0) + assert compute_cosine_similarity(a, b) == pytest.approx(-1.0) def test_cosine_similarity_empty(self): - assert _compute_cosine_similarity([], []) == 0.0 + assert compute_cosine_similarity([], []) == 0.0 def test_cosine_similarity_mismatched_dims(self): - assert _compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0 + assert compute_cosine_similarity([1.0], [1.0, 2.0]) == 0.0 def test_parse_time_window_hours(self): delta = _parse_time_window("24h") diff --git a/tests/unit/memory/test_adapters.py b/tests/unit/memory/test_adapters.py index ac7ca6b..481dbbc 100644 --- a/tests/unit/memory/test_adapters.py +++ b/tests/unit/memory/test_adapters.py @@ -1,6 +1,7 @@ """Tests for KnowledgeBase adapters — 飞书、Confluence、通用 HTTP 适配器""" import pytest +import time from unittest.mock import AsyncMock, MagicMock, patch from agentkit.memory.knowledge_base import Document, QueryResult, SourceInfo, KnowledgeBase @@ -64,7 +65,7 @@ class TestKnowledgeBaseProtocol: assert isinstance(adapter2, KnowledgeBase) adapter3 = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) assert isinstance(adapter3, KnowledgeBase) @@ -256,6 +257,8 @@ class TestFeishuKBAdapterSearch: async def test_search_success(self, adapter): # Mock authentication adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 @@ -307,6 +310,7 @@ class TestFeishuKBAdapterSearch: @pytest.mark.asyncio async def test_search_api_error(self, adapter): adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 @@ -328,6 +332,7 @@ class TestFeishuKBAdapterSearch: import httpx adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 500 @@ -380,6 +385,7 @@ class TestFeishuKBAdapterListSources: async def test_list_sources_success(self): adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 @@ -420,6 +426,7 @@ class TestFeishuKBAdapterGetDocument: async def test_get_document_success(self): adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 @@ -450,6 +457,7 @@ class TestFeishuKBAdapterGetDocument: async def test_get_document_not_found(self): adapter = FeishuKBAdapter(app_id="cli_test", app_secret="secret") adapter._access_token = "t-xxx" + adapter._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 @@ -736,23 +744,23 @@ class TestGenericHTTPAdapterInit: def test_basic_init(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) - assert adapter._endpoint_url == "http://localhost:8000/api/kb" + assert adapter._endpoint_url == "https://example.com/api/kb" assert adapter._auth_config == {} assert adapter._extra_headers == {} assert adapter._source_type == "generic_http" def test_init_with_auth_bearer(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb/", + endpoint_url="https://example.com/api/kb/", auth_config={"type": "bearer", "token": "sk-test"}, headers={"X-Custom": "value"}, source_id="my-kb", source_name="My KB", timeout=60, ) - assert adapter._endpoint_url == "http://localhost:8000/api/kb" + assert adapter._endpoint_url == "https://example.com/api/kb" assert adapter._auth_config["type"] == "bearer" assert adapter._extra_headers == {"X-Custom": "value"} assert adapter._source_id == "my-kb" @@ -760,7 +768,7 @@ class TestGenericHTTPAdapterInit: def test_client_bearer_auth_header(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", auth_config={"type": "bearer", "token": "sk-test"}, ) client = adapter._make_client() @@ -768,7 +776,7 @@ class TestGenericHTTPAdapterInit: def test_client_basic_auth_header(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", auth_config={"type": "basic", "username": "user", "password": "pass"}, ) client = adapter._make_client() @@ -777,7 +785,7 @@ class TestGenericHTTPAdapterInit: def test_client_api_key_header(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", auth_config={"type": "api_key", "header_name": "X-API-Key", "api_key": "key123"}, ) client = adapter._make_client() @@ -790,7 +798,7 @@ class TestGenericHTTPAdapterSearch: @pytest.fixture def adapter(self): return GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", auth_config={"type": "bearer", "token": "sk-test"}, ) @@ -892,7 +900,7 @@ class TestGenericHTTPAdapterIngest: @pytest.mark.asyncio async def test_ingest_success(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -921,7 +929,7 @@ class TestGenericHTTPAdapterIngest: import httpx adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -946,7 +954,7 @@ class TestGenericHTTPAdapterDeleteById: @pytest.mark.asyncio async def test_delete_success(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -963,7 +971,7 @@ class TestGenericHTTPAdapterDeleteById: @pytest.mark.asyncio async def test_delete_not_found(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -983,7 +991,7 @@ class TestGenericHTTPAdapterGetDocument: @pytest.mark.asyncio async def test_get_document_success(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -1011,7 +1019,7 @@ class TestGenericHTTPAdapterGetDocument: import httpx adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -1034,7 +1042,7 @@ class TestGenericHTTPAdapterHealthCheck: @pytest.mark.asyncio async def test_health_check_ok(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -1050,7 +1058,7 @@ class TestGenericHTTPAdapterHealthCheck: async def test_health_check_fallback_to_root(self): """health endpoint 不存在时回退到根路径""" adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) import httpx @@ -1076,7 +1084,7 @@ class TestGenericHTTPAdapterHealthCheck: @pytest.mark.asyncio async def test_health_check_connection_error(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_client = AsyncMock() @@ -1092,7 +1100,7 @@ class TestGenericHTTPAdapterListSources: @pytest.mark.asyncio async def test_list_sources_success(self): adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_resp = MagicMock() @@ -1116,7 +1124,7 @@ class TestGenericHTTPAdapterListSources: async def test_list_sources_endpoint_not_found(self): """sources endpoint 不存在时返回默认信息源""" adapter = GenericHTTPAdapter( - endpoint_url="http://localhost:8000/api/kb", + endpoint_url="https://example.com/api/kb", ) mock_client = AsyncMock() @@ -1146,7 +1154,7 @@ class TestCrossAdapterIntegration: username="user@test.com", api_token="token", ), - GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"), + GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"), ] for adapter in adapters: assert isinstance(adapter, KnowledgeBase) @@ -1166,7 +1174,7 @@ class TestCrossAdapterIntegration: username="user@test.com", api_token="token", ), - GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb"), + GenericHTTPAdapter(endpoint_url="https://example.com/api/kb"), ] for adapter in adapters: assert hasattr(adapter, "search") @@ -1179,6 +1187,7 @@ class TestCrossAdapterIntegration: # Feishu feishu = FeishuKBAdapter(app_id="cli_test", app_secret="secret") feishu._access_token = "t-xxx" + feishu._token_expiry = time.time() + 7200 mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() @@ -1220,7 +1229,7 @@ class TestCrossAdapterIntegration: assert all(isinstance(r, QueryResult) for r in results) # GenericHTTP - generic = GenericHTTPAdapter(endpoint_url="http://localhost:8000/api/kb") + generic = GenericHTTPAdapter(endpoint_url="https://example.com/api/kb") mock_resp3 = MagicMock() mock_resp3.status_code = 200 mock_resp3.raise_for_status = MagicMock() diff --git a/tests/unit/test_episodic_vector_search.py b/tests/unit/test_episodic_vector_search.py index 2fe4e80..36a28ec 100644 --- a/tests/unit/test_episodic_vector_search.py +++ b/tests/unit/test_episodic_vector_search.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import DeclarativeBase from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.base import MemoryItem from agentkit.memory.embedder import MockEmbedder +from agentkit.utils.vector_math import compute_cosine_similarity # ── 真实 SQLAlchemy 模型(用于测试) ───────────────────── @@ -112,40 +113,40 @@ def _make_row_mapping(data: dict) -> _RowMapping: class TestCosineSimilarity: - """_compute_cosine_similarity 测试""" + """compute_cosine_similarity 测试""" def test_identical_vectors_return_one(self): """相同向量余弦相似度为 1""" vec = [1.0, 0.0, 0.0] - assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0) + assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0) def test_orthogonal_vectors_return_zero(self): """正交向量余弦相似度为 0""" vec_a = [1.0, 0.0] vec_b = [0.0, 1.0] - assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0) + assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0) def test_opposite_vectors_return_minus_one(self): """相反向量余弦相似度为 -1""" vec_a = [1.0, 0.0] vec_b = [-1.0, 0.0] - assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0) + assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0) def test_dimension_mismatch_returns_zero(self): """维度不匹配返回 0""" vec_a = [1.0, 2.0] vec_b = [1.0] - assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0 + assert compute_cosine_similarity(vec_a, vec_b) == 0.0 def test_empty_vectors_return_zero(self): """空向量返回 0""" - assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0 + assert compute_cosine_similarity([], []) == 0.0 def test_zero_vector_returns_zero(self): """零向量返回 0""" vec_a = [0.0, 0.0] vec_b = [1.0, 2.0] - assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0 + assert compute_cosine_similarity(vec_a, vec_b) == 0.0 # ── MockEmbedder 测试 ─────────────────────────────────── diff --git a/tests/unit/tools/test_pty_session.py b/tests/unit/tools/test_pty_session.py index ce76382..18f0659 100644 --- a/tests/unit/tools/test_pty_session.py +++ b/tests/unit/tools/test_pty_session.py @@ -25,7 +25,7 @@ class TestPTYSessionConstruction: def test_default_construction(self): pty = PTYSession() assert pty.is_running is False - assert pty._auto_respond is True + assert pty._auto_respond is False assert pty._default_timeout == 30.0 def test_custom_construction(self):