feat(rag_platform): U8 — TaskIQ async task integration

Add tasks.py: TaskManager with vectorize/batch_index tasks, per-user concurrency limits, degraded mode (sync execution without broker), WorkerSweeper for timeout detection, error message sanitization
Add taskiq>=0.11 and taskiq-redis>=0.5 to pyproject.toml
Task parameter schema validation (VectorizeTaskParams, BatchIndexTaskParams)

Tests: 41 new tests, 289 total passing
This commit is contained in:
chiguyong 2026-06-25 12:58:51 +08:00
parent d026a91f43
commit e3ae2f3a56
3 changed files with 1429 additions and 0 deletions

View File

@ -43,6 +43,9 @@ dependencies = [
"llama-index-embeddings-openai>=0.3",
"pgvector>=0.3",
"jieba>=0.42",
# 异步任务队列U8 — 文档向量化)
"taskiq>=0.11",
"taskiq-redis>=0.5",
]
[project.scripts]

View File

@ -0,0 +1,705 @@
"""U8 — TaskIQ 异步任务集成 — 文档向量化与批量索引。
包装 TaskIQ broker提供
- 任务参数 schema 校验Pydantic 模型
- per-user 并发上限
- 任务状态机PENDING RUNNING COMPLETED | FAILED | CANCELLED
- Worker liveness & sweeperworker_heartbeat_at 超时检测
- 错误消息净化剥离内部路径/连接串
- 降级模式broker 未配置时同步执行
Redis 隔离使用独立 db=1 bus/cache db=0 分离 key 前缀 ``taskiq:``
任务参数在 Redis 中带 TTL默认 1 小时
ponytail: TaskIQ broker 为可选依赖 通过 TYPE_CHECKING 导入
未安装时模块仍可加载TaskManager 退化为同步执行
"""
from __future__ import annotations
import asyncio
import inspect
import logging
import re
import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import BaseModel, ConfigDict, Field, field_validator
from agentkit.core.protocol import TaskStatus
from agentkit.rag_platform.document_processor import (
DEFAULT_CHUNK_OVERLAP,
DEFAULT_CHUNK_SIZE,
DocumentProcessor,
)
from agentkit.server.task_store import InMemoryTaskStore, TaskRecord
if TYPE_CHECKING:
from llama_index.core.embeddings import BaseEmbedding
from llama_index.vector_stores.postgres import PGVectorStore
from agentkit.rag_platform.store import KBStore
logger = logging.getLogger(__name__)
# 默认配置
DEFAULT_MAX_CONCURRENT_PER_USER = 3
DEFAULT_WORKER_TTL_SECONDS = 300 # worker 心跳超时阈值
DEFAULT_ERROR_MESSAGE_MAX_LENGTH = 500
DEFAULT_TASK_TTL_SECONDS = 3600 # 任务参数 TTL
# Redis 隔离 — 独立 db=1与 bus/cache 的 db=0 分离
TASKIQ_REDIS_DB = 1
TASKIQ_KEY_PREFIX = "taskiq:"
async def _maybe_await(result: Any) -> Any:
"""统一处理 sync/async 调用结果 — InMemoryTaskStore 方法为 syncRedisTaskStore 为 async。"""
if inspect.isawaitable(result):
return await result
return result
# ---------------------------------------------------------------------------
# 任务参数模型 — schema 校验
# ---------------------------------------------------------------------------
class VectorizeTaskParams(BaseModel):
"""向量化任务参数 — 单文档 parse → segment → vectorize → index。"""
model_config = ConfigDict()
document_id: str = Field(min_length=1, max_length=128)
kb_id: str = Field(min_length=1, max_length=128)
file_path: str = Field(min_length=1, max_length=1024)
file_type: str = Field(min_length=1, max_length=64)
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192)
chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024)
user_id: str = Field(min_length=1, max_length=128)
@field_validator("chunk_overlap")
@classmethod
def _overlap_less_than_size(cls, v: int, info: Any) -> int:
"""chunk_overlap 必须小于 chunk_size。"""
size = info.data.get("chunk_size", DEFAULT_CHUNK_SIZE)
if v >= size:
raise ValueError(f"chunk_overlap ({v}) must be < chunk_size ({size})")
return v
class BatchIndexTaskParams(BaseModel):
"""批量索引任务参数 — 多文档并行向量化。"""
model_config = ConfigDict()
kb_id: str = Field(min_length=1, max_length=128)
document_ids: list[str] = Field(min_length=1, max_length=100)
user_id: str = Field(min_length=1, max_length=128)
file_paths: dict[str, str] = Field(default_factory=dict) # document_id -> file_path
file_types: dict[str, str] = Field(default_factory=dict) # document_id -> file_type
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192)
chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024)
# ---------------------------------------------------------------------------
# 错误消息净化
# ---------------------------------------------------------------------------
# 匹配 Unix/Windows 文件路径
_PATH_RE = re.compile(
r"(?:/[\w.\-]+)+|" # Unix 路径 /usr/local/...
r"(?:[A-Za-z]:\\[\w.\-]+(?:\\[\w.\-]+)*)" # Windows 路径 C:\Users\...
)
# 匹配 Redis/PostgreSQL 连接串
_CONN_STR_RE = re.compile(
r"(redis://|postgresql://|postgres://|mysql://)"
r"[^\s\"']+",
re.IGNORECASE,
)
# 匹配环境变量风格的密钥
_SECRET_RE = re.compile(
r"(api[_-]?key|token|password|secret)[\s:=]+[\w\-]+",
re.IGNORECASE,
)
def sanitize_error_message(msg: str, max_length: int = DEFAULT_ERROR_MESSAGE_MAX_LENGTH) -> str:
"""净化错误消息 — 剥离内部路径、连接串、密钥。
Args:
msg: 原始错误消息
max_length: 最大长度截断
Returns:
净化后的错误消息
"""
if not msg:
return ""
# 连接串优先于路径 — 避免 /localhost:6379/0 被路径正则部分匹配
cleaned = _CONN_STR_RE.sub("<connection-string>", msg)
cleaned = _PATH_RE.sub("<path>", cleaned)
cleaned = _SECRET_RE.sub(r"\1=<redacted>", cleaned)
if len(cleaned) > max_length:
cleaned = cleaned[: max_length - 3] + "..."
return cleaned
# ---------------------------------------------------------------------------
# Worker liveness & sweeper
# ---------------------------------------------------------------------------
class TaskStoreProtocol(Protocol):
"""TaskStore 协议 — InMemoryTaskStore / RedisTaskStore 均满足。"""
async def create(
self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None
) -> TaskRecord: ...
def get(self, task_id: str) -> TaskRecord | None: ...
async def update_status(
self, task_id: str, status: TaskStatus, **kwargs: Any
) -> TaskRecord: ...
def list_tasks(
self, status: TaskStatus | None = None, limit: int = 100
) -> list[TaskRecord]: ...
class WorkerSweeper:
"""检测超时任务并标记为 failed。
扫描 status=RUNNING worker_heartbeat_at 超过 TTL 的任务
将其状态置为 failederror_message="worker_timeout"
ponytail: 升级路径 当前依赖 list_tasks 全量扫描O(n)
生产环境可改为 Redis ZSET 按心跳时间戳排序O(log n)
"""
def __init__(
self,
task_store: TaskStoreProtocol,
ttl_seconds: int = DEFAULT_WORKER_TTL_SECONDS,
) -> None:
self._store = task_store
self._ttl = ttl_seconds
async def sweep(self) -> int:
"""扫描超时任务,返回清理数量。"""
now = datetime.now(timezone.utc)
cutoff = now - timedelta(seconds=self._ttl)
cleaned = 0
running_tasks = await _maybe_await(
self._store.list_tasks(status=TaskStatus.RUNNING, limit=1000)
)
for record in running_tasks:
heartbeat = record.metadata.get("worker_heartbeat_at")
if heartbeat is None:
# 无心跳记录 — 使用 started_at 作为回退
heartbeat = record.started_at.isoformat() if record.started_at else None
if heartbeat is None:
continue
try:
hb_time = datetime.fromisoformat(heartbeat)
except (ValueError, TypeError):
continue
# 确保 timezone-aware
if hb_time.tzinfo is None:
hb_time = hb_time.replace(tzinfo=timezone.utc)
if hb_time < cutoff:
try:
await _maybe_await(
self._store.update_status(
record.task_id,
TaskStatus.FAILED,
error_message="worker_timeout",
completed_at=now,
metadata={**record.metadata, "swept_at": now.isoformat()},
)
)
cleaned += 1
logger.warning(
"Swept timed-out task %s (heartbeat=%s, cutoff=%s)",
record.task_id,
heartbeat,
cutoff.isoformat(),
)
except KeyError:
# 任务已被并发删除
continue
return cleaned
# ---------------------------------------------------------------------------
# TaskManager — 任务提交、查询、取消
# ---------------------------------------------------------------------------
class ConcurrencyLimitExceeded(Exception):
"""per-user 并发上限超出。"""
class TaskNotFoundError(Exception):
"""任务不存在。"""
class TaskManager:
"""异步任务管理器 — 包装 TaskIQ broker。
如果 broker 未配置或 TaskIQ 不可用任务同步执行降级模式
Args:
broker: TaskIQ broker 实例可选
task_store: 任务状态存储默认 InMemoryTaskStore
max_concurrent_per_user: per-user 并发上限
task_ttl_seconds: 任务参数 TTLRedis 模式
"""
def __init__(
self,
broker: Any = None,
task_store: TaskStoreProtocol | None = None,
max_concurrent_per_user: int = DEFAULT_MAX_CONCURRENT_PER_USER,
task_ttl_seconds: int = DEFAULT_TASK_TTL_SECONDS,
) -> None:
self._broker = broker
self._store: TaskStoreProtocol = task_store or InMemoryTaskStore()
self._max_concurrent = max_concurrent_per_user
self._task_ttl = task_ttl_seconds
# user_id -> 当前 RUNNING 任务数
self._running_counts: dict[str, int] = {}
self._lock = asyncio.Lock()
@property
def store(self) -> TaskStoreProtocol:
"""暴露底层 task store供 sweeper 使用)。"""
return self._store
@property
def broker(self) -> Any:
"""暴露底层 broker供 startup/shutdown 集成)。"""
return self._broker
# ── 任务提交 ────────────────────────────────────────────────
async def submit_vectorize(
self,
params: VectorizeTaskParams,
dependencies: dict[str, Any] | None = None,
) -> str:
"""提交向量化任务,返回 task_id。
Args:
params: 向量化任务参数已通过 schema 校验
dependencies: 运行时依赖store/vector_store/embed_model
降级模式下直接使用broker 模式下由 worker 注入
Raises:
ConcurrencyLimitExceeded: per-user 并发上限超出
"""
await self._check_concurrency(params.user_id)
task_id = f"vec-{uuid.uuid4()}"
input_data = params.model_dump()
# 依赖对象不序列化 — 仅在降级模式内存中传递
deps = dependencies or {}
await _maybe_await(
self._store.create(
task_id=task_id,
agent_name="rag_vectorize",
input_data=input_data,
skill_name="vectorize",
)
)
# 记录 user_id 到 metadata 便于按用户查询
record = await _maybe_await(self._store.get(task_id))
if record is not None:
record.metadata["user_id"] = params.user_id
logger.info(
"Submitted vectorize task %s for document=%s user=%s",
task_id,
params.document_id,
params.user_id,
)
if self._broker is not None:
# broker 模式 — 异步派发
await self._dispatch_to_broker("vectorize", input_data, task_id)
else:
# 降级模式 — 同步执行(不阻塞调用方,使用 asyncio.create_task
asyncio.create_task(self._run_vectorize_degraded(task_id, params, deps))
return task_id
async def submit_batch_index(
self,
params: BatchIndexTaskParams,
dependencies: dict[str, Any] | None = None,
) -> str:
"""提交批量索引任务,返回 task_id。"""
await self._check_concurrency(params.user_id)
task_id = f"batch-{uuid.uuid4()}"
input_data = params.model_dump()
deps = dependencies or {}
await _maybe_await(
self._store.create(
task_id=task_id,
agent_name="rag_batch_index",
input_data=input_data,
skill_name="batch_index",
)
)
record = await _maybe_await(self._store.get(task_id))
if record is not None:
record.metadata["user_id"] = params.user_id
logger.info(
"Submitted batch index task %s for %d documents user=%s",
task_id,
len(params.document_ids),
params.user_id,
)
if self._broker is not None:
await self._dispatch_to_broker("batch_index", input_data, task_id)
else:
asyncio.create_task(self._run_batch_index_degraded(task_id, params, deps))
return task_id
# ── 任务查询 ────────────────────────────────────────────────
async def get_task_status(self, task_id: str) -> TaskRecord | None:
"""查询任务状态。"""
return await _maybe_await(self._store.get(task_id))
async def list_tasks(
self,
user_id: str | None = None,
status: TaskStatus | None = None,
limit: int = 100,
) -> list[TaskRecord]:
"""查询任务历史 — 可按 user_id 和 status 过滤。"""
tasks = await _maybe_await(self._store.list_tasks(status=status, limit=limit))
if user_id is not None:
tasks = [t for t in tasks if t.metadata.get("user_id") == user_id]
return tasks
async def cancel_task(self, task_id: str) -> bool:
"""取消任务 — 仅 PENDING/RUNNING 可取消。
Returns:
True 如果取消成功
Raises:
TaskNotFoundError: 任务不存在
"""
record = await _maybe_await(self._store.get(task_id))
if record is None:
raise TaskNotFoundError(f"Task {task_id} not found")
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
return False
await _maybe_await(
self._store.update_status(
task_id,
TaskStatus.CANCELLED,
completed_at=datetime.now(timezone.utc),
)
)
# 释放并发计数
user_id = record.metadata.get("user_id")
if user_id and self._running_counts.get(user_id, 0) > 0:
self._running_counts[user_id] -= 1
return True
# ── 心跳更新 ────────────────────────────────────────────────
async def update_heartbeat(self, task_id: str) -> None:
"""更新 worker 心跳时间戳 — 由 worker 周期性调用。"""
record = await _maybe_await(self._store.get(task_id))
if record is None:
return
now = datetime.now(timezone.utc).isoformat()
record.metadata["worker_heartbeat_at"] = now
# ── 内部方法 ────────────────────────────────────────────────
async def _check_concurrency(self, user_id: str) -> None:
"""检查 per-user 并发上限。"""
async with self._lock:
current = self._running_counts.get(user_id, 0)
if current >= self._max_concurrent:
raise ConcurrencyLimitExceeded(
f"User {user_id} has {current} running tasks (max={self._max_concurrent})"
)
self._running_counts[user_id] = current + 1
async def _release_concurrency(self, user_id: str) -> None:
"""释放并发计数。"""
async with self._lock:
if self._running_counts.get(user_id, 0) > 0:
self._running_counts[user_id] -= 1
async def _dispatch_to_broker(self, task_type: str, params: dict, task_id: str) -> None:
"""派发任务到 TaskIQ broker。"""
# broker.kiq 返回 TaskiqTask — 实际执行由 worker 进程完成
try:
await self._broker.kiq(task_type=task_type, params=params, task_id=task_id)
except Exception as e:
logger.error("Failed to dispatch task %s to broker: %s", task_id, e)
raise
async def _run_vectorize_degraded(
self,
task_id: str,
params: VectorizeTaskParams,
deps: dict[str, Any],
) -> None:
"""降级模式 — 同步执行向量化任务(在 asyncio 任务中)。"""
now = datetime.now(timezone.utc)
try:
await _maybe_await(
self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now)
)
await self.update_heartbeat(task_id)
await run_vectorize_task(params, **deps)
await _maybe_await(
self._store.update_status(
task_id,
TaskStatus.COMPLETED,
completed_at=datetime.now(timezone.utc),
progress=1.0,
)
)
except Exception as e:
sanitized = sanitize_error_message(str(e))
await _maybe_await(
self._store.update_status(
task_id,
TaskStatus.FAILED,
error_message=sanitized,
completed_at=datetime.now(timezone.utc),
)
)
logger.error("Vectorize task %s failed: %s", task_id, sanitized)
finally:
await self._release_concurrency(params.user_id)
async def _run_batch_index_degraded(
self,
task_id: str,
params: BatchIndexTaskParams,
deps: dict[str, Any],
) -> None:
"""降级模式 — 同步执行批量索引任务。"""
now = datetime.now(timezone.utc)
total = len(params.document_ids)
succeeded: list[str] = []
failed: list[str] = []
try:
await _maybe_await(
self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now)
)
await self.update_heartbeat(task_id)
for i, doc_id in enumerate(params.document_ids):
await self.update_heartbeat(task_id)
try:
single_params = VectorizeTaskParams(
document_id=doc_id,
kb_id=params.kb_id,
file_path=params.file_paths.get(doc_id, ""),
file_type=params.file_types.get(doc_id, "txt"),
chunk_size=params.chunk_size,
chunk_overlap=params.chunk_overlap,
user_id=params.user_id,
)
await run_vectorize_task(single_params, **deps)
succeeded.append(doc_id)
except Exception as e:
logger.warning("Batch item %s failed: %s", doc_id, e)
failed.append(doc_id)
# 更新进度
progress = (i + 1) / total
await _maybe_await(
self._store.update_status(
task_id,
TaskStatus.RUNNING,
progress=progress,
progress_message=f"Processed {i + 1}/{total}",
)
)
final_status = TaskStatus.COMPLETED if not failed else TaskStatus.PARTIALLY_COMPLETED
await _maybe_await(
self._store.update_status(
task_id,
final_status,
completed_at=datetime.now(timezone.utc),
progress=1.0,
output_data={
"succeeded": succeeded,
"failed": failed,
"total": total,
},
)
)
except Exception as e:
sanitized = sanitize_error_message(str(e))
await _maybe_await(
self._store.update_status(
task_id,
TaskStatus.FAILED,
error_message=sanitized,
completed_at=datetime.now(timezone.utc),
)
)
logger.error("Batch index task %s failed: %s", task_id, sanitized)
finally:
await self._release_concurrency(params.user_id)
# ---------------------------------------------------------------------------
# 任务执行函数 — 由 worker 调用
# ---------------------------------------------------------------------------
async def run_vectorize_task(
params: VectorizeTaskParams,
store: "KBStore",
vector_store: "PGVectorStore",
embed_model: "BaseEmbedding",
) -> None:
"""向量化任务 — 由 worker 执行。
封装 DocumentProcessor.process() parse segment vectorize index
状态转换由 DocumentProcessor 内部管理pending parsing ... indexed | failed
Args:
params: 向量化任务参数
store: KBStore 实例用于文档状态更新
vector_store: LlamaIndex PGVectorStore 实例
embed_model: LlamaIndex embedding 模型
Raises:
Exception: 管道任一阶段失败
"""
processor = DocumentProcessor(
chunk_size=params.chunk_size,
chunk_overlap=params.chunk_overlap,
)
await processor.process(
params.file_path,
params.file_type,
params.kb_id,
params.document_id,
vector_store,
embed_model,
store,
)
async def run_batch_index_task(
params: BatchIndexTaskParams,
store: "KBStore",
vector_store: "PGVectorStore",
embed_model: "BaseEmbedding",
) -> dict[str, list[str]]:
"""批量索引任务 — 顺序处理多个文档。
Returns:
{"succeeded": [...], "failed": [...]}
"""
succeeded: list[str] = []
failed: list[str] = []
for doc_id in params.document_ids:
try:
single_params = VectorizeTaskParams(
document_id=doc_id,
kb_id=params.kb_id,
file_path=params.file_paths.get(doc_id, ""),
file_type=params.file_types.get(doc_id, "txt"),
chunk_size=params.chunk_size,
chunk_overlap=params.chunk_overlap,
user_id=params.user_id,
)
await run_vectorize_task(single_params, store, vector_store, embed_model)
succeeded.append(doc_id)
except Exception as e:
logger.warning("Batch item %s failed: %s", doc_id, e)
failed.append(doc_id)
return {"succeeded": succeeded, "failed": failed}
# ---------------------------------------------------------------------------
# TaskIQ broker 工厂(可选 — 仅 Redis 可用时使用)
# ---------------------------------------------------------------------------
def create_broker(redis_url: str) -> Any:
"""创建 TaskIQ Redis broker — 独立 db=1 隔离。
Args:
redis_url: Redis 连接 URLdb=0 用于 bus/cache此处强制改为 db=1
Returns:
配置好的 TaskiqBroker 实例
Raises:
ImportError: taskiq 未安装
"""
from urllib.parse import urlparse, urlunparse
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
from taskiq import TaskiqBroker
# 强制使用 db=1 — 与 bus/cache 的 db=0 隔离
parsed = urlparse(redis_url)
# 替换 path 为 /1
new_path = "/1"
isolated_url = urlunparse(parsed._replace(path=new_path))
broker = TaskiqBroker(ListQueueBroker(url=isolated_url))
broker.with_result_backend(
RedisAsyncResultBackend(redis_url=isolated_url, result_ex_time=DEFAULT_TASK_TTL_SECONDS)
)
logger.info("TaskIQ broker created with Redis db=1 isolation")
return broker
__all__ = [
"BatchIndexTaskParams",
"ConcurrencyLimitExceeded",
"DEFAULT_MAX_CONCURRENT_PER_USER",
"DEFAULT_WORKER_TTL_SECONDS",
"TaskManager",
"TaskNotFoundError",
"TaskStoreProtocol",
"VectorizeTaskParams",
"WorkerSweeper",
"create_broker",
"run_batch_index_task",
"run_vectorize_task",
"sanitize_error_message",
]

