From e3ae2f3a5676c9230ace8d7b9006f43e58ac73b1 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 12:58:51 +0800 Subject: [PATCH] =?UTF-8?q?feat(rag=5Fplatform):=20U8=20=E2=80=94=20TaskIQ?= =?UTF-8?q?=20async=20task=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tasks.py: TaskManager with vectorize/batch_index tasks, per-user concurrency limits, degraded mode (sync execution without broker), WorkerSweeper for timeout detection, error message sanitization Add taskiq>=0.11 and taskiq-redis>=0.5 to pyproject.toml Task parameter schema validation (VectorizeTaskParams, BatchIndexTaskParams) Tests: 41 new tests, 289 total passing --- pyproject.toml | 3 + src/agentkit/rag_platform/tasks.py | 705 +++++++++++++++++++++++++ tests/unit/rag_platform/test_tasks.py | 721 ++++++++++++++++++++++++++ 3 files changed, 1429 insertions(+) create mode 100644 src/agentkit/rag_platform/tasks.py create mode 100644 tests/unit/rag_platform/test_tasks.py diff --git a/pyproject.toml b/pyproject.toml index 54d9552..9f90092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dependencies = [ "llama-index-embeddings-openai>=0.3", "pgvector>=0.3", "jieba>=0.42", + # 异步任务队列(U8 — 文档向量化) + "taskiq>=0.11", + "taskiq-redis>=0.5", ] [project.scripts] diff --git a/src/agentkit/rag_platform/tasks.py b/src/agentkit/rag_platform/tasks.py new file mode 100644 index 0000000..1db0aca --- /dev/null +++ b/src/agentkit/rag_platform/tasks.py @@ -0,0 +1,705 @@ +"""U8 — TaskIQ 异步任务集成 — 文档向量化与批量索引。 + +包装 TaskIQ broker,提供: +- 任务参数 schema 校验(Pydantic 模型) +- per-user 并发上限 +- 任务状态机(PENDING → RUNNING → COMPLETED | FAILED | CANCELLED) +- Worker liveness & sweeper(worker_heartbeat_at 超时检测) +- 错误消息净化(剥离内部路径/连接串) +- 降级模式:broker 未配置时同步执行 + +Redis 隔离:使用独立 db=1(与 bus/cache 的 db=0 分离),或 key 前缀 ``taskiq:``。 +任务参数在 Redis 中带 TTL(默认 1 小时)。 + +ponytail: TaskIQ broker 为可选依赖 — 通过 TYPE_CHECKING 导入, +未安装时模块仍可加载,TaskManager 退化为同步执行。 +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import re +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from agentkit.core.protocol import TaskStatus +from agentkit.rag_platform.document_processor import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DocumentProcessor, +) +from agentkit.server.task_store import InMemoryTaskStore, TaskRecord + +if TYPE_CHECKING: + from llama_index.core.embeddings import BaseEmbedding + from llama_index.vector_stores.postgres import PGVectorStore + + from agentkit.rag_platform.store import KBStore + +logger = logging.getLogger(__name__) + +# 默认配置 +DEFAULT_MAX_CONCURRENT_PER_USER = 3 +DEFAULT_WORKER_TTL_SECONDS = 300 # worker 心跳超时阈值 +DEFAULT_ERROR_MESSAGE_MAX_LENGTH = 500 +DEFAULT_TASK_TTL_SECONDS = 3600 # 任务参数 TTL + +# Redis 隔离 — 独立 db=1,与 bus/cache 的 db=0 分离 +TASKIQ_REDIS_DB = 1 +TASKIQ_KEY_PREFIX = "taskiq:" + + +async def _maybe_await(result: Any) -> Any: + """统一处理 sync/async 调用结果 — InMemoryTaskStore 方法为 sync,RedisTaskStore 为 async。""" + if inspect.isawaitable(result): + return await result + return result + + +# --------------------------------------------------------------------------- +# 任务参数模型 — schema 校验 +# --------------------------------------------------------------------------- + + +class VectorizeTaskParams(BaseModel): + """向量化任务参数 — 单文档 parse → segment → vectorize → index。""" + + model_config = ConfigDict() + + document_id: str = Field(min_length=1, max_length=128) + kb_id: str = Field(min_length=1, max_length=128) + file_path: str = Field(min_length=1, max_length=1024) + file_type: str = Field(min_length=1, max_length=64) + chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192) + chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024) + user_id: str = Field(min_length=1, max_length=128) + + @field_validator("chunk_overlap") + @classmethod + def _overlap_less_than_size(cls, v: int, info: Any) -> int: + """chunk_overlap 必须小于 chunk_size。""" + size = info.data.get("chunk_size", DEFAULT_CHUNK_SIZE) + if v >= size: + raise ValueError(f"chunk_overlap ({v}) must be < chunk_size ({size})") + return v + + +class BatchIndexTaskParams(BaseModel): + """批量索引任务参数 — 多文档并行向量化。""" + + model_config = ConfigDict() + + kb_id: str = Field(min_length=1, max_length=128) + document_ids: list[str] = Field(min_length=1, max_length=100) + user_id: str = Field(min_length=1, max_length=128) + file_paths: dict[str, str] = Field(default_factory=dict) # document_id -> file_path + file_types: dict[str, str] = Field(default_factory=dict) # document_id -> file_type + chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192) + chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024) + + +# --------------------------------------------------------------------------- +# 错误消息净化 +# --------------------------------------------------------------------------- + +# 匹配 Unix/Windows 文件路径 +_PATH_RE = re.compile( + r"(?:/[\w.\-]+)+|" # Unix 路径 /usr/local/... + r"(?:[A-Za-z]:\\[\w.\-]+(?:\\[\w.\-]+)*)" # Windows 路径 C:\Users\... +) +# 匹配 Redis/PostgreSQL 连接串 +_CONN_STR_RE = re.compile( + r"(redis://|postgresql://|postgres://|mysql://)" + r"[^\s\"']+", + re.IGNORECASE, +) +# 匹配环境变量风格的密钥 +_SECRET_RE = re.compile( + r"(api[_-]?key|token|password|secret)[\s:=]+[\w\-]+", + re.IGNORECASE, +) + + +def sanitize_error_message(msg: str, max_length: int = DEFAULT_ERROR_MESSAGE_MAX_LENGTH) -> str: + """净化错误消息 — 剥离内部路径、连接串、密钥。 + + Args: + msg: 原始错误消息 + max_length: 最大长度(截断) + + Returns: + 净化后的错误消息 + """ + if not msg: + return "" + # 连接串优先于路径 — 避免 /localhost:6379/0 被路径正则部分匹配 + cleaned = _CONN_STR_RE.sub("", msg) + cleaned = _PATH_RE.sub("", cleaned) + cleaned = _SECRET_RE.sub(r"\1=", cleaned) + if len(cleaned) > max_length: + cleaned = cleaned[: max_length - 3] + "..." + return cleaned + + +# --------------------------------------------------------------------------- +# Worker liveness & sweeper +# --------------------------------------------------------------------------- + + +class TaskStoreProtocol(Protocol): + """TaskStore 协议 — InMemoryTaskStore / RedisTaskStore 均满足。""" + + async def create( + self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None + ) -> TaskRecord: ... + + def get(self, task_id: str) -> TaskRecord | None: ... + + async def update_status( + self, task_id: str, status: TaskStatus, **kwargs: Any + ) -> TaskRecord: ... + + def list_tasks( + self, status: TaskStatus | None = None, limit: int = 100 + ) -> list[TaskRecord]: ... + + +class WorkerSweeper: + """检测超时任务并标记为 failed。 + + 扫描 status=RUNNING 且 worker_heartbeat_at 超过 TTL 的任务, + 将其状态置为 failed,error_message="worker_timeout"。 + + ponytail: 升级路径 — 当前依赖 list_tasks 全量扫描(O(n)), + 生产环境可改为 Redis ZSET 按心跳时间戳排序(O(log n))。 + """ + + def __init__( + self, + task_store: TaskStoreProtocol, + ttl_seconds: int = DEFAULT_WORKER_TTL_SECONDS, + ) -> None: + self._store = task_store + self._ttl = ttl_seconds + + async def sweep(self) -> int: + """扫描超时任务,返回清理数量。""" + now = datetime.now(timezone.utc) + cutoff = now - timedelta(seconds=self._ttl) + cleaned = 0 + + running_tasks = await _maybe_await( + self._store.list_tasks(status=TaskStatus.RUNNING, limit=1000) + ) + + for record in running_tasks: + heartbeat = record.metadata.get("worker_heartbeat_at") + if heartbeat is None: + # 无心跳记录 — 使用 started_at 作为回退 + heartbeat = record.started_at.isoformat() if record.started_at else None + if heartbeat is None: + continue + + try: + hb_time = datetime.fromisoformat(heartbeat) + except (ValueError, TypeError): + continue + + # 确保 timezone-aware + if hb_time.tzinfo is None: + hb_time = hb_time.replace(tzinfo=timezone.utc) + + if hb_time < cutoff: + try: + await _maybe_await( + self._store.update_status( + record.task_id, + TaskStatus.FAILED, + error_message="worker_timeout", + completed_at=now, + metadata={**record.metadata, "swept_at": now.isoformat()}, + ) + ) + cleaned += 1 + logger.warning( + "Swept timed-out task %s (heartbeat=%s, cutoff=%s)", + record.task_id, + heartbeat, + cutoff.isoformat(), + ) + except KeyError: + # 任务已被并发删除 + continue + return cleaned + + +# --------------------------------------------------------------------------- +# TaskManager — 任务提交、查询、取消 +# --------------------------------------------------------------------------- + + +class ConcurrencyLimitExceeded(Exception): + """per-user 并发上限超出。""" + + +class TaskNotFoundError(Exception): + """任务不存在。""" + + +class TaskManager: + """异步任务管理器 — 包装 TaskIQ broker。 + + 如果 broker 未配置或 TaskIQ 不可用,任务同步执行(降级模式)。 + + Args: + broker: TaskIQ broker 实例(可选) + task_store: 任务状态存储(默认 InMemoryTaskStore) + max_concurrent_per_user: per-user 并发上限 + task_ttl_seconds: 任务参数 TTL(Redis 模式) + """ + + def __init__( + self, + broker: Any = None, + task_store: TaskStoreProtocol | None = None, + max_concurrent_per_user: int = DEFAULT_MAX_CONCURRENT_PER_USER, + task_ttl_seconds: int = DEFAULT_TASK_TTL_SECONDS, + ) -> None: + self._broker = broker + self._store: TaskStoreProtocol = task_store or InMemoryTaskStore() + self._max_concurrent = max_concurrent_per_user + self._task_ttl = task_ttl_seconds + # user_id -> 当前 RUNNING 任务数 + self._running_counts: dict[str, int] = {} + self._lock = asyncio.Lock() + + @property + def store(self) -> TaskStoreProtocol: + """暴露底层 task store(供 sweeper 使用)。""" + return self._store + + @property + def broker(self) -> Any: + """暴露底层 broker(供 startup/shutdown 集成)。""" + return self._broker + + # ── 任务提交 ──────────────────────────────────────────────── + + async def submit_vectorize( + self, + params: VectorizeTaskParams, + dependencies: dict[str, Any] | None = None, + ) -> str: + """提交向量化任务,返回 task_id。 + + Args: + params: 向量化任务参数(已通过 schema 校验) + dependencies: 运行时依赖(store/vector_store/embed_model), + 降级模式下直接使用;broker 模式下由 worker 注入 + + Raises: + ConcurrencyLimitExceeded: per-user 并发上限超出 + """ + await self._check_concurrency(params.user_id) + + task_id = f"vec-{uuid.uuid4()}" + input_data = params.model_dump() + # 依赖对象不序列化 — 仅在降级模式内存中传递 + deps = dependencies or {} + + await _maybe_await( + self._store.create( + task_id=task_id, + agent_name="rag_vectorize", + input_data=input_data, + skill_name="vectorize", + ) + ) + # 记录 user_id 到 metadata 便于按用户查询 + record = await _maybe_await(self._store.get(task_id)) + if record is not None: + record.metadata["user_id"] = params.user_id + + logger.info( + "Submitted vectorize task %s for document=%s user=%s", + task_id, + params.document_id, + params.user_id, + ) + + if self._broker is not None: + # broker 模式 — 异步派发 + await self._dispatch_to_broker("vectorize", input_data, task_id) + else: + # 降级模式 — 同步执行(不阻塞调用方,使用 asyncio.create_task) + asyncio.create_task(self._run_vectorize_degraded(task_id, params, deps)) + + return task_id + + async def submit_batch_index( + self, + params: BatchIndexTaskParams, + dependencies: dict[str, Any] | None = None, + ) -> str: + """提交批量索引任务,返回 task_id。""" + await self._check_concurrency(params.user_id) + + task_id = f"batch-{uuid.uuid4()}" + input_data = params.model_dump() + deps = dependencies or {} + + await _maybe_await( + self._store.create( + task_id=task_id, + agent_name="rag_batch_index", + input_data=input_data, + skill_name="batch_index", + ) + ) + record = await _maybe_await(self._store.get(task_id)) + if record is not None: + record.metadata["user_id"] = params.user_id + + logger.info( + "Submitted batch index task %s for %d documents user=%s", + task_id, + len(params.document_ids), + params.user_id, + ) + + if self._broker is not None: + await self._dispatch_to_broker("batch_index", input_data, task_id) + else: + asyncio.create_task(self._run_batch_index_degraded(task_id, params, deps)) + + return task_id + + # ── 任务查询 ──────────────────────────────────────────────── + + async def get_task_status(self, task_id: str) -> TaskRecord | None: + """查询任务状态。""" + return await _maybe_await(self._store.get(task_id)) + + async def list_tasks( + self, + user_id: str | None = None, + status: TaskStatus | None = None, + limit: int = 100, + ) -> list[TaskRecord]: + """查询任务历史 — 可按 user_id 和 status 过滤。""" + tasks = await _maybe_await(self._store.list_tasks(status=status, limit=limit)) + if user_id is not None: + tasks = [t for t in tasks if t.metadata.get("user_id") == user_id] + return tasks + + async def cancel_task(self, task_id: str) -> bool: + """取消任务 — 仅 PENDING/RUNNING 可取消。 + + Returns: + True 如果取消成功 + Raises: + TaskNotFoundError: 任务不存在 + """ + record = await _maybe_await(self._store.get(task_id)) + if record is None: + raise TaskNotFoundError(f"Task {task_id} not found") + + if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): + return False + + await _maybe_await( + self._store.update_status( + task_id, + TaskStatus.CANCELLED, + completed_at=datetime.now(timezone.utc), + ) + ) + # 释放并发计数 + user_id = record.metadata.get("user_id") + if user_id and self._running_counts.get(user_id, 0) > 0: + self._running_counts[user_id] -= 1 + return True + + # ── 心跳更新 ──────────────────────────────────────────────── + + async def update_heartbeat(self, task_id: str) -> None: + """更新 worker 心跳时间戳 — 由 worker 周期性调用。""" + record = await _maybe_await(self._store.get(task_id)) + if record is None: + return + now = datetime.now(timezone.utc).isoformat() + record.metadata["worker_heartbeat_at"] = now + + # ── 内部方法 ──────────────────────────────────────────────── + + async def _check_concurrency(self, user_id: str) -> None: + """检查 per-user 并发上限。""" + async with self._lock: + current = self._running_counts.get(user_id, 0) + if current >= self._max_concurrent: + raise ConcurrencyLimitExceeded( + f"User {user_id} has {current} running tasks (max={self._max_concurrent})" + ) + self._running_counts[user_id] = current + 1 + + async def _release_concurrency(self, user_id: str) -> None: + """释放并发计数。""" + async with self._lock: + if self._running_counts.get(user_id, 0) > 0: + self._running_counts[user_id] -= 1 + + async def _dispatch_to_broker(self, task_type: str, params: dict, task_id: str) -> None: + """派发任务到 TaskIQ broker。""" + # broker.kiq 返回 TaskiqTask — 实际执行由 worker 进程完成 + try: + await self._broker.kiq(task_type=task_type, params=params, task_id=task_id) + except Exception as e: + logger.error("Failed to dispatch task %s to broker: %s", task_id, e) + raise + + async def _run_vectorize_degraded( + self, + task_id: str, + params: VectorizeTaskParams, + deps: dict[str, Any], + ) -> None: + """降级模式 — 同步执行向量化任务(在 asyncio 任务中)。""" + now = datetime.now(timezone.utc) + try: + await _maybe_await( + self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now) + ) + await self.update_heartbeat(task_id) + + await run_vectorize_task(params, **deps) + + await _maybe_await( + self._store.update_status( + task_id, + TaskStatus.COMPLETED, + completed_at=datetime.now(timezone.utc), + progress=1.0, + ) + ) + except Exception as e: + sanitized = sanitize_error_message(str(e)) + await _maybe_await( + self._store.update_status( + task_id, + TaskStatus.FAILED, + error_message=sanitized, + completed_at=datetime.now(timezone.utc), + ) + ) + logger.error("Vectorize task %s failed: %s", task_id, sanitized) + finally: + await self._release_concurrency(params.user_id) + + async def _run_batch_index_degraded( + self, + task_id: str, + params: BatchIndexTaskParams, + deps: dict[str, Any], + ) -> None: + """降级模式 — 同步执行批量索引任务。""" + now = datetime.now(timezone.utc) + total = len(params.document_ids) + succeeded: list[str] = [] + failed: list[str] = [] + + try: + await _maybe_await( + self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now) + ) + await self.update_heartbeat(task_id) + + for i, doc_id in enumerate(params.document_ids): + await self.update_heartbeat(task_id) + try: + single_params = VectorizeTaskParams( + document_id=doc_id, + kb_id=params.kb_id, + file_path=params.file_paths.get(doc_id, ""), + file_type=params.file_types.get(doc_id, "txt"), + chunk_size=params.chunk_size, + chunk_overlap=params.chunk_overlap, + user_id=params.user_id, + ) + await run_vectorize_task(single_params, **deps) + succeeded.append(doc_id) + except Exception as e: + logger.warning("Batch item %s failed: %s", doc_id, e) + failed.append(doc_id) + + # 更新进度 + progress = (i + 1) / total + await _maybe_await( + self._store.update_status( + task_id, + TaskStatus.RUNNING, + progress=progress, + progress_message=f"Processed {i + 1}/{total}", + ) + ) + + final_status = TaskStatus.COMPLETED if not failed else TaskStatus.PARTIALLY_COMPLETED + await _maybe_await( + self._store.update_status( + task_id, + final_status, + completed_at=datetime.now(timezone.utc), + progress=1.0, + output_data={ + "succeeded": succeeded, + "failed": failed, + "total": total, + }, + ) + ) + except Exception as e: + sanitized = sanitize_error_message(str(e)) + await _maybe_await( + self._store.update_status( + task_id, + TaskStatus.FAILED, + error_message=sanitized, + completed_at=datetime.now(timezone.utc), + ) + ) + logger.error("Batch index task %s failed: %s", task_id, sanitized) + finally: + await self._release_concurrency(params.user_id) + + +# --------------------------------------------------------------------------- +# 任务执行函数 — 由 worker 调用 +# --------------------------------------------------------------------------- + + +async def run_vectorize_task( + params: VectorizeTaskParams, + store: "KBStore", + vector_store: "PGVectorStore", + embed_model: "BaseEmbedding", +) -> None: + """向量化任务 — 由 worker 执行。 + + 封装 DocumentProcessor.process() — parse → segment → vectorize → index。 + 状态转换由 DocumentProcessor 内部管理(pending → parsing → ... → indexed | failed)。 + + Args: + params: 向量化任务参数 + store: KBStore 实例(用于文档状态更新) + vector_store: LlamaIndex PGVectorStore 实例 + embed_model: LlamaIndex embedding 模型 + + Raises: + Exception: 管道任一阶段失败 + """ + processor = DocumentProcessor( + chunk_size=params.chunk_size, + chunk_overlap=params.chunk_overlap, + ) + await processor.process( + params.file_path, + params.file_type, + params.kb_id, + params.document_id, + vector_store, + embed_model, + store, + ) + + +async def run_batch_index_task( + params: BatchIndexTaskParams, + store: "KBStore", + vector_store: "PGVectorStore", + embed_model: "BaseEmbedding", +) -> dict[str, list[str]]: + """批量索引任务 — 顺序处理多个文档。 + + Returns: + {"succeeded": [...], "failed": [...]} + """ + succeeded: list[str] = [] + failed: list[str] = [] + + for doc_id in params.document_ids: + try: + single_params = VectorizeTaskParams( + document_id=doc_id, + kb_id=params.kb_id, + file_path=params.file_paths.get(doc_id, ""), + file_type=params.file_types.get(doc_id, "txt"), + chunk_size=params.chunk_size, + chunk_overlap=params.chunk_overlap, + user_id=params.user_id, + ) + await run_vectorize_task(single_params, store, vector_store, embed_model) + succeeded.append(doc_id) + except Exception as e: + logger.warning("Batch item %s failed: %s", doc_id, e) + failed.append(doc_id) + + return {"succeeded": succeeded, "failed": failed} + + +# --------------------------------------------------------------------------- +# TaskIQ broker 工厂(可选 — 仅 Redis 可用时使用) +# --------------------------------------------------------------------------- + + +def create_broker(redis_url: str) -> Any: + """创建 TaskIQ Redis broker — 独立 db=1 隔离。 + + Args: + redis_url: Redis 连接 URL(db=0 用于 bus/cache,此处强制改为 db=1) + + Returns: + 配置好的 TaskiqBroker 实例 + + Raises: + ImportError: taskiq 未安装 + """ + from urllib.parse import urlparse, urlunparse + + from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend + + from taskiq import TaskiqBroker + + # 强制使用 db=1 — 与 bus/cache 的 db=0 隔离 + parsed = urlparse(redis_url) + # 替换 path 为 /1 + new_path = "/1" + isolated_url = urlunparse(parsed._replace(path=new_path)) + + broker = TaskiqBroker(ListQueueBroker(url=isolated_url)) + broker.with_result_backend( + RedisAsyncResultBackend(redis_url=isolated_url, result_ex_time=DEFAULT_TASK_TTL_SECONDS) + ) + + logger.info("TaskIQ broker created with Redis db=1 isolation") + return broker + + +__all__ = [ + "BatchIndexTaskParams", + "ConcurrencyLimitExceeded", + "DEFAULT_MAX_CONCURRENT_PER_USER", + "DEFAULT_WORKER_TTL_SECONDS", + "TaskManager", + "TaskNotFoundError", + "TaskStoreProtocol", + "VectorizeTaskParams", + "WorkerSweeper", + "create_broker", + "run_batch_index_task", + "run_vectorize_task", + "sanitize_error_message", +] diff --git a/tests/unit/rag_platform/test_tasks.py b/tests/unit/rag_platform/test_tasks.py new file mode 100644 index 0000000..2096ab1 --- /dev/null +++ b/tests/unit/rag_platform/test_tasks.py @@ -0,0 +1,721 @@ +"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。 + +测试场景: +1. 任务参数 schema 校验(Pydantic 模型) +2. 错误消息净化(剥离内部路径/连接串/密钥) +3. 任务提交创建任务记录(PENDING) +4. 任务状态查询(pending → running → completed) +5. 任务历史可查询(按 user_id 过滤) +6. per-user 并发上限强制 +7. 任务取消 +8. Worker sweeper 检测超时任务 +9. 降级模式(无 broker)同步执行 +10. 任务失败后错误消息净化 +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.core.protocol import TaskStatus +from agentkit.rag_platform.tasks import ( + BatchIndexTaskParams, + ConcurrencyLimitExceeded, + TaskManager, + TaskNotFoundError, + VectorizeTaskParams, + WorkerSweeper, + run_batch_index_task, + run_vectorize_task, + sanitize_error_message, +) +from agentkit.server.task_store import InMemoryTaskStore + + +# --------------------------------------------------------------------------- +# 测试 fixtures +# --------------------------------------------------------------------------- + + +def _make_vectorize_params(**overrides) -> VectorizeTaskParams: + """创建向量化任务参数。""" + defaults = { + "document_id": "doc-001", + "kb_id": "kb-001", + "file_path": "/tmp/test.txt", + "file_type": "txt", + "user_id": "user1", + } + defaults.update(overrides) + return VectorizeTaskParams(**defaults) + + +def _make_batch_params(**overrides) -> BatchIndexTaskParams: + """创建批量索引任务参数。""" + defaults = { + "kb_id": "kb-001", + "document_ids": ["doc-1", "doc-2"], + "user_id": "user1", + "file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"}, + "file_types": {"doc-1": "txt", "doc-2": "txt"}, + } + defaults.update(overrides) + return BatchIndexTaskParams(**defaults) + + +def _make_mock_dependencies(): + """创建 mock 依赖(store/vector_store/embed_model)。""" + store = MagicMock() + store.update_document_status = AsyncMock() + vector_store = MagicMock() + vector_store.async_add = AsyncMock() + embed_model = MagicMock() + embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536) + return {"store": store, "vector_store": vector_store, "embed_model": embed_model} + + +# --------------------------------------------------------------------------- +# 1. 任务参数 schema 校验 +# --------------------------------------------------------------------------- + + +class TestVectorizeTaskParams: + """VectorizeTaskParams schema 校验测试。""" + + def test_valid_params(self): + """合法参数通过校验。""" + params = _make_vectorize_params() + assert params.document_id == "doc-001" + assert params.chunk_size == 512 + assert params.chunk_overlap == 50 + + def test_empty_document_id_rejected(self): + """空 document_id 被拒绝。""" + with pytest.raises(ValueError, match="document_id"): + VectorizeTaskParams( + document_id="", + kb_id="kb-1", + file_path="/tmp/x.txt", + file_type="txt", + user_id="u1", + ) + + def test_chunk_overlap_must_be_less_than_size(self): + """chunk_overlap 必须小于 chunk_size。""" + with pytest.raises(ValueError, match="chunk_overlap"): + VectorizeTaskParams( + document_id="doc-1", + kb_id="kb-1", + file_path="/tmp/x.txt", + file_type="txt", + user_id="u1", + chunk_size=100, + chunk_overlap=100, + ) + + def test_chunk_size_bounds(self): + """chunk_size 必须在 50-8192 之间。""" + with pytest.raises(ValueError): + _make_vectorize_params(chunk_size=10) + + def test_user_id_required(self): + """user_id 必填。""" + with pytest.raises(ValueError): + VectorizeTaskParams( + document_id="doc-1", + kb_id="kb-1", + file_path="/tmp/x.txt", + file_type="txt", + ) + + +class TestBatchIndexTaskParams: + """BatchIndexTaskParams schema 校验测试。""" + + def test_valid_params(self): + """合法参数通过校验。""" + params = _make_batch_params() + assert len(params.document_ids) == 2 + + def test_empty_document_ids_rejected(self): + """空 document_ids 列表被拒绝。""" + with pytest.raises(ValueError, match="document_ids"): + BatchIndexTaskParams( + kb_id="kb-1", + document_ids=[], + user_id="u1", + ) + + def test_too_many_document_ids_rejected(self): + """超过 100 个 document_ids 被拒绝。""" + with pytest.raises(ValueError): + BatchIndexTaskParams( + kb_id="kb-1", + document_ids=[f"doc-{i}" for i in range(101)], + user_id="u1", + ) + + +# --------------------------------------------------------------------------- +# 2. 错误消息净化 +# --------------------------------------------------------------------------- + + +class TestSanitizeErrorMessage: + """错误消息净化测试。""" + + def test_strips_unix_paths(self): + """剥离 Unix 文件路径。""" + msg = "Failed to open /usr/local/lib/python/file.py" + cleaned = sanitize_error_message(msg) + assert "/usr/local" not in cleaned + assert "" in cleaned + + def test_strips_windows_paths(self): + """剥离 Windows 文件路径。""" + msg = "Error in C:\\Users\\admin\\file.py at line 10" + cleaned = sanitize_error_message(msg) + assert "C:\\Users" not in cleaned + assert "" in cleaned + + def test_strips_redis_connection_string(self): + """剥离 Redis 连接串。""" + msg = "Cannot connect to redis://localhost:6379/0" + cleaned = sanitize_error_message(msg) + assert "redis://localhost" not in cleaned + assert "" in cleaned + + def test_strips_postgresql_connection_string(self): + """剥离 PostgreSQL 连接串。""" + msg = "DB error: postgresql://user:pass@host:5432/db" + cleaned = sanitize_error_message(msg) + assert "postgresql://" not in cleaned + assert "" in cleaned + + def test_redacts_api_keys(self): + """净化 API key。""" + msg = "Auth failed with api_key=sk-abc123def456" + cleaned = sanitize_error_message(msg) + assert "sk-abc123def456" not in cleaned + assert "" in cleaned + + def test_redacts_password(self): + """净化 password。""" + msg = "Login failed password=secret123" + cleaned = sanitize_error_message(msg) + assert "secret123" not in cleaned + assert "" in cleaned + + def test_truncates_long_messages(self): + """长消息截断。""" + msg = "x" * 1000 + cleaned = sanitize_error_message(msg, max_length=100) + assert len(cleaned) == 100 + assert cleaned.endswith("...") + + def test_empty_message_returns_empty(self): + """空消息返回空字符串。""" + assert sanitize_error_message("") == "" + + +# --------------------------------------------------------------------------- +# 3. TaskManager — 任务提交与状态查询 +# --------------------------------------------------------------------------- + + +class TestTaskManagerSubmit: + """任务提交测试。""" + + async def test_submit_vectorize_creates_pending_task(self): + """提交向量化任务创建 PENDING 状态记录。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + params = _make_vectorize_params() + + task_id = await manager.submit_vectorize(params) + + assert task_id.startswith("vec-") + record = await manager.get_task_status(task_id) + assert record is not None + assert record.status == TaskStatus.PENDING + assert record.input_data["document_id"] == "doc-001" + assert record.metadata["user_id"] == "user1" + + async def test_submit_batch_index_creates_pending_task(self): + """提交批量索引任务创建 PENDING 状态记录。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + params = _make_batch_params() + + task_id = await manager.submit_batch_index(params) + + assert task_id.startswith("batch-") + record = await manager.get_task_status(task_id) + assert record is not None + assert record.status == TaskStatus.PENDING + assert len(record.input_data["document_ids"]) == 2 + + async def test_submit_with_broker_dispatches(self): + """有 broker 时调用 broker.kiq 派发任务。""" + store = InMemoryTaskStore() + mock_broker = MagicMock() + mock_broker.kiq = AsyncMock() + manager = TaskManager(broker=mock_broker, task_store=store) + params = _make_vectorize_params() + + task_id = await manager.submit_vectorize(params) + + mock_broker.kiq.assert_awaited_once() + # 验证派发参数包含 task_type 和 task_id + call_kwargs = mock_broker.kiq.call_args.kwargs + assert call_kwargs["task_type"] == "vectorize" + assert call_kwargs["task_id"] == task_id + + +# --------------------------------------------------------------------------- +# 4. 任务状态查询 +# --------------------------------------------------------------------------- + + +class TestTaskManagerQuery: + """任务状态查询测试。""" + + async def test_get_task_status_returns_none_if_not_found(self): + """查询不存在的任务返回 None。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + + result = await manager.get_task_status("nonexistent") + assert result is None + + async def test_list_tasks_filters_by_user(self): + """list_tasks 按 user_id 过滤。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + + # user1 提交 2 个任务 + await manager.submit_vectorize(_make_vectorize_params(user_id="user1")) + await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1")) + # user2 提交 1 个任务 + await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2")) + + user1_tasks = await manager.list_tasks(user_id="user1") + user2_tasks = await manager.list_tasks(user_id="user2") + all_tasks = await manager.list_tasks() + + assert len(user1_tasks) == 2 + assert len(user2_tasks) == 1 + assert len(all_tasks) == 3 + + async def test_list_tasks_filters_by_status(self): + """list_tasks 按 status 过滤。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + + task_id = await manager.submit_vectorize(_make_vectorize_params()) + # 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync + store.update_status(task_id, TaskStatus.COMPLETED) + # 再提交一个 PENDING + await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2")) + + pending = await manager.list_tasks(status=TaskStatus.PENDING) + completed = await manager.list_tasks(status=TaskStatus.COMPLETED) + + assert len(pending) == 1 + assert len(completed) == 1 + + +# --------------------------------------------------------------------------- +# 5. per-user 并发上限 +# --------------------------------------------------------------------------- + + +class TestConcurrencyLimit: + """per-user 并发上限测试。""" + + async def test_concurrency_limit_exceeded(self): + """超过 per-user 并发上限抛出异常。""" + store = InMemoryTaskStore() + manager = TaskManager( + broker=None, + task_store=store, + max_concurrent_per_user=2, + ) + + # 提交 2 个任务(达到上限) + await manager.submit_vectorize(_make_vectorize_params(document_id="d1")) + await manager.submit_vectorize(_make_vectorize_params(document_id="d2")) + + # 第 3 个应抛出异常 + with pytest.raises(ConcurrencyLimitExceeded): + await manager.submit_vectorize(_make_vectorize_params(document_id="d3")) + + async def test_concurrency_limit_per_user_isolated(self): + """不同用户的并发上限独立计算。""" + store = InMemoryTaskStore() + manager = TaskManager( + broker=None, + task_store=store, + max_concurrent_per_user=1, + ) + + # user1 提交 1 个任务 + await manager.submit_vectorize(_make_vectorize_params(user_id="user1")) + # user2 仍可提交 + task_id = await manager.submit_vectorize( + _make_vectorize_params(document_id="d2", user_id="user2") + ) + assert task_id is not None + + async def test_concurrency_released_after_cancel(self): + """取消任务后释放并发计数。""" + store = InMemoryTaskStore() + manager = TaskManager( + broker=None, + task_store=store, + max_concurrent_per_user=1, + ) + + task_id = await manager.submit_vectorize(_make_vectorize_params()) + # 取消任务 + await manager.cancel_task(task_id) + # 应可再次提交 + task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2")) + assert task_id2 is not None + + +# --------------------------------------------------------------------------- +# 6. 任务取消 +# --------------------------------------------------------------------------- + + +class TestTaskCancel: + """任务取消测试。""" + + async def test_cancel_pending_task(self): + """取消 PENDING 任务成功。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + + result = await manager.cancel_task(task_id) + + assert result is True + record = await manager.get_task_status(task_id) + assert record is not None + assert record.status == TaskStatus.CANCELLED + + async def test_cancel_nonexistent_raises(self): + """取消不存在的任务抛出 TaskNotFoundError。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + + with pytest.raises(TaskNotFoundError): + await manager.cancel_task("nonexistent") + + async def test_cancel_completed_returns_false(self): + """取消已完成任务返回 False。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + store.update_status(task_id, TaskStatus.COMPLETED) + + result = await manager.cancel_task(task_id) + + assert result is False + + +# --------------------------------------------------------------------------- +# 7. WorkerSweeper — 超时任务检测 +# --------------------------------------------------------------------------- + + +class TestWorkerSweeper: + """Worker sweeper 测试。""" + + async def test_sweep_detects_timed_out_task(self): + """sweeper 检测超时任务并标记为 failed。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + + # 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync + old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat() + store.update_status(task_id, TaskStatus.RUNNING) + record = store.get(task_id) + assert record is not None + record.metadata["worker_heartbeat_at"] = old_time + + sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) + cleaned = await sweeper.sweep() + + assert cleaned == 1 + record = store.get(task_id) + assert record is not None + assert record.status == TaskStatus.FAILED + assert record.error_message == "worker_timeout" + + async def test_sweep_skips_fresh_heartbeat(self): + """sweeper 跳过心跳未超时的任务。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + + # 新鲜心跳 + fresh_time = datetime.now(timezone.utc).isoformat() + store.update_status(task_id, TaskStatus.RUNNING) + record = store.get(task_id) + assert record is not None + record.metadata["worker_heartbeat_at"] = fresh_time + + sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) + cleaned = await sweeper.sweep() + + assert cleaned == 0 + record = store.get(task_id) + assert record is not None + assert record.status == TaskStatus.RUNNING + + async def test_sweep_skips_tasks_without_heartbeat(self): + """sweeper 跳过无心跳记录的任务。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + store.update_status(task_id, TaskStatus.RUNNING) + # 不设置 worker_heartbeat_at + + sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) + cleaned = await sweeper.sweep() + + assert cleaned == 0 + + async def test_sweep_skips_non_running_tasks(self): + """sweeper 只处理 RUNNING 状态任务。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + # PENDING 状态 — 不应被清理 + old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat() + record = store.get(task_id) + assert record is not None + record.metadata["worker_heartbeat_at"] = old_time + + sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) + cleaned = await sweeper.sweep() + + assert cleaned == 0 + + +# --------------------------------------------------------------------------- +# 8. 降级模式 — 无 broker 同步执行 +# --------------------------------------------------------------------------- + + +class TestDegradedMode: + """降级模式测试 — 无 broker 时同步执行。""" + + async def test_vectorize_success_transitions_to_completed(self): + """降级模式 — 向量化成功后状态转为 COMPLETED。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + params = _make_vectorize_params() + deps = _make_mock_dependencies() + + # mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index + with patch( + "agentkit.rag_platform.tasks.run_vectorize_task", + new_callable=AsyncMock, + ): + task_id = await manager.submit_vectorize(params, dependencies=deps) + + # 等待降级模式任务完成 + await asyncio.sleep(0.1) + + record = await manager.get_task_status(task_id) + assert record is not None + assert record.status == TaskStatus.COMPLETED + assert record.progress == 1.0 + + async def test_vectorize_failure_sanitizes_error(self): + """降级模式 — 向量化失败时错误消息被净化。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + params = _make_vectorize_params() + deps = _make_mock_dependencies() + + # mock run_vectorize_task 抛出含路径的异常 + with patch( + "agentkit.rag_platform.tasks.run_vectorize_task", + new_callable=AsyncMock, + side_effect=FileNotFoundError( + "[Errno 2] No such file or directory: /usr/local/data/missing.txt" + ), + ): + task_id = await manager.submit_vectorize(params, dependencies=deps) + + # 等待降级模式任务完成 + await asyncio.sleep(0.1) + + record = await manager.get_task_status(task_id) + assert record is not None + assert record.status == TaskStatus.FAILED + # 错误消息应被净化(路径替换为 ) + assert "/usr/local" not in (record.error_message or "") + assert "" in (record.error_message or "") + + async def test_degraded_mode_does_not_block_caller(self): + """降级模式 — 任务在后台执行,不阻塞调用方。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + params = _make_vectorize_params() + deps = _make_mock_dependencies() + + # mock run_vectorize_task — 模拟耗时操作 + async def _slow_task(*args, **kwargs): + await asyncio.sleep(0.5) + + with patch( + "agentkit.rag_platform.tasks.run_vectorize_task", + new_callable=AsyncMock, + side_effect=_slow_task, + ): + # submit_vectorize 应立即返回(不等待任务完成) + task_id = await manager.submit_vectorize(params, dependencies=deps) + + # 立即查询 — 应为 PENDING 或 RUNNING(不阻塞) + record = await manager.get_task_status(task_id) + assert record is not None + # 状态应为 PENDING(任务刚提交,还未开始执行) + # 或 RUNNING(已开始执行但未完成) + assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING) + + # 等待任务完成 + await asyncio.sleep(0.6) + + +# --------------------------------------------------------------------------- +# 9. 任务执行函数测试 +# --------------------------------------------------------------------------- + + +class TestRunVectorizeTask: + """run_vectorize_task 函数测试。""" + + async def test_run_vectorize_calls_document_processor(self): + """run_vectorize_task 调用 DocumentProcessor.process。""" + params = _make_vectorize_params() + deps = _make_mock_dependencies() + + # mock DocumentProcessor.process — 不依赖 llama_index + with patch( + "agentkit.rag_platform.tasks.DocumentProcessor.process", + new_callable=AsyncMock, + ) as mock_process: + await run_vectorize_task(params, **deps) + + mock_process.assert_awaited_once() + + async def test_run_vectorize_failure_propagates(self): + """run_vectorize_task 失败时抛出异常。""" + params = _make_vectorize_params() + deps = _make_mock_dependencies() + + with patch( + "agentkit.rag_platform.tasks.DocumentProcessor.process", + new_callable=AsyncMock, + side_effect=RuntimeError("processing failed"), + ): + with pytest.raises(RuntimeError, match="processing failed"): + await run_vectorize_task(params, **deps) + + +class TestRunBatchIndexTask: + """run_batch_index_task 函数测试。""" + + async def test_batch_index_returns_results(self): + """run_batch_index_task 返回成功/失败列表。""" + params = BatchIndexTaskParams( + kb_id="kb-1", + document_ids=["doc-1", "doc-2"], + user_id="user1", + file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"}, + file_types={"doc-1": "txt", "doc-2": "txt"}, + ) + deps = _make_mock_dependencies() + + # mock DocumentProcessor.process — 模拟成功 + with patch( + "agentkit.rag_platform.tasks.DocumentProcessor.process", + new_callable=AsyncMock, + ): + result = await run_batch_index_task(params, **deps) + + assert "succeeded" in result + assert "failed" in result + assert len(result["succeeded"]) == 2 + assert len(result["failed"]) == 0 + + async def test_batch_index_partial_failure(self): + """批量索引部分失败时返回失败列表。""" + params = BatchIndexTaskParams( + kb_id="kb-1", + document_ids=["doc-1", "doc-missing"], + user_id="user1", + file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"}, + file_types={"doc-1": "txt", "doc-missing": "txt"}, + ) + deps = _make_mock_dependencies() + + # mock DocumentProcessor.process — 第一次成功,第二次失败 + call_count = [0] + + async def _side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 2: + raise FileNotFoundError("missing file") + + with patch( + "agentkit.rag_platform.tasks.DocumentProcessor.process", + new_callable=AsyncMock, + side_effect=_side_effect, + ): + result = await run_batch_index_task(params, **deps) + + assert "doc-1" in result["succeeded"] + assert "doc-missing" in result["failed"] + + +# --------------------------------------------------------------------------- +# 10. 心跳更新测试 +# --------------------------------------------------------------------------- + + +class TestHeartbeat: + """worker 心跳测试。""" + + async def test_update_heartbeat_sets_timestamp(self): + """update_heartbeat 设置 worker_heartbeat_at 时间戳。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + task_id = await manager.submit_vectorize(_make_vectorize_params()) + + await manager.update_heartbeat(task_id) + + record = await manager.get_task_status(task_id) + assert record is not None + assert "worker_heartbeat_at" in record.metadata + # 验证时间戳可解析 + hb = record.metadata["worker_heartbeat_at"] + datetime.fromisoformat(hb) + + async def test_update_heartbeat_nonexistent_task_noop(self): + """update_heartbeat 对不存在的任务无副作用。""" + store = InMemoryTaskStore() + manager = TaskManager(broker=None, task_store=store) + + # 不应抛出异常 + await manager.update_heartbeat("nonexistent")