722 lines
26 KiB
Python
722 lines
26 KiB
Python
"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。
|
||
|
||
测试场景:
|
||
1. 任务参数 schema 校验(Pydantic 模型)
|
||
2. 错误消息净化(剥离内部路径/连接串/密钥)
|
||
3. 任务提交创建任务记录(PENDING)
|
||
4. 任务状态查询(pending → running → completed)
|
||
5. 任务历史可查询(按 user_id 过滤)
|
||
6. per-user 并发上限强制
|
||
7. 任务取消
|
||
8. Worker sweeper 检测超时任务
|
||
9. 降级模式(无 broker)同步执行
|
||
10. 任务失败后错误消息净化
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.protocol import TaskStatus
|
||
from agentkit.rag_platform.tasks import (
|
||
BatchIndexTaskParams,
|
||
ConcurrencyLimitExceeded,
|
||
TaskManager,
|
||
TaskNotFoundError,
|
||
VectorizeTaskParams,
|
||
WorkerSweeper,
|
||
run_batch_index_task,
|
||
run_vectorize_task,
|
||
sanitize_error_message,
|
||
)
|
||
from agentkit.server.task_store import InMemoryTaskStore
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试 fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_vectorize_params(**overrides) -> VectorizeTaskParams:
|
||
"""创建向量化任务参数。"""
|
||
defaults = {
|
||
"document_id": "doc-001",
|
||
"kb_id": "kb-001",
|
||
"file_path": "/tmp/test.txt",
|
||
"file_type": "txt",
|
||
"user_id": "user1",
|
||
}
|
||
defaults.update(overrides)
|
||
return VectorizeTaskParams(**defaults)
|
||
|
||
|
||
def _make_batch_params(**overrides) -> BatchIndexTaskParams:
|
||
"""创建批量索引任务参数。"""
|
||
defaults = {
|
||
"kb_id": "kb-001",
|
||
"document_ids": ["doc-1", "doc-2"],
|
||
"user_id": "user1",
|
||
"file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
|
||
"file_types": {"doc-1": "txt", "doc-2": "txt"},
|
||
}
|
||
defaults.update(overrides)
|
||
return BatchIndexTaskParams(**defaults)
|
||
|
||
|
||
def _make_mock_dependencies():
|
||
"""创建 mock 依赖(store/vector_store/embed_model)。"""
|
||
store = MagicMock()
|
||
store.update_document_status = AsyncMock()
|
||
vector_store = MagicMock()
|
||
vector_store.async_add = AsyncMock()
|
||
embed_model = MagicMock()
|
||
embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||
return {"store": store, "vector_store": vector_store, "embed_model": embed_model}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 1. 任务参数 schema 校验
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestVectorizeTaskParams:
|
||
"""VectorizeTaskParams schema 校验测试。"""
|
||
|
||
def test_valid_params(self):
|
||
"""合法参数通过校验。"""
|
||
params = _make_vectorize_params()
|
||
assert params.document_id == "doc-001"
|
||
assert params.chunk_size == 512
|
||
assert params.chunk_overlap == 50
|
||
|
||
def test_empty_document_id_rejected(self):
|
||
"""空 document_id 被拒绝。"""
|
||
with pytest.raises(ValueError, match="document_id"):
|
||
VectorizeTaskParams(
|
||
document_id="",
|
||
kb_id="kb-1",
|
||
file_path="/tmp/x.txt",
|
||
file_type="txt",
|
||
user_id="u1",
|
||
)
|
||
|
||
def test_chunk_overlap_must_be_less_than_size(self):
|
||
"""chunk_overlap 必须小于 chunk_size。"""
|
||
with pytest.raises(ValueError, match="chunk_overlap"):
|
||
VectorizeTaskParams(
|
||
document_id="doc-1",
|
||
kb_id="kb-1",
|
||
file_path="/tmp/x.txt",
|
||
file_type="txt",
|
||
user_id="u1",
|
||
chunk_size=100,
|
||
chunk_overlap=100,
|
||
)
|
||
|
||
def test_chunk_size_bounds(self):
|
||
"""chunk_size 必须在 50-8192 之间。"""
|
||
with pytest.raises(ValueError):
|
||
_make_vectorize_params(chunk_size=10)
|
||
|
||
def test_user_id_required(self):
|
||
"""user_id 必填。"""
|
||
with pytest.raises(ValueError):
|
||
VectorizeTaskParams(
|
||
document_id="doc-1",
|
||
kb_id="kb-1",
|
||
file_path="/tmp/x.txt",
|
||
file_type="txt",
|
||
)
|
||
|
||
|
||
class TestBatchIndexTaskParams:
|
||
"""BatchIndexTaskParams schema 校验测试。"""
|
||
|
||
def test_valid_params(self):
|
||
"""合法参数通过校验。"""
|
||
params = _make_batch_params()
|
||
assert len(params.document_ids) == 2
|
||
|
||
def test_empty_document_ids_rejected(self):
|
||
"""空 document_ids 列表被拒绝。"""
|
||
with pytest.raises(ValueError, match="document_ids"):
|
||
BatchIndexTaskParams(
|
||
kb_id="kb-1",
|
||
document_ids=[],
|
||
user_id="u1",
|
||
)
|
||
|
||
def test_too_many_document_ids_rejected(self):
|
||
"""超过 100 个 document_ids 被拒绝。"""
|
||
with pytest.raises(ValueError):
|
||
BatchIndexTaskParams(
|
||
kb_id="kb-1",
|
||
document_ids=[f"doc-{i}" for i in range(101)],
|
||
user_id="u1",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 2. 错误消息净化
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSanitizeErrorMessage:
|
||
"""错误消息净化测试。"""
|
||
|
||
def test_strips_unix_paths(self):
|
||
"""剥离 Unix 文件路径。"""
|
||
msg = "Failed to open /usr/local/lib/python/file.py"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "/usr/local" not in cleaned
|
||
assert "<path>" in cleaned
|
||
|
||
def test_strips_windows_paths(self):
|
||
"""剥离 Windows 文件路径。"""
|
||
msg = "Error in C:\\Users\\admin\\file.py at line 10"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "C:\\Users" not in cleaned
|
||
assert "<path>" in cleaned
|
||
|
||
def test_strips_redis_connection_string(self):
|
||
"""剥离 Redis 连接串。"""
|
||
msg = "Cannot connect to redis://localhost:6379/0"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "redis://localhost" not in cleaned
|
||
assert "<connection-string>" in cleaned
|
||
|
||
def test_strips_postgresql_connection_string(self):
|
||
"""剥离 PostgreSQL 连接串。"""
|
||
msg = "DB error: postgresql://user:pass@host:5432/db"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "postgresql://" not in cleaned
|
||
assert "<connection-string>" in cleaned
|
||
|
||
def test_redacts_api_keys(self):
|
||
"""净化 API key。"""
|
||
msg = "Auth failed with api_key=sk-abc123def456"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "sk-abc123def456" not in cleaned
|
||
assert "<redacted>" in cleaned
|
||
|
||
def test_redacts_password(self):
|
||
"""净化 password。"""
|
||
msg = "Login failed password=secret123"
|
||
cleaned = sanitize_error_message(msg)
|
||
assert "secret123" not in cleaned
|
||
assert "<redacted>" in cleaned
|
||
|
||
def test_truncates_long_messages(self):
|
||
"""长消息截断。"""
|
||
msg = "x" * 1000
|
||
cleaned = sanitize_error_message(msg, max_length=100)
|
||
assert len(cleaned) == 100
|
||
assert cleaned.endswith("...")
|
||
|
||
def test_empty_message_returns_empty(self):
|
||
"""空消息返回空字符串。"""
|
||
assert sanitize_error_message("") == ""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 3. TaskManager — 任务提交与状态查询
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTaskManagerSubmit:
|
||
"""任务提交测试。"""
|
||
|
||
async def test_submit_vectorize_creates_pending_task(self):
|
||
"""提交向量化任务创建 PENDING 状态记录。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
params = _make_vectorize_params()
|
||
|
||
task_id = await manager.submit_vectorize(params)
|
||
|
||
assert task_id.startswith("vec-")
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.PENDING
|
||
assert record.input_data["document_id"] == "doc-001"
|
||
assert record.metadata["user_id"] == "user1"
|
||
|
||
async def test_submit_batch_index_creates_pending_task(self):
|
||
"""提交批量索引任务创建 PENDING 状态记录。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
params = _make_batch_params()
|
||
|
||
task_id = await manager.submit_batch_index(params)
|
||
|
||
assert task_id.startswith("batch-")
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.PENDING
|
||
assert len(record.input_data["document_ids"]) == 2
|
||
|
||
async def test_submit_with_broker_dispatches(self):
|
||
"""有 broker 时调用 broker.kiq 派发任务。"""
|
||
store = InMemoryTaskStore()
|
||
mock_broker = MagicMock()
|
||
mock_broker.kiq = AsyncMock()
|
||
manager = TaskManager(broker=mock_broker, task_store=store)
|
||
params = _make_vectorize_params()
|
||
|
||
task_id = await manager.submit_vectorize(params)
|
||
|
||
mock_broker.kiq.assert_awaited_once()
|
||
# 验证派发参数包含 task_type 和 task_id
|
||
call_kwargs = mock_broker.kiq.call_args.kwargs
|
||
assert call_kwargs["task_type"] == "vectorize"
|
||
assert call_kwargs["task_id"] == task_id
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 4. 任务状态查询
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTaskManagerQuery:
|
||
"""任务状态查询测试。"""
|
||
|
||
async def test_get_task_status_returns_none_if_not_found(self):
|
||
"""查询不存在的任务返回 None。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
|
||
result = await manager.get_task_status("nonexistent")
|
||
assert result is None
|
||
|
||
async def test_list_tasks_filters_by_user(self):
|
||
"""list_tasks 按 user_id 过滤。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
|
||
# user1 提交 2 个任务
|
||
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1"))
|
||
# user2 提交 1 个任务
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2"))
|
||
|
||
user1_tasks = await manager.list_tasks(user_id="user1")
|
||
user2_tasks = await manager.list_tasks(user_id="user2")
|
||
all_tasks = await manager.list_tasks()
|
||
|
||
assert len(user1_tasks) == 2
|
||
assert len(user2_tasks) == 1
|
||
assert len(all_tasks) == 3
|
||
|
||
async def test_list_tasks_filters_by_status(self):
|
||
"""list_tasks 按 status 过滤。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
# 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync
|
||
store.update_status(task_id, TaskStatus.COMPLETED)
|
||
# 再提交一个 PENDING
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2"))
|
||
|
||
pending = await manager.list_tasks(status=TaskStatus.PENDING)
|
||
completed = await manager.list_tasks(status=TaskStatus.COMPLETED)
|
||
|
||
assert len(pending) == 1
|
||
assert len(completed) == 1
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 5. per-user 并发上限
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestConcurrencyLimit:
|
||
"""per-user 并发上限测试。"""
|
||
|
||
async def test_concurrency_limit_exceeded(self):
|
||
"""超过 per-user 并发上限抛出异常。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(
|
||
broker=None,
|
||
task_store=store,
|
||
max_concurrent_per_user=2,
|
||
)
|
||
|
||
# 提交 2 个任务(达到上限)
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="d1"))
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
|
||
|
||
# 第 3 个应抛出异常
|
||
with pytest.raises(ConcurrencyLimitExceeded):
|
||
await manager.submit_vectorize(_make_vectorize_params(document_id="d3"))
|
||
|
||
async def test_concurrency_limit_per_user_isolated(self):
|
||
"""不同用户的并发上限独立计算。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(
|
||
broker=None,
|
||
task_store=store,
|
||
max_concurrent_per_user=1,
|
||
)
|
||
|
||
# user1 提交 1 个任务
|
||
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
|
||
# user2 仍可提交
|
||
task_id = await manager.submit_vectorize(
|
||
_make_vectorize_params(document_id="d2", user_id="user2")
|
||
)
|
||
assert task_id is not None
|
||
|
||
async def test_concurrency_released_after_cancel(self):
|
||
"""取消任务后释放并发计数。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(
|
||
broker=None,
|
||
task_store=store,
|
||
max_concurrent_per_user=1,
|
||
)
|
||
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
# 取消任务
|
||
await manager.cancel_task(task_id)
|
||
# 应可再次提交
|
||
task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
|
||
assert task_id2 is not None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 6. 任务取消
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTaskCancel:
|
||
"""任务取消测试。"""
|
||
|
||
async def test_cancel_pending_task(self):
|
||
"""取消 PENDING 任务成功。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
|
||
result = await manager.cancel_task(task_id)
|
||
|
||
assert result is True
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.CANCELLED
|
||
|
||
async def test_cancel_nonexistent_raises(self):
|
||
"""取消不存在的任务抛出 TaskNotFoundError。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
|
||
with pytest.raises(TaskNotFoundError):
|
||
await manager.cancel_task("nonexistent")
|
||
|
||
async def test_cancel_completed_returns_false(self):
|
||
"""取消已完成任务返回 False。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
store.update_status(task_id, TaskStatus.COMPLETED)
|
||
|
||
result = await manager.cancel_task(task_id)
|
||
|
||
assert result is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 7. WorkerSweeper — 超时任务检测
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestWorkerSweeper:
|
||
"""Worker sweeper 测试。"""
|
||
|
||
async def test_sweep_detects_timed_out_task(self):
|
||
"""sweeper 检测超时任务并标记为 failed。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
|
||
# 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync
|
||
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
|
||
store.update_status(task_id, TaskStatus.RUNNING)
|
||
record = store.get(task_id)
|
||
assert record is not None
|
||
record.metadata["worker_heartbeat_at"] = old_time
|
||
|
||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||
cleaned = await sweeper.sweep()
|
||
|
||
assert cleaned == 1
|
||
record = store.get(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.FAILED
|
||
assert record.error_message == "worker_timeout"
|
||
|
||
async def test_sweep_skips_fresh_heartbeat(self):
|
||
"""sweeper 跳过心跳未超时的任务。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
|
||
# 新鲜心跳
|
||
fresh_time = datetime.now(timezone.utc).isoformat()
|
||
store.update_status(task_id, TaskStatus.RUNNING)
|
||
record = store.get(task_id)
|
||
assert record is not None
|
||
record.metadata["worker_heartbeat_at"] = fresh_time
|
||
|
||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||
cleaned = await sweeper.sweep()
|
||
|
||
assert cleaned == 0
|
||
record = store.get(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.RUNNING
|
||
|
||
async def test_sweep_skips_tasks_without_heartbeat(self):
|
||
"""sweeper 跳过无心跳记录的任务。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
store.update_status(task_id, TaskStatus.RUNNING)
|
||
# 不设置 worker_heartbeat_at
|
||
|
||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||
cleaned = await sweeper.sweep()
|
||
|
||
assert cleaned == 0
|
||
|
||
async def test_sweep_skips_non_running_tasks(self):
|
||
"""sweeper 只处理 RUNNING 状态任务。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
# PENDING 状态 — 不应被清理
|
||
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
|
||
record = store.get(task_id)
|
||
assert record is not None
|
||
record.metadata["worker_heartbeat_at"] = old_time
|
||
|
||
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
|
||
cleaned = await sweeper.sweep()
|
||
|
||
assert cleaned == 0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 8. 降级模式 — 无 broker 同步执行
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestDegradedMode:
|
||
"""降级模式测试 — 无 broker 时同步执行。"""
|
||
|
||
async def test_vectorize_success_transitions_to_completed(self):
|
||
"""降级模式 — 向量化成功后状态转为 COMPLETED。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
params = _make_vectorize_params()
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||
new_callable=AsyncMock,
|
||
):
|
||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||
|
||
# 等待降级模式任务完成
|
||
await asyncio.sleep(0.1)
|
||
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.COMPLETED
|
||
assert record.progress == 1.0
|
||
|
||
async def test_vectorize_failure_sanitizes_error(self):
|
||
"""降级模式 — 向量化失败时错误消息被净化。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
params = _make_vectorize_params()
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock run_vectorize_task 抛出含路径的异常
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||
new_callable=AsyncMock,
|
||
side_effect=FileNotFoundError(
|
||
"[Errno 2] No such file or directory: /usr/local/data/missing.txt"
|
||
),
|
||
):
|
||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||
|
||
# 等待降级模式任务完成
|
||
await asyncio.sleep(0.1)
|
||
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert record.status == TaskStatus.FAILED
|
||
# 错误消息应被净化(路径替换为 <path>)
|
||
assert "/usr/local" not in (record.error_message or "")
|
||
assert "<path>" in (record.error_message or "")
|
||
|
||
async def test_degraded_mode_does_not_block_caller(self):
|
||
"""降级模式 — 任务在后台执行,不阻塞调用方。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
params = _make_vectorize_params()
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock run_vectorize_task — 模拟耗时操作
|
||
async def _slow_task(*args, **kwargs):
|
||
await asyncio.sleep(0.5)
|
||
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.run_vectorize_task",
|
||
new_callable=AsyncMock,
|
||
side_effect=_slow_task,
|
||
):
|
||
# submit_vectorize 应立即返回(不等待任务完成)
|
||
task_id = await manager.submit_vectorize(params, dependencies=deps)
|
||
|
||
# 立即查询 — 应为 PENDING 或 RUNNING(不阻塞)
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
# 状态应为 PENDING(任务刚提交,还未开始执行)
|
||
# 或 RUNNING(已开始执行但未完成)
|
||
assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
|
||
|
||
# 等待任务完成
|
||
await asyncio.sleep(0.6)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 9. 任务执行函数测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRunVectorizeTask:
|
||
"""run_vectorize_task 函数测试。"""
|
||
|
||
async def test_run_vectorize_calls_document_processor(self):
|
||
"""run_vectorize_task 调用 DocumentProcessor.process。"""
|
||
params = _make_vectorize_params()
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock DocumentProcessor.process — 不依赖 llama_index
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||
new_callable=AsyncMock,
|
||
) as mock_process:
|
||
await run_vectorize_task(params, **deps)
|
||
|
||
mock_process.assert_awaited_once()
|
||
|
||
async def test_run_vectorize_failure_propagates(self):
|
||
"""run_vectorize_task 失败时抛出异常。"""
|
||
params = _make_vectorize_params()
|
||
deps = _make_mock_dependencies()
|
||
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||
new_callable=AsyncMock,
|
||
side_effect=RuntimeError("processing failed"),
|
||
):
|
||
with pytest.raises(RuntimeError, match="processing failed"):
|
||
await run_vectorize_task(params, **deps)
|
||
|
||
|
||
class TestRunBatchIndexTask:
|
||
"""run_batch_index_task 函数测试。"""
|
||
|
||
async def test_batch_index_returns_results(self):
|
||
"""run_batch_index_task 返回成功/失败列表。"""
|
||
params = BatchIndexTaskParams(
|
||
kb_id="kb-1",
|
||
document_ids=["doc-1", "doc-2"],
|
||
user_id="user1",
|
||
file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
|
||
file_types={"doc-1": "txt", "doc-2": "txt"},
|
||
)
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock DocumentProcessor.process — 模拟成功
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||
new_callable=AsyncMock,
|
||
):
|
||
result = await run_batch_index_task(params, **deps)
|
||
|
||
assert "succeeded" in result
|
||
assert "failed" in result
|
||
assert len(result["succeeded"]) == 2
|
||
assert len(result["failed"]) == 0
|
||
|
||
async def test_batch_index_partial_failure(self):
|
||
"""批量索引部分失败时返回失败列表。"""
|
||
params = BatchIndexTaskParams(
|
||
kb_id="kb-1",
|
||
document_ids=["doc-1", "doc-missing"],
|
||
user_id="user1",
|
||
file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"},
|
||
file_types={"doc-1": "txt", "doc-missing": "txt"},
|
||
)
|
||
deps = _make_mock_dependencies()
|
||
|
||
# mock DocumentProcessor.process — 第一次成功,第二次失败
|
||
call_count = [0]
|
||
|
||
async def _side_effect(*args, **kwargs):
|
||
call_count[0] += 1
|
||
if call_count[0] == 2:
|
||
raise FileNotFoundError("missing file")
|
||
|
||
with patch(
|
||
"agentkit.rag_platform.tasks.DocumentProcessor.process",
|
||
new_callable=AsyncMock,
|
||
side_effect=_side_effect,
|
||
):
|
||
result = await run_batch_index_task(params, **deps)
|
||
|
||
assert "doc-1" in result["succeeded"]
|
||
assert "doc-missing" in result["failed"]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 10. 心跳更新测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestHeartbeat:
|
||
"""worker 心跳测试。"""
|
||
|
||
async def test_update_heartbeat_sets_timestamp(self):
|
||
"""update_heartbeat 设置 worker_heartbeat_at 时间戳。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
task_id = await manager.submit_vectorize(_make_vectorize_params())
|
||
|
||
await manager.update_heartbeat(task_id)
|
||
|
||
record = await manager.get_task_status(task_id)
|
||
assert record is not None
|
||
assert "worker_heartbeat_at" in record.metadata
|
||
# 验证时间戳可解析
|
||
hb = record.metadata["worker_heartbeat_at"]
|
||
datetime.fromisoformat(hb)
|
||
|
||
async def test_update_heartbeat_nonexistent_task_noop(self):
|
||
"""update_heartbeat 对不存在的任务无副作用。"""
|
||
store = InMemoryTaskStore()
|
||
manager = TaskManager(broker=None, task_store=store)
|
||
|
||
# 不应抛出异常
|
||
await manager.update_heartbeat("nonexistent")
|