View File

@ -0,0 +1,721 @@
"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。
测试场景
1. 任务参数 schema 校验Pydantic 模型
2. 错误消息净化剥离内部路径/连接串/密钥
3. 任务提交创建任务记录PENDING
4. 任务状态查询pending running completed
5. 任务历史可查询 user_id 过滤
6. per-user 并发上限强制
7. 任务取消
8. Worker sweeper 检测超时任务
9. 降级模式 broker同步执行
10. 任务失败后错误消息净化
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.protocol import TaskStatus
from agentkit.rag_platform.tasks import (
BatchIndexTaskParams,
ConcurrencyLimitExceeded,
TaskManager,
TaskNotFoundError,
VectorizeTaskParams,
WorkerSweeper,
run_batch_index_task,
run_vectorize_task,
sanitize_error_message,
)
from agentkit.server.task_store import InMemoryTaskStore
# ---------------------------------------------------------------------------
# 测试 fixtures
# ---------------------------------------------------------------------------
def _make_vectorize_params(**overrides) -> VectorizeTaskParams:
"""创建向量化任务参数。"""
defaults = {
"document_id": "doc-001",
"kb_id": "kb-001",
"file_path": "/tmp/test.txt",
"file_type": "txt",
"user_id": "user1",
}
defaults.update(overrides)
return VectorizeTaskParams(**defaults)
def _make_batch_params(**overrides) -> BatchIndexTaskParams:
"""创建批量索引任务参数。"""
defaults = {
"kb_id": "kb-001",
"document_ids": ["doc-1", "doc-2"],
"user_id": "user1",
"file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
"file_types": {"doc-1": "txt", "doc-2": "txt"},
}
defaults.update(overrides)
return BatchIndexTaskParams(**defaults)
def _make_mock_dependencies():
"""创建 mock 依赖store/vector_store/embed_model"""
store = MagicMock()
store.update_document_status = AsyncMock()
vector_store = MagicMock()
vector_store.async_add = AsyncMock()
embed_model = MagicMock()
embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return {"store": store, "vector_store": vector_store, "embed_model": embed_model}
# ---------------------------------------------------------------------------
# 1. 任务参数 schema 校验
# ---------------------------------------------------------------------------
class TestVectorizeTaskParams:
"""VectorizeTaskParams schema 校验测试。"""
def test_valid_params(self):
"""合法参数通过校验。"""
params = _make_vectorize_params()
assert params.document_id == "doc-001"
assert params.chunk_size == 512
assert params.chunk_overlap == 50
def test_empty_document_id_rejected(self):
"""空 document_id 被拒绝。"""
with pytest.raises(ValueError, match="document_id"):
VectorizeTaskParams(
document_id="",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
user_id="u1",
)
def test_chunk_overlap_must_be_less_than_size(self):
"""chunk_overlap 必须小于 chunk_size。"""
with pytest.raises(ValueError, match="chunk_overlap"):
VectorizeTaskParams(
document_id="doc-1",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
user_id="u1",
chunk_size=100,
chunk_overlap=100,
)
def test_chunk_size_bounds(self):
"""chunk_size 必须在 50-8192 之间。"""
with pytest.raises(ValueError):
_make_vectorize_params(chunk_size=10)
def test_user_id_required(self):
"""user_id 必填。"""
with pytest.raises(ValueError):
VectorizeTaskParams(
document_id="doc-1",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
)
class TestBatchIndexTaskParams:
"""BatchIndexTaskParams schema 校验测试。"""
def test_valid_params(self):
"""合法参数通过校验。"""
params = _make_batch_params()
assert len(params.document_ids) == 2
def test_empty_document_ids_rejected(self):
"""空 document_ids 列表被拒绝。"""
with pytest.raises(ValueError, match="document_ids"):
BatchIndexTaskParams(
kb_id="kb-1",
document_ids=[],
user_id="u1",
)
def test_too_many_document_ids_rejected(self):
"""超过 100 个 document_ids 被拒绝。"""
with pytest.raises(ValueError):
BatchIndexTaskParams(
kb_id="kb-1",
document_ids=[f"doc-{i}" for i in range(101)],
user_id="u1",
)
# ---------------------------------------------------------------------------
# 2. 错误消息净化
# ---------------------------------------------------------------------------
class TestSanitizeErrorMessage:
"""错误消息净化测试。"""
def test_strips_unix_paths(self):
"""剥离 Unix 文件路径。"""
msg = "Failed to open /usr/local/lib/python/file.py"
cleaned = sanitize_error_message(msg)
assert "/usr/local" not in cleaned
assert "<path>" in cleaned
def test_strips_windows_paths(self):
"""剥离 Windows 文件路径。"""
msg = "Error in C:\\Users\\admin\\file.py at line 10"
cleaned = sanitize_error_message(msg)
assert "C:\\Users" not in cleaned
assert "<path>" in cleaned
def test_strips_redis_connection_string(self):
"""剥离 Redis 连接串。"""
msg = "Cannot connect to redis://localhost:6379/0"
cleaned = sanitize_error_message(msg)
assert "redis://localhost" not in cleaned
assert "<connection-string>" in cleaned
def test_strips_postgresql_connection_string(self):
"""剥离 PostgreSQL 连接串。"""
msg = "DB error: postgresql://user:pass@host:5432/db"
cleaned = sanitize_error_message(msg)
assert "postgresql://" not in cleaned
assert "<connection-string>" in cleaned
def test_redacts_api_keys(self):
"""净化 API key。"""
msg = "Auth failed with api_key=sk-abc123def456"
cleaned = sanitize_error_message(msg)
assert "sk-abc123def456" not in cleaned
assert "<redacted>" in cleaned
def test_redacts_password(self):
"""净化 password。"""
msg = "Login failed password=secret123"
cleaned = sanitize_error_message(msg)
assert "secret123" not in cleaned
assert "<redacted>" in cleaned
def test_truncates_long_messages(self):
"""长消息截断。"""
msg = "x" * 1000
cleaned = sanitize_error_message(msg, max_length=100)
assert len(cleaned) == 100
assert cleaned.endswith("...")
def test_empty_message_returns_empty(self):
"""空消息返回空字符串。"""
assert sanitize_error_message("") == ""
# ---------------------------------------------------------------------------
# 3. TaskManager — 任务提交与状态查询
# ---------------------------------------------------------------------------
class TestTaskManagerSubmit:
"""任务提交测试。"""
async def test_submit_vectorize_creates_pending_task(self):
"""提交向量化任务创建 PENDING 状态记录。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
task_id = await manager.submit_vectorize(params)
assert task_id.startswith("vec-")
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
assert record.input_data["document_id"] == "doc-001"
assert record.metadata["user_id"] == "user1"
async def test_submit_batch_index_creates_pending_task(self):
"""提交批量索引任务创建 PENDING 状态记录。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_batch_params()
task_id = await manager.submit_batch_index(params)
assert task_id.startswith("batch-")
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
assert len(record.input_data["document_ids"]) == 2
async def test_submit_with_broker_dispatches(self):
"""有 broker 时调用 broker.kiq 派发任务。"""
store = InMemoryTaskStore()
mock_broker = MagicMock()
mock_broker.kiq = AsyncMock()
manager = TaskManager(broker=mock_broker, task_store=store)
params = _make_vectorize_params()
task_id = await manager.submit_vectorize(params)
mock_broker.kiq.assert_awaited_once()
# 验证派发参数包含 task_type 和 task_id
call_kwargs = mock_broker.kiq.call_args.kwargs
assert call_kwargs["task_type"] == "vectorize"
assert call_kwargs["task_id"] == task_id
# ---------------------------------------------------------------------------
# 4. 任务状态查询
# ---------------------------------------------------------------------------
class TestTaskManagerQuery:
"""任务状态查询测试。"""
async def test_get_task_status_returns_none_if_not_found(self):
"""查询不存在的任务返回 None。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
result = await manager.get_task_status("nonexistent")
assert result is None
async def test_list_tasks_filters_by_user(self):
"""list_tasks 按 user_id 过滤。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
# user1 提交 2 个任务
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1"))
# user2 提交 1 个任务
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2"))
user1_tasks = await manager.list_tasks(user_id="user1")
user2_tasks = await manager.list_tasks(user_id="user2")
all_tasks = await manager.list_tasks()
assert len(user1_tasks) == 2
assert len(user2_tasks) == 1
assert len(all_tasks) == 3
async def test_list_tasks_filters_by_status(self):
"""list_tasks 按 status 过滤。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync
store.update_status(task_id, TaskStatus.COMPLETED)
# 再提交一个 PENDING
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2"))
pending = await manager.list_tasks(status=TaskStatus.PENDING)
completed = await manager.list_tasks(status=TaskStatus.COMPLETED)
assert len(pending) == 1
assert len(completed) == 1
# ---------------------------------------------------------------------------
# 5. per-user 并发上限
# ---------------------------------------------------------------------------
class TestConcurrencyLimit:
"""per-user 并发上限测试。"""
async def test_concurrency_limit_exceeded(self):
"""超过 per-user 并发上限抛出异常。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=2,
)
# 提交 2 个任务(达到上限)
await manager.submit_vectorize(_make_vectorize_params(document_id="d1"))
await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
# 第 3 个应抛出异常
with pytest.raises(ConcurrencyLimitExceeded):
await manager.submit_vectorize(_make_vectorize_params(document_id="d3"))
async def test_concurrency_limit_per_user_isolated(self):
"""不同用户的并发上限独立计算。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=1,
)
# user1 提交 1 个任务
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
# user2 仍可提交
task_id = await manager.submit_vectorize(
_make_vectorize_params(document_id="d2", user_id="user2")
)
assert task_id is not None
async def test_concurrency_released_after_cancel(self):
"""取消任务后释放并发计数。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=1,
)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 取消任务
await manager.cancel_task(task_id)
# 应可再次提交
task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
assert task_id2 is not None
# ---------------------------------------------------------------------------
# 6. 任务取消
# ---------------------------------------------------------------------------
class TestTaskCancel:
"""任务取消测试。"""
async def test_cancel_pending_task(self):
"""取消 PENDING 任务成功。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
result = await manager.cancel_task(task_id)
assert result is True
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.CANCELLED
async def test_cancel_nonexistent_raises(self):
"""取消不存在的任务抛出 TaskNotFoundError。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
with pytest.raises(TaskNotFoundError):
await manager.cancel_task("nonexistent")
async def test_cancel_completed_returns_false(self):
"""取消已完成任务返回 False。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
store.update_status(task_id, TaskStatus.COMPLETED)
result = await manager.cancel_task(task_id)
assert result is False
# ---------------------------------------------------------------------------
# 7. WorkerSweeper — 超时任务检测
# ---------------------------------------------------------------------------
class TestWorkerSweeper:
"""Worker sweeper 测试。"""
async def test_sweep_detects_timed_out_task(self):
"""sweeper 检测超时任务并标记为 failed。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
store.update_status(task_id, TaskStatus.RUNNING)
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = old_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 1
record = store.get(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
assert record.error_message == "worker_timeout"
async def test_sweep_skips_fresh_heartbeat(self):
"""sweeper 跳过心跳未超时的任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 新鲜心跳
fresh_time = datetime.now(timezone.utc).isoformat()
store.update_status(task_id, TaskStatus.RUNNING)
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = fresh_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
record = store.get(task_id)
assert record is not None
assert record.status == TaskStatus.RUNNING
async def test_sweep_skips_tasks_without_heartbeat(self):
"""sweeper 跳过无心跳记录的任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
store.update_status(task_id, TaskStatus.RUNNING)
# 不设置 worker_heartbeat_at
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
async def test_sweep_skips_non_running_tasks(self):
"""sweeper 只处理 RUNNING 状态任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# PENDING 状态 — 不应被清理
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = old_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
# ---------------------------------------------------------------------------
# 8. 降级模式 — 无 broker 同步执行
# ---------------------------------------------------------------------------
class TestDegradedMode:
"""降级模式测试 — 无 broker 时同步执行。"""
async def test_vectorize_success_transitions_to_completed(self):
"""降级模式 — 向量化成功后状态转为 COMPLETED。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
):
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 等待降级模式任务完成
await asyncio.sleep(0.1)
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.progress == 1.0
async def test_vectorize_failure_sanitizes_error(self):
"""降级模式 — 向量化失败时错误消息被净化。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task 抛出含路径的异常
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
side_effect=FileNotFoundError(
"[Errno 2] No such file or directory: /usr/local/data/missing.txt"
),
):
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 等待降级模式任务完成
await asyncio.sleep(0.1)
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
# 错误消息应被净化(路径替换为 <path>
assert "/usr/local" not in (record.error_message or "")
assert "<path>" in (record.error_message or "")
async def test_degraded_mode_does_not_block_caller(self):
"""降级模式 — 任务在后台执行,不阻塞调用方。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task — 模拟耗时操作
async def _slow_task(*args, **kwargs):
await asyncio.sleep(0.5)
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
side_effect=_slow_task,
):
# submit_vectorize 应立即返回(不等待任务完成)
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 立即查询 — 应为 PENDING 或 RUNNING不阻塞
record = await manager.get_task_status(task_id)
assert record is not None
# 状态应为 PENDING任务刚提交还未开始执行
# 或 RUNNING已开始执行但未完成
assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
# 等待任务完成
await asyncio.sleep(0.6)
# ---------------------------------------------------------------------------
# 9. 任务执行函数测试
# ---------------------------------------------------------------------------
class TestRunVectorizeTask:
"""run_vectorize_task 函数测试。"""
async def test_run_vectorize_calls_document_processor(self):
"""run_vectorize_task 调用 DocumentProcessor.process。"""
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 不依赖 llama_index
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
) as mock_process:
await run_vectorize_task(params, **deps)
mock_process.assert_awaited_once()
async def test_run_vectorize_failure_propagates(self):
"""run_vectorize_task 失败时抛出异常。"""
params = _make_vectorize_params()
deps = _make_mock_dependencies()
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
side_effect=RuntimeError("processing failed"),
):
with pytest.raises(RuntimeError, match="processing failed"):
await run_vectorize_task(params, **deps)
class TestRunBatchIndexTask:
"""run_batch_index_task 函数测试。"""
async def test_batch_index_returns_results(self):
"""run_batch_index_task 返回成功/失败列表。"""
params = BatchIndexTaskParams(
kb_id="kb-1",
document_ids=["doc-1", "doc-2"],
user_id="user1",
file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
file_types={"doc-1": "txt", "doc-2": "txt"},
)
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 模拟成功
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
):
result = await run_batch_index_task(params, **deps)
assert "succeeded" in result
assert "failed" in result
assert len(result["succeeded"]) == 2
assert len(result["failed"]) == 0
async def test_batch_index_partial_failure(self):
"""批量索引部分失败时返回失败列表。"""
params = BatchIndexTaskParams(
kb_id="kb-1",
document_ids=["doc-1", "doc-missing"],
user_id="user1",
file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"},
file_types={"doc-1": "txt", "doc-missing": "txt"},
)
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 第一次成功,第二次失败
call_count = [0]
async def _side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 2:
raise FileNotFoundError("missing file")
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
side_effect=_side_effect,
):
result = await run_batch_index_task(params, **deps)
assert "doc-1" in result["succeeded"]
assert "doc-missing" in result["failed"]
# ---------------------------------------------------------------------------
# 10. 心跳更新测试
# ---------------------------------------------------------------------------
class TestHeartbeat:
"""worker 心跳测试。"""
async def test_update_heartbeat_sets_timestamp(self):
"""update_heartbeat 设置 worker_heartbeat_at 时间戳。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
await manager.update_heartbeat(task_id)
record = await manager.get_task_status(task_id)
assert record is not None
assert "worker_heartbeat_at" in record.metadata
# 验证时间戳可解析
hb = record.metadata["worker_heartbeat_at"]
datetime.fromisoformat(hb)
async def test_update_heartbeat_nonexistent_task_noop(self):
"""update_heartbeat 对不存在的任务无副作用。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
# 不应抛出异常
await manager.update_heartbeat("nonexistent")