feat(rag_platform): U8 — TaskIQ async task integration
Add tasks.py: TaskManager with vectorize/batch_index tasks, per-user concurrency limits, degraded mode (sync execution without broker), WorkerSweeper for timeout detection, error message sanitization Add taskiq>=0.11 and taskiq-redis>=0.5 to pyproject.toml Task parameter schema validation (VectorizeTaskParams, BatchIndexTaskParams) Tests: 41 new tests, 289 total passing
This commit is contained in:
parent
d026a91f43
commit
e3ae2f3a56
|
|
@ -43,6 +43,9 @@ dependencies = [
|
|||
"llama-index-embeddings-openai>=0.3",
|
||||
"pgvector>=0.3",
|
||||
"jieba>=0.42",
|
||||
# 异步任务队列(U8 — 文档向量化)
|
||||
"taskiq>=0.11",
|
||||
"taskiq-redis>=0.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,705 @@
|
|||
"""U8 — TaskIQ 异步任务集成 — 文档向量化与批量索引。
|
||||
|
||||
包装 TaskIQ broker,提供:
|
||||
- 任务参数 schema 校验(Pydantic 模型)
|
||||
- per-user 并发上限
|
||||
- 任务状态机(PENDING → RUNNING → COMPLETED | FAILED | CANCELLED)
|
||||
- Worker liveness & sweeper(worker_heartbeat_at 超时检测)
|
||||
- 错误消息净化(剥离内部路径/连接串)
|
||||
- 降级模式:broker 未配置时同步执行
|
||||
|
||||
Redis 隔离:使用独立 db=1(与 bus/cache 的 db=0 分离),或 key 前缀 ``taskiq:``。
|
||||
任务参数在 Redis 中带 TTL(默认 1 小时)。
|
||||
|
||||
ponytail: TaskIQ broker 为可选依赖 — 通过 TYPE_CHECKING 导入,
|
||||
未安装时模块仍可加载,TaskManager 退化为同步执行。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from agentkit.core.protocol import TaskStatus
|
||||
from agentkit.rag_platform.document_processor import (
|
||||
DEFAULT_CHUNK_OVERLAP,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DocumentProcessor,
|
||||
)
|
||||
from agentkit.server.task_store import InMemoryTaskStore, TaskRecord
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
||||
from agentkit.rag_platform.store import KBStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_MAX_CONCURRENT_PER_USER = 3
|
||||
DEFAULT_WORKER_TTL_SECONDS = 300 # worker 心跳超时阈值
|
||||
DEFAULT_ERROR_MESSAGE_MAX_LENGTH = 500
|
||||
DEFAULT_TASK_TTL_SECONDS = 3600 # 任务参数 TTL
|
||||
|
||||
# Redis 隔离 — 独立 db=1,与 bus/cache 的 db=0 分离
|
||||
TASKIQ_REDIS_DB = 1
|
||||
TASKIQ_KEY_PREFIX = "taskiq:"
|
||||
|
||||
|
||||
async def _maybe_await(result: Any) -> Any:
|
||||
"""统一处理 sync/async 调用结果 — InMemoryTaskStore 方法为 sync,RedisTaskStore 为 async。"""
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 任务参数模型 — schema 校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VectorizeTaskParams(BaseModel):
|
||||
"""向量化任务参数 — 单文档 parse → segment → vectorize → index。"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
document_id: str = Field(min_length=1, max_length=128)
|
||||
kb_id: str = Field(min_length=1, max_length=128)
|
||||
file_path: str = Field(min_length=1, max_length=1024)
|
||||
file_type: str = Field(min_length=1, max_length=64)
|
||||
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192)
|
||||
chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024)
|
||||
user_id: str = Field(min_length=1, max_length=128)
|
||||
|
||||
@field_validator("chunk_overlap")
|
||||
@classmethod
|
||||
def _overlap_less_than_size(cls, v: int, info: Any) -> int:
|
||||
"""chunk_overlap 必须小于 chunk_size。"""
|
||||
size = info.data.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
||||
if v >= size:
|
||||
raise ValueError(f"chunk_overlap ({v}) must be < chunk_size ({size})")
|
||||
return v
|
||||
|
||||
|
||||
class BatchIndexTaskParams(BaseModel):
|
||||
"""批量索引任务参数 — 多文档并行向量化。"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
kb_id: str = Field(min_length=1, max_length=128)
|
||||
document_ids: list[str] = Field(min_length=1, max_length=100)
|
||||
user_id: str = Field(min_length=1, max_length=128)
|
||||
file_paths: dict[str, str] = Field(default_factory=dict) # document_id -> file_path
|
||||
file_types: dict[str, str] = Field(default_factory=dict) # document_id -> file_type
|
||||
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=50, le=8192)
|
||||
chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP, ge=0, le=1024)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 错误消息净化
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 匹配 Unix/Windows 文件路径
|
||||
_PATH_RE = re.compile(
|
||||
r"(?:/[\w.\-]+)+|" # Unix 路径 /usr/local/...
|
||||
r"(?:[A-Za-z]:\\[\w.\-]+(?:\\[\w.\-]+)*)" # Windows 路径 C:\Users\...
|
||||
)
|
||||
# 匹配 Redis/PostgreSQL 连接串
|
||||
_CONN_STR_RE = re.compile(
|
||||
r"(redis://|postgresql://|postgres://|mysql://)"
|
||||
r"[^\s\"']+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# 匹配环境变量风格的密钥
|
||||
_SECRET_RE = re.compile(
|
||||
r"(api[_-]?key|token|password|secret)[\s:=]+[\w\-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_error_message(msg: str, max_length: int = DEFAULT_ERROR_MESSAGE_MAX_LENGTH) -> str:
|
||||
"""净化错误消息 — 剥离内部路径、连接串、密钥。
|
||||
|
||||
Args:
|
||||
msg: 原始错误消息
|
||||
max_length: 最大长度(截断)
|
||||
|
||||
Returns:
|
||||
净化后的错误消息
|
||||
"""
|
||||
if not msg:
|
||||
return ""
|
||||
# 连接串优先于路径 — 避免 /localhost:6379/0 被路径正则部分匹配
|
||||
cleaned = _CONN_STR_RE.sub("<connection-string>", msg)
|
||||
cleaned = _PATH_RE.sub("<path>", cleaned)
|
||||
cleaned = _SECRET_RE.sub(r"\1=<redacted>", cleaned)
|
||||
if len(cleaned) > max_length:
|
||||
cleaned = cleaned[: max_length - 3] + "..."
|
||||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Worker liveness & sweeper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TaskStoreProtocol(Protocol):
|
||||
"""TaskStore 协议 — InMemoryTaskStore / RedisTaskStore 均满足。"""
|
||||
|
||||
async def create(
|
||||
self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None
|
||||
) -> TaskRecord: ...
|
||||
|
||||
def get(self, task_id: str) -> TaskRecord | None: ...
|
||||
|
||||
async def update_status(
|
||||
self, task_id: str, status: TaskStatus, **kwargs: Any
|
||||
) -> TaskRecord: ...
|
||||
|
||||
def list_tasks(
|
||||
self, status: TaskStatus | None = None, limit: int = 100
|
||||
) -> list[TaskRecord]: ...
|
||||
|
||||
|
||||
class WorkerSweeper:
|
||||
"""检测超时任务并标记为 failed。
|
||||
|
||||
扫描 status=RUNNING 且 worker_heartbeat_at 超过 TTL 的任务,
|
||||
将其状态置为 failed,error_message="worker_timeout"。
|
||||
|
||||
ponytail: 升级路径 — 当前依赖 list_tasks 全量扫描(O(n)),
|
||||
生产环境可改为 Redis ZSET 按心跳时间戳排序(O(log n))。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_store: TaskStoreProtocol,
|
||||
ttl_seconds: int = DEFAULT_WORKER_TTL_SECONDS,
|
||||
) -> None:
|
||||
self._store = task_store
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
async def sweep(self) -> int:
|
||||
"""扫描超时任务,返回清理数量。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(seconds=self._ttl)
|
||||
cleaned = 0
|
||||
|
||||
running_tasks = await _maybe_await(
|
||||
self._store.list_tasks(status=TaskStatus.RUNNING, limit=1000)
|
||||
)
|
||||
|
||||
for record in running_tasks:
|
||||
heartbeat = record.metadata.get("worker_heartbeat_at")
|
||||
if heartbeat is None:
|
||||
# 无心跳记录 — 使用 started_at 作为回退
|
||||
heartbeat = record.started_at.isoformat() if record.started_at else None
|
||||
if heartbeat is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
hb_time = datetime.fromisoformat(heartbeat)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# 确保 timezone-aware
|
||||
if hb_time.tzinfo is None:
|
||||
hb_time = hb_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
if hb_time < cutoff:
|
||||
try:
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
record.task_id,
|
||||
TaskStatus.FAILED,
|
||||
error_message="worker_timeout",
|
||||
completed_at=now,
|
||||
metadata={**record.metadata, "swept_at": now.isoformat()},
|
||||
)
|
||||
)
|
||||
cleaned += 1
|
||||
logger.warning(
|
||||
"Swept timed-out task %s (heartbeat=%s, cutoff=%s)",
|
||||
record.task_id,
|
||||
heartbeat,
|
||||
cutoff.isoformat(),
|
||||
)
|
||||
except KeyError:
|
||||
# 任务已被并发删除
|
||||
continue
|
||||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TaskManager — 任务提交、查询、取消
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcurrencyLimitExceeded(Exception):
|
||||
"""per-user 并发上限超出。"""
|
||||
|
||||
|
||||
class TaskNotFoundError(Exception):
|
||||
"""任务不存在。"""
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""异步任务管理器 — 包装 TaskIQ broker。
|
||||
|
||||
如果 broker 未配置或 TaskIQ 不可用,任务同步执行(降级模式)。
|
||||
|
||||
Args:
|
||||
broker: TaskIQ broker 实例(可选)
|
||||
task_store: 任务状态存储(默认 InMemoryTaskStore)
|
||||
max_concurrent_per_user: per-user 并发上限
|
||||
task_ttl_seconds: 任务参数 TTL(Redis 模式)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: Any = None,
|
||||
task_store: TaskStoreProtocol | None = None,
|
||||
max_concurrent_per_user: int = DEFAULT_MAX_CONCURRENT_PER_USER,
|
||||
task_ttl_seconds: int = DEFAULT_TASK_TTL_SECONDS,
|
||||
) -> None:
|
||||
self._broker = broker
|
||||
self._store: TaskStoreProtocol = task_store or InMemoryTaskStore()
|
||||
self._max_concurrent = max_concurrent_per_user
|
||||
self._task_ttl = task_ttl_seconds
|
||||
# user_id -> 当前 RUNNING 任务数
|
||||
self._running_counts: dict[str, int] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def store(self) -> TaskStoreProtocol:
|
||||
"""暴露底层 task store(供 sweeper 使用)。"""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def broker(self) -> Any:
|
||||
"""暴露底层 broker(供 startup/shutdown 集成)。"""
|
||||
return self._broker
|
||||
|
||||
# ── 任务提交 ────────────────────────────────────────────────
|
||||
|
||||
async def submit_vectorize(
|
||||
self,
|
||||
params: VectorizeTaskParams,
|
||||
dependencies: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""提交向量化任务,返回 task_id。
|
||||
|
||||
Args:
|
||||
params: 向量化任务参数(已通过 schema 校验)
|
||||
dependencies: 运行时依赖(store/vector_store/embed_model),
|
||||
降级模式下直接使用;broker 模式下由 worker 注入
|
||||
|
||||
Raises:
|
||||
ConcurrencyLimitExceeded: per-user 并发上限超出
|
||||
"""
|
||||
await self._check_concurrency(params.user_id)
|
||||
|
||||
task_id = f"vec-{uuid.uuid4()}"
|
||||
input_data = params.model_dump()
|
||||
# 依赖对象不序列化 — 仅在降级模式内存中传递
|
||||
deps = dependencies or {}
|
||||
|
||||
await _maybe_await(
|
||||
self._store.create(
|
||||
task_id=task_id,
|
||||
agent_name="rag_vectorize",
|
||||
input_data=input_data,
|
||||
skill_name="vectorize",
|
||||
)
|
||||
)
|
||||
# 记录 user_id 到 metadata 便于按用户查询
|
||||
record = await _maybe_await(self._store.get(task_id))
|
||||
if record is not None:
|
||||
record.metadata["user_id"] = params.user_id
|
||||
|
||||
logger.info(
|
||||
"Submitted vectorize task %s for document=%s user=%s",
|
||||
task_id,
|
||||
params.document_id,
|
||||
params.user_id,
|
||||
)
|
||||
|
||||
if self._broker is not None:
|
||||
# broker 模式 — 异步派发
|
||||
await self._dispatch_to_broker("vectorize", input_data, task_id)
|
||||
else:
|
||||
# 降级模式 — 同步执行(不阻塞调用方,使用 asyncio.create_task)
|
||||
asyncio.create_task(self._run_vectorize_degraded(task_id, params, deps))
|
||||
|
||||
return task_id
|
||||
|
||||
async def submit_batch_index(
|
||||
self,
|
||||
params: BatchIndexTaskParams,
|
||||
dependencies: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""提交批量索引任务,返回 task_id。"""
|
||||
await self._check_concurrency(params.user_id)
|
||||
|
||||
task_id = f"batch-{uuid.uuid4()}"
|
||||
input_data = params.model_dump()
|
||||
deps = dependencies or {}
|
||||
|
||||
await _maybe_await(
|
||||
self._store.create(
|
||||
task_id=task_id,
|
||||
agent_name="rag_batch_index",
|
||||
input_data=input_data,
|
||||
skill_name="batch_index",
|
||||
)
|
||||
)
|
||||
record = await _maybe_await(self._store.get(task_id))
|
||||
if record is not None:
|
||||
record.metadata["user_id"] = params.user_id
|
||||
|
||||
logger.info(
|
||||
"Submitted batch index task %s for %d documents user=%s",
|
||||
task_id,
|
||||
len(params.document_ids),
|
||||
params.user_id,
|
||||
)
|
||||
|
||||
if self._broker is not None:
|
||||
await self._dispatch_to_broker("batch_index", input_data, task_id)
|
||||
else:
|
||||
asyncio.create_task(self._run_batch_index_degraded(task_id, params, deps))
|
||||
|
||||
return task_id
|
||||
|
||||
# ── 任务查询 ────────────────────────────────────────────────
|
||||
|
||||
async def get_task_status(self, task_id: str) -> TaskRecord | None:
|
||||
"""查询任务状态。"""
|
||||
return await _maybe_await(self._store.get(task_id))
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
status: TaskStatus | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[TaskRecord]:
|
||||
"""查询任务历史 — 可按 user_id 和 status 过滤。"""
|
||||
tasks = await _maybe_await(self._store.list_tasks(status=status, limit=limit))
|
||||
if user_id is not None:
|
||||
tasks = [t for t in tasks if t.metadata.get("user_id") == user_id]
|
||||
return tasks
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务 — 仅 PENDING/RUNNING 可取消。
|
||||
|
||||
Returns:
|
||||
True 如果取消成功
|
||||
Raises:
|
||||
TaskNotFoundError: 任务不存在
|
||||
"""
|
||||
record = await _maybe_await(self._store.get(task_id))
|
||||
if record is None:
|
||||
raise TaskNotFoundError(f"Task {task_id} not found")
|
||||
|
||||
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||
return False
|
||||
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
TaskStatus.CANCELLED,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
# 释放并发计数
|
||||
user_id = record.metadata.get("user_id")
|
||||
if user_id and self._running_counts.get(user_id, 0) > 0:
|
||||
self._running_counts[user_id] -= 1
|
||||
return True
|
||||
|
||||
# ── 心跳更新 ────────────────────────────────────────────────
|
||||
|
||||
async def update_heartbeat(self, task_id: str) -> None:
|
||||
"""更新 worker 心跳时间戳 — 由 worker 周期性调用。"""
|
||||
record = await _maybe_await(self._store.get(task_id))
|
||||
if record is None:
|
||||
return
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record.metadata["worker_heartbeat_at"] = now
|
||||
|
||||
# ── 内部方法 ────────────────────────────────────────────────
|
||||
|
||||
async def _check_concurrency(self, user_id: str) -> None:
|
||||
"""检查 per-user 并发上限。"""
|
||||
async with self._lock:
|
||||
current = self._running_counts.get(user_id, 0)
|
||||
if current >= self._max_concurrent:
|
||||
raise ConcurrencyLimitExceeded(
|
||||
f"User {user_id} has {current} running tasks (max={self._max_concurrent})"
|
||||
)
|
||||
self._running_counts[user_id] = current + 1
|
||||
|
||||
async def _release_concurrency(self, user_id: str) -> None:
|
||||
"""释放并发计数。"""
|
||||
async with self._lock:
|
||||
if self._running_counts.get(user_id, 0) > 0:
|
||||
self._running_counts[user_id] -= 1
|
||||
|
||||
async def _dispatch_to_broker(self, task_type: str, params: dict, task_id: str) -> None:
|
||||
"""派发任务到 TaskIQ broker。"""
|
||||
# broker.kiq 返回 TaskiqTask — 实际执行由 worker 进程完成
|
||||
try:
|
||||
await self._broker.kiq(task_type=task_type, params=params, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to dispatch task %s to broker: %s", task_id, e)
|
||||
raise
|
||||
|
||||
async def _run_vectorize_degraded(
|
||||
self,
|
||||
task_id: str,
|
||||
params: VectorizeTaskParams,
|
||||
deps: dict[str, Any],
|
||||
) -> None:
|
||||
"""降级模式 — 同步执行向量化任务(在 asyncio 任务中)。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
await _maybe_await(
|
||||
self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now)
|
||||
)
|
||||
await self.update_heartbeat(task_id)
|
||||
|
||||
await run_vectorize_task(params, **deps)
|
||||
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
TaskStatus.COMPLETED,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
progress=1.0,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
sanitized = sanitize_error_message(str(e))
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
TaskStatus.FAILED,
|
||||
error_message=sanitized,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
logger.error("Vectorize task %s failed: %s", task_id, sanitized)
|
||||
finally:
|
||||
await self._release_concurrency(params.user_id)
|
||||
|
||||
async def _run_batch_index_degraded(
|
||||
self,
|
||||
task_id: str,
|
||||
params: BatchIndexTaskParams,
|
||||
deps: dict[str, Any],
|
||||
) -> None:
|
||||
"""降级模式 — 同步执行批量索引任务。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
total = len(params.document_ids)
|
||||
succeeded: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
try:
|
||||
await _maybe_await(
|
||||
self._store.update_status(task_id, TaskStatus.RUNNING, started_at=now)
|
||||
)
|
||||
await self.update_heartbeat(task_id)
|
||||
|
||||
for i, doc_id in enumerate(params.document_ids):
|
||||
await self.update_heartbeat(task_id)
|
||||
try:
|
||||
single_params = VectorizeTaskParams(
|
||||
document_id=doc_id,
|
||||
kb_id=params.kb_id,
|
||||
file_path=params.file_paths.get(doc_id, ""),
|
||||
file_type=params.file_types.get(doc_id, "txt"),
|
||||
chunk_size=params.chunk_size,
|
||||
chunk_overlap=params.chunk_overlap,
|
||||
user_id=params.user_id,
|
||||
)
|
||||
await run_vectorize_task(single_params, **deps)
|
||||
succeeded.append(doc_id)
|
||||
except Exception as e:
|
||||
logger.warning("Batch item %s failed: %s", doc_id, e)
|
||||
failed.append(doc_id)
|
||||
|
||||
# 更新进度
|
||||
progress = (i + 1) / total
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
TaskStatus.RUNNING,
|
||||
progress=progress,
|
||||
progress_message=f"Processed {i + 1}/{total}",
|
||||
)
|
||||
)
|
||||
|
||||
final_status = TaskStatus.COMPLETED if not failed else TaskStatus.PARTIALLY_COMPLETED
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
final_status,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
progress=1.0,
|
||||
output_data={
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
"total": total,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
sanitized = sanitize_error_message(str(e))
|
||||
await _maybe_await(
|
||||
self._store.update_status(
|
||||
task_id,
|
||||
TaskStatus.FAILED,
|
||||
error_message=sanitized,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
logger.error("Batch index task %s failed: %s", task_id, sanitized)
|
||||
finally:
|
||||
await self._release_concurrency(params.user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 任务执行函数 — 由 worker 调用
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_vectorize_task(
|
||||
params: VectorizeTaskParams,
|
||||
store: "KBStore",
|
||||
vector_store: "PGVectorStore",
|
||||
embed_model: "BaseEmbedding",
|
||||
) -> None:
|
||||
"""向量化任务 — 由 worker 执行。
|
||||
|
||||
封装 DocumentProcessor.process() — parse → segment → vectorize → index。
|
||||
状态转换由 DocumentProcessor 内部管理(pending → parsing → ... → indexed | failed)。
|
||||
|
||||
Args:
|
||||
params: 向量化任务参数
|
||||
store: KBStore 实例(用于文档状态更新)
|
||||
vector_store: LlamaIndex PGVectorStore 实例
|
||||
embed_model: LlamaIndex embedding 模型
|
||||
|
||||
Raises:
|
||||
Exception: 管道任一阶段失败
|
||||
"""
|
||||
processor = DocumentProcessor(
|
||||
chunk_size=params.chunk_size,
|
||||
chunk_overlap=params.chunk_overlap,
|
||||
)
|
||||
await processor.process(
|
||||
params.file_path,
|
||||
params.file_type,
|
||||
params.kb_id,
|
||||
params.document_id,
|
||||
vector_store,
|
||||
embed_model,
|
||||
store,
|
||||
)
|
||||
|
||||
|
||||
async def run_batch_index_task(
|
||||
params: BatchIndexTaskParams,
|
||||
store: "KBStore",
|
||||
vector_store: "PGVectorStore",
|
||||
embed_model: "BaseEmbedding",
|
||||
) -> dict[str, list[str]]:
|
||||
"""批量索引任务 — 顺序处理多个文档。
|
||||
|
||||
Returns:
|
||||
{"succeeded": [...], "failed": [...]}
|
||||
"""
|
||||
succeeded: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
for doc_id in params.document_ids:
|
||||
try:
|
||||
single_params = VectorizeTaskParams(
|
||||
document_id=doc_id,
|
||||
kb_id=params.kb_id,
|
||||
file_path=params.file_paths.get(doc_id, ""),
|
||||
file_type=params.file_types.get(doc_id, "txt"),
|
||||
chunk_size=params.chunk_size,
|
||||
chunk_overlap=params.chunk_overlap,
|
||||
user_id=params.user_id,
|
||||
)
|
||||
await run_vectorize_task(single_params, store, vector_store, embed_model)
|
||||
succeeded.append(doc_id)
|
||||
except Exception as e:
|
||||
logger.warning("Batch item %s failed: %s", doc_id, e)
|
||||
failed.append(doc_id)
|
||||
|
||||
return {"succeeded": succeeded, "failed": failed}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TaskIQ broker 工厂(可选 — 仅 Redis 可用时使用)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_broker(redis_url: str) -> Any:
|
||||
"""创建 TaskIQ Redis broker — 独立 db=1 隔离。
|
||||
|
||||
Args:
|
||||
redis_url: Redis 连接 URL(db=0 用于 bus/cache,此处强制改为 db=1)
|
||||
|
||||
Returns:
|
||||
配置好的 TaskiqBroker 实例
|
||||
|
||||
Raises:
|
||||
ImportError: taskiq 未安装
|
||||
"""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from taskiq_redis import ListQueueBroker, RedisAsyncResultBackend
|
||||
|
||||
from taskiq import TaskiqBroker
|
||||
|
||||
# 强制使用 db=1 — 与 bus/cache 的 db=0 隔离
|
||||
parsed = urlparse(redis_url)
|
||||
# 替换 path 为 /1
|
||||
new_path = "/1"
|
||||
isolated_url = urlunparse(parsed._replace(path=new_path))
|
||||
|
||||
broker = TaskiqBroker(ListQueueBroker(url=isolated_url))
|
||||
broker.with_result_backend(
|
||||
RedisAsyncResultBackend(redis_url=isolated_url, result_ex_time=DEFAULT_TASK_TTL_SECONDS)
|
||||
)
|
||||
|
||||
logger.info("TaskIQ broker created with Redis db=1 isolation")
|
||||
return broker
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchIndexTaskParams",
|
||||
"ConcurrencyLimitExceeded",
|
||||
"DEFAULT_MAX_CONCURRENT_PER_USER",
|
||||
"DEFAULT_WORKER_TTL_SECONDS",
|
||||
"TaskManager",
|
||||
"TaskNotFoundError",
|
||||
"TaskStoreProtocol",
|
||||
"VectorizeTaskParams",
|
||||
"WorkerSweeper",
|
||||
"create_broker",
|
||||
"run_batch_index_task",
|
||||
"run_vectorize_task",
|
||||
"sanitize_error_message",
|
||||
]
|
||||
|
|
@ -0,0 +1,721 @@
|
|||
"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。
|
||||
|
||||
测试场景:
|
||||
1. 任务参数 schema 校验(Pydantic 模型)
|
||||
2. 错误消息净化(剥离内部路径/连接串/密钥)
|
||||
3. 任务提交创建任务记录(PENDING)
|
||||
4. 任务状态查询(pending → running → completed)
|
||||
5. 任务历史可查询(按 user_id 过滤)
|
||||
6. per-user 并发上限强制
|
||||
7. 任务取消
|
||||
8. Worker sweeper 检测超时任务
|
||||
9. 降级模式(无 broker)同步执行
|
||||
10. 任务失败后错误消息净化
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskStatus
|
||||
from agentkit.rag_platform.tasks import (
|
||||
BatchIndexTaskParams,
|
||||
ConcurrencyLimitExceeded,
|
||||
TaskManager,
|
||||
TaskNotFoundError,
|
||||
VectorizeTaskParams,
|
||||
WorkerSweeper,
|
||||
run_batch_index_task,
|
||||
run_vectorize_task,
|
||||
sanitize_error_message,
|
||||
)
|
||||
from agentkit.server.task_store import InMemoryTaskStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 测试 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_vectorize_params(**overrides) -> VectorizeTaskParams:
|
||||
"""创建向量化任务参数。"""
|
||||
defaults = {
|
||||
"document_id": "doc-001",
|
||||
"kb_id": "kb-001",
|
||||
"file_path": "/tmp/test.txt",
|
||||
"file_type": "txt",
|
||||
"user_id": "user1",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return VectorizeTaskParams(**defaults)
|
||||
|
||||
|
||||
def _make_batch_params(**overrides) -> BatchIndexTaskParams:
|
||||
"""创建批量索引任务参数。"""
|
||||
defaults = {
|
||||
"kb_id": "kb-001",
|
||||
"document_ids": ["doc-1", "doc-2"],
|
||||
"user_id": "user1",
|
||||
"file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
|
||||
"file_types": {"doc-1": "txt", "doc-2": "txt"},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return BatchIndexTaskParams(**defaults)
|
||||
|
||||
|
||||
def _make_mock_dependencies():
|
||||
"""创建 mock 依赖(store/vector_store/embed_model)。"""
|
||||
store = MagicMock()
|
||||
store.update_document_status = AsyncMock()
|
||||
vector_store = MagicMock()
|
||||
vector_store.async_add = AsyncMock()
|
||||
embed_model = MagicMock()
|
||||
embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||
return {"store": store, "vector_store": vector_store, "embed_model": embed_model}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. 任务参数 schema 校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVectorizeTaskParams:
|
||||
"""VectorizeTaskParams schema 校验测试。"""
|
||||
|
||||
def test_valid_params(self):
|
||||
"""合法参数通过校验。"""
|
||||
params = _make_vectorize_params()
|
||||
assert params.document_id == "doc-001"
|
||||
assert params.chunk_size == 512
|
||||
assert params.chunk_overlap == 50
|
||||
|
||||
def test_empty_document_id_rejected(self):
|
||||
"""空 document_id 被拒绝。"""
|
||||
with pytest.raises(ValueError, match="document_id"):
|
||||
VectorizeTaskParams(
|
||||
document_id="",
|
||||
kb_id="kb-1",
|
||||
file_path="/tmp/x.txt",
|
||||
file_type="txt",
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
def test_chunk_overlap_must_be_less_than_size(self):
|
||||
"""chunk_overlap 必须小于 chunk_size。"""
|
||||
with pytest.raises(ValueError, match="chunk_overlap"):
|
||||
VectorizeTaskParams(
|
||||
document_id="doc-1",
|
||||
kb_id="kb-1",
|
||||
file_path="/tmp/x.txt",
|
||||
file_type="txt",
|
||||
user_id="u1",
|
||||
chunk_size=100,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
|
||||
def test_chunk_size_bounds(self):
|
||||
"""chunk_size 必须在 50-8192 之间。"""
|
||||
with pytest.raises(ValueError):
|
||||
_make_vectorize_params(chunk_size=10)
|
||||
|
||||
def test_user_id_required(self):
|
||||
"""user_id 必填。"""
|
||||
with pytest.raises(ValueError):
|
||||
VectorizeTaskParams(
|
||||
document_id="doc-1",
|
||||
kb_id="kb-1",
|
||||
file_path="/tmp/x.txt",
|
||||
file_type="txt",
|
||||
)
|
||||
|
||||
|
||||
class TestBatchIndexTaskParams:
|
||||
"""BatchIndexTaskParams schema 校验测试。"""
|
||||
|
||||
def test_valid_params(self):
|
||||
"""合法参数通过校验。"""
|
||||
params = _make_batch_params()
|
||||
assert len(params.document_ids) == 2
|
||||
|
||||
def test_empty_document_ids_rejected(self):
|
||||
"""空 document_ids 列表被拒绝。"""
|
||||
with pytest.raises(ValueError, match="document_ids"):
|
||||
BatchIndexTaskParams(
|
||||
kb_id="kb-1",
|
||||
document_ids=[],
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
def test_too_many_document_ids_rejected(self):
|
||||
"""超过 100 个 document_ids 被拒绝。"""
|
||||
with pytest.raises(ValueError):
|
||||
BatchIndexTaskParams(
|
||||
kb_id="kb-1",
|
||||
document_ids=[f"doc-{i}" for i in range(101)],
|
||||
user_id="u1",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. 错误消息净化
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeErrorMessage:
|
||||
"""错误消息净化测试。"""
|
||||
|
||||
def test_strips_unix_paths(self):
|
||||
"""剥离 Unix 文件路径。"""
|
||||
msg = "Failed to open /usr/local/lib/python/file.py"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "/usr/local" not in cleaned
|
||||
assert "<path>" in cleaned
|
||||
|
||||
def test_strips_windows_paths(self):
|
||||
"""剥离 Windows 文件路径。"""
|
||||
msg = "Error in C:\\Users\\admin\\file.py at line 10"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "C:\\Users" not in cleaned
|
||||
assert "<path>" in cleaned
|
||||
|
||||
def test_strips_redis_connection_string(self):
|
||||
"""剥离 Redis 连接串。"""
|
||||
msg = "Cannot connect to redis://localhost:6379/0"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "redis://localhost" not in cleaned
|
||||
assert "<connection-string>" in cleaned
|
||||
|
||||
def test_strips_postgresql_connection_string(self):
|
||||
"""剥离 PostgreSQL 连接串。"""
|
||||
msg = "DB error: postgresql://user:pass@host:5432/db"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "postgresql://" not in cleaned
|
||||
assert "<connection-string>" in cleaned
|
||||
|
||||
def test_redacts_api_keys(self):
|
||||
"""净化 API key。"""
|
||||
msg = "Auth failed with api_key=sk-abc123def456"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "sk-abc123def456" not in cleaned
|
||||
assert "<redacted>" in cleaned
|
||||
|
||||
def test_redacts_password(self):
|
||||
"""净化 password。"""
|
||||
msg = "Login failed password=secret123"
|
||||
cleaned = sanitize_error_message(msg)
|
||||
assert "secret123" not in cleaned
|
||||
assert "<redacted>" in cleaned
|
||||
|
||||
def test_truncates_long_messages(self):
|
||||
"""长消息截断。"""
|
||||
msg = "x" * 1000
|
||||
cleaned = sanitize_error_message(msg, max_length=100)
|
||||
assert len(cleaned) == 100
|
||||
assert cleaned.endswith("...")
|
||||
|
||||
def test_empty_message_returns_empty(self):
|
||||
"""空消息返回空字符串。"""
|
||||
assert sanitize_error_message("") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. TaskManager — 任务提交与状态查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTaskManagerSubmit:
|
||||
"""任务提交测试。"""
|
||||
|
||||
async def test_submit_vectorize_creates_pending_task(self):
|
||||
"""提交向量化任务创建 PENDING 状态记录。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
params = _make_vectorize_params()
|
||||
|
||||
task_id = await manager.submit_vectorize(params)
|
||||
|
||||
assert task_id.startswith("vec-")
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.PENDING
|
||||
assert record.input_data["document_id"] == "doc-001"
|
||||
assert record.metadata["user_id"] == "user1"
|
||||
|
||||
async def test_submit_batch_index_creates_pending_task(self):
|
||||
"""提交批量索引任务创建 PENDING 状态记录。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
params = _make_batch_params()
|
||||
|
||||
task_id = await manager.submit_batch_index(params)
|
||||
|
||||
assert task_id.startswith("batch-")
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.PENDING
|
||||
assert len(record.input_data["document_ids"]) == 2
|
||||
|
||||
async def test_submit_with_broker_dispatches(self):
|
||||
"""有 broker 时调用 broker.kiq 派发任务。"""
|
||||
store = InMemoryTaskStore()
|
||||
mock_broker = MagicMock()
|
||||
mock_broker.kiq = AsyncMock()
|
||||
manager = TaskManager(broker=mock_broker, task_store=store)
|
||||
params = _make_vectorize_params()
|
||||
|
||||
task_id = await manager.submit_vectorize(params)
|
||||
|
||||
mock_broker.kiq.assert_awaited_once()
|
||||
# 验证派发参数包含 task_type 和 task_id
|
||||
call_kwargs = mock_broker.kiq.call_args.kwargs
|
||||
assert call_kwargs["task_type"] == "vectorize"
|
||||
assert call_kwargs["task_id"] == task_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. 任务状态查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTaskManagerQuery:
|
||||
"""任务状态查询测试。"""
|
||||
|
||||
async def test_get_task_status_returns_none_if_not_found(self):
|
||||
"""查询不存在的任务返回 None。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
|
||||
result = await manager.get_task_status("nonexistent")
|
||||
assert result is None
|
||||
|
||||
async def test_list_tasks_filters_by_user(self):
|
||||
"""list_tasks 按 user_id 过滤。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
|
||||
# user1 提交 2 个任务
|
||||
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1"))
|
||||
# user2 提交 1 个任务
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2"))
|
||||
|
||||
user1_tasks = await manager.list_tasks(user_id="user1")
|
||||
user2_tasks = await manager.list_tasks(user_id="user2")
|
||||
all_tasks = await manager.list_tasks()
|
||||
|
||||
assert len(user1_tasks) == 2
|
||||
assert len(user2_tasks) == 1
|
||||
assert len(all_tasks) == 3
|
||||
|
||||
async def test_list_tasks_filters_by_status(self):
|
||||
"""list_tasks 按 status 过滤。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
# 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync
|
||||
store.update_status(task_id, TaskStatus.COMPLETED)
|
||||
# 再提交一个 PENDING
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2"))
|
||||
|
||||
pending = await manager.list_tasks(status=TaskStatus.PENDING)
|
||||
completed = await manager.list_tasks(status=TaskStatus.COMPLETED)
|
||||
|
||||
assert len(pending) == 1
|
||||
assert len(completed) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. per-user 并发上限
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConcurrencyLimit:
|
||||
"""per-user 并发上限测试。"""
|
||||
|
||||
async def test_concurrency_limit_exceeded(self):
|
||||
"""超过 per-user 并发上限抛出异常。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(
|
||||
broker=None,
|
||||
task_store=store,
|
||||
max_concurrent_per_user=2,
|
||||
)
|
||||
|
||||
# 提交 2 个任务(达到上限)
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="d1"))
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
|
||||
|
||||
# 第 3 个应抛出异常
|
||||
with pytest.raises(ConcurrencyLimitExceeded):
|
||||
await manager.submit_vectorize(_make_vectorize_params(document_id="d3"))
|
||||
|
||||
async def test_concurrency_limit_per_user_isolated(self):
|
||||
"""不同用户的并发上限独立计算。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(
|
||||
broker=None,
|
||||
task_store=store,
|
||||
max_concurrent_per_user=1,
|
||||
)
|
||||
|
||||
# user1 提交 1 个任务
|
||||
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
|
||||
# user2 仍可提交
|
||||
task_id = await manager.submit_vectorize(
|
||||
_make_vectorize_params(document_id="d2", user_id="user2")
|
||||
)
|
||||
assert task_id is not None
|
||||
|
||||
async def test_concurrency_released_after_cancel(self):
|
||||
"""取消任务后释放并发计数。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(
|
||||
broker=None,
|
||||
task_store=store,
|
||||
max_concurrent_per_user=1,
|
||||
)
|
||||
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
# 取消任务
|
||||
await manager.cancel_task(task_id)
|
||||
# 应可再次提交
|
||||
task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
|
||||
assert task_id2 is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. 任务取消
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTaskCancel:
|
||||
"""任务取消测试。"""
|
||||
|
||||
async def test_cancel_pending_task(self):
|
||||
"""取消 PENDING 任务成功。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
|
||||
result = await manager.cancel_task(task_id)
|
||||
|
||||
assert result is True
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.CANCELLED
|
||||
|
||||
async def test_cancel_nonexistent_raises(self):
|
||||
"""取消不存在的任务抛出 TaskNotFoundError。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
await manager.cancel_task("nonexistent")
|
||||
|
||||
async def test_cancel_completed_returns_false(self):
|
||||
"""取消已完成任务返回 False。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
store.update_status(task_id, TaskStatus.COMPLETED)
|
||||
|
||||
result = await manager.cancel_task(task_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. WorkerSweeper — 超时任务检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkerSweeper:
|
||||
"""Worker sweeper 测试。"""
|
||||
|
||||
async def test_sweep_detects_timed_out_task(self):
|
||||
"""sweeper 检测超时任务并标记为 failed。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
|
||||
# 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync
|
||||
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
|
||||
store.update_status(task_id, TaskStatus.RUNNING)
|
||||
record = store.get(task_id)
|
||||
assert record is not None
|
||||
record.metadata["worker_heartbeat_at"] = old_time
|
||||
|
||||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||||
cleaned = await sweeper.sweep()
|
||||
|
||||
assert cleaned == 1
|
||||
record = store.get(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.FAILED
|
||||
assert record.error_message == "worker_timeout"
|
||||
|
||||
async def test_sweep_skips_fresh_heartbeat(self):
|
||||
"""sweeper 跳过心跳未超时的任务。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
|
||||
# 新鲜心跳
|
||||
fresh_time = datetime.now(timezone.utc).isoformat()
|
||||
store.update_status(task_id, TaskStatus.RUNNING)
|
||||
record = store.get(task_id)
|
||||
assert record is not None
|
||||
record.metadata["worker_heartbeat_at"] = fresh_time
|
||||
|
||||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||||
cleaned = await sweeper.sweep()
|
||||
|
||||
assert cleaned == 0
|
||||
record = store.get(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.RUNNING
|
||||
|
||||
async def test_sweep_skips_tasks_without_heartbeat(self):
|
||||
"""sweeper 跳过无心跳记录的任务。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
store.update_status(task_id, TaskStatus.RUNNING)
|
||||
# 不设置 worker_heartbeat_at
|
||||
|
||||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||||
cleaned = await sweeper.sweep()
|
||||
|
||||
assert cleaned == 0
|
||||
|
||||
async def test_sweep_skips_non_running_tasks(self):
|
||||
"""sweeper 只处理 RUNNING 状态任务。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
# PENDING 状态 — 不应被清理
|
||||
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
|
||||
record = store.get(task_id)
|
||||
assert record is not None
|
||||
record.metadata["worker_heartbeat_at"] = old_time
|
||||
|
||||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||||
cleaned = await sweeper.sweep()
|
||||
|
||||
assert cleaned == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. 降级模式 — 无 broker 同步执行
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDegradedMode:
|
||||
"""降级模式测试 — 无 broker 时同步执行。"""
|
||||
|
||||
async def test_vectorize_success_transitions_to_completed(self):
|
||||
"""降级模式 — 向量化成功后状态转为 COMPLETED。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
params = _make_vectorize_params()
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||||
|
||||
# 等待降级模式任务完成
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.COMPLETED
|
||||
assert record.progress == 1.0
|
||||
|
||||
async def test_vectorize_failure_sanitizes_error(self):
|
||||
"""降级模式 — 向量化失败时错误消息被净化。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
params = _make_vectorize_params()
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock run_vectorize_task 抛出含路径的异常
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=FileNotFoundError(
|
||||
"[Errno 2] No such file or directory: /usr/local/data/missing.txt"
|
||||
),
|
||||
):
|
||||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||||
|
||||
# 等待降级模式任务完成
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert record.status == TaskStatus.FAILED
|
||||
# 错误消息应被净化(路径替换为 <path>)
|
||||
assert "/usr/local" not in (record.error_message or "")
|
||||
assert "<path>" in (record.error_message or "")
|
||||
|
||||
async def test_degraded_mode_does_not_block_caller(self):
|
||||
"""降级模式 — 任务在后台执行,不阻塞调用方。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
params = _make_vectorize_params()
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock run_vectorize_task — 模拟耗时操作
|
||||
async def _slow_task(*args, **kwargs):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_slow_task,
|
||||
):
|
||||
# submit_vectorize 应立即返回(不等待任务完成)
|
||||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||||
|
||||
# 立即查询 — 应为 PENDING 或 RUNNING(不阻塞)
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
# 状态应为 PENDING(任务刚提交,还未开始执行)
|
||||
# 或 RUNNING(已开始执行但未完成)
|
||||
assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.sleep(0.6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. 任务执行函数测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunVectorizeTask:
|
||||
"""run_vectorize_task 函数测试。"""
|
||||
|
||||
async def test_run_vectorize_calls_document_processor(self):
|
||||
"""run_vectorize_task 调用 DocumentProcessor.process。"""
|
||||
params = _make_vectorize_params()
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock DocumentProcessor.process — 不依赖 llama_index
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_process:
|
||||
await run_vectorize_task(params, **deps)
|
||||
|
||||
mock_process.assert_awaited_once()
|
||||
|
||||
async def test_run_vectorize_failure_propagates(self):
|
||||
"""run_vectorize_task 失败时抛出异常。"""
|
||||
params = _make_vectorize_params()
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("processing failed"),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="processing failed"):
|
||||
await run_vectorize_task(params, **deps)
|
||||
|
||||
|
||||
class TestRunBatchIndexTask:
|
||||
"""run_batch_index_task 函数测试。"""
|
||||
|
||||
async def test_batch_index_returns_results(self):
|
||||
"""run_batch_index_task 返回成功/失败列表。"""
|
||||
params = BatchIndexTaskParams(
|
||||
kb_id="kb-1",
|
||||
document_ids=["doc-1", "doc-2"],
|
||||
user_id="user1",
|
||||
file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
|
||||
file_types={"doc-1": "txt", "doc-2": "txt"},
|
||||
)
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock DocumentProcessor.process — 模拟成功
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
result = await run_batch_index_task(params, **deps)
|
||||
|
||||
assert "succeeded" in result
|
||||
assert "failed" in result
|
||||
assert len(result["succeeded"]) == 2
|
||||
assert len(result["failed"]) == 0
|
||||
|
||||
async def test_batch_index_partial_failure(self):
|
||||
"""批量索引部分失败时返回失败列表。"""
|
||||
params = BatchIndexTaskParams(
|
||||
kb_id="kb-1",
|
||||
document_ids=["doc-1", "doc-missing"],
|
||||
user_id="user1",
|
||||
file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"},
|
||||
file_types={"doc-1": "txt", "doc-missing": "txt"},
|
||||
)
|
||||
deps = _make_mock_dependencies()
|
||||
|
||||
# mock DocumentProcessor.process — 第一次成功,第二次失败
|
||||
call_count = [0]
|
||||
|
||||
async def _side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
raise FileNotFoundError("missing file")
|
||||
|
||||
with patch(
|
||||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=_side_effect,
|
||||
):
|
||||
result = await run_batch_index_task(params, **deps)
|
||||
|
||||
assert "doc-1" in result["succeeded"]
|
||||
assert "doc-missing" in result["failed"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. 心跳更新测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeartbeat:
|
||||
"""worker 心跳测试。"""
|
||||
|
||||
async def test_update_heartbeat_sets_timestamp(self):
|
||||
"""update_heartbeat 设置 worker_heartbeat_at 时间戳。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||||
|
||||
await manager.update_heartbeat(task_id)
|
||||
|
||||
record = await manager.get_task_status(task_id)
|
||||
assert record is not None
|
||||
assert "worker_heartbeat_at" in record.metadata
|
||||
# 验证时间戳可解析
|
||||
hb = record.metadata["worker_heartbeat_at"]
|
||||
datetime.fromisoformat(hb)
|
||||
|
||||
async def test_update_heartbeat_nonexistent_task_noop(self):
|
||||
"""update_heartbeat 对不存在的任务无副作用。"""
|
||||
store = InMemoryTaskStore()
|
||||
manager = TaskManager(broker=None, task_store=store)
|
||||
|
||||
# 不应抛出异常
|
||||
await manager.update_heartbeat("nonexistent")
|
||||
Loading…
Reference in New Issue