fischer-agentkit/tests/unit/rag_platform/test_tasks.py

722 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。
测试场景:
1. 任务参数 schema 校验Pydantic 模型)
2. 错误消息净化(剥离内部路径/连接串/密钥)
3. 任务提交创建任务记录PENDING
4. 任务状态查询pending → running → completed
5. 任务历史可查询(按 user_id 过滤)
6. per-user 并发上限强制
7. 任务取消
8. Worker sweeper 检测超时任务
9. 降级模式(无 broker同步执行
10. 任务失败后错误消息净化
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.protocol import TaskStatus
from agentkit.rag_platform.tasks import (
BatchIndexTaskParams,
ConcurrencyLimitExceeded,
TaskManager,
TaskNotFoundError,
VectorizeTaskParams,
WorkerSweeper,
run_batch_index_task,
run_vectorize_task,
sanitize_error_message,
)
from agentkit.server.task_store import InMemoryTaskStore
# ---------------------------------------------------------------------------
# 测试 fixtures
# ---------------------------------------------------------------------------
def _make_vectorize_params(**overrides) -> VectorizeTaskParams:
"""创建向量化任务参数。"""
defaults = {
"document_id": "doc-001",
"kb_id": "kb-001",
"file_path": "/tmp/test.txt",
"file_type": "txt",
"user_id": "user1",
}
defaults.update(overrides)
return VectorizeTaskParams(**defaults)
def _make_batch_params(**overrides) -> BatchIndexTaskParams:
"""创建批量索引任务参数。"""
defaults = {
"kb_id": "kb-001",
"document_ids": ["doc-1", "doc-2"],
"user_id": "user1",
"file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
"file_types": {"doc-1": "txt", "doc-2": "txt"},
}
defaults.update(overrides)
return BatchIndexTaskParams(**defaults)
def _make_mock_dependencies():
"""创建 mock 依赖store/vector_store/embed_model"""
store = MagicMock()
store.update_document_status = AsyncMock()
vector_store = MagicMock()
vector_store.async_add = AsyncMock()
embed_model = MagicMock()
embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return {"store": store, "vector_store": vector_store, "embed_model": embed_model}
# ---------------------------------------------------------------------------
# 1. 任务参数 schema 校验
# ---------------------------------------------------------------------------
class TestVectorizeTaskParams:
"""VectorizeTaskParams schema 校验测试。"""
def test_valid_params(self):
"""合法参数通过校验。"""
params = _make_vectorize_params()
assert params.document_id == "doc-001"
assert params.chunk_size == 512
assert params.chunk_overlap == 50
def test_empty_document_id_rejected(self):
"""空 document_id 被拒绝。"""
with pytest.raises(ValueError, match="document_id"):
VectorizeTaskParams(
document_id="",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
user_id="u1",
)
def test_chunk_overlap_must_be_less_than_size(self):
"""chunk_overlap 必须小于 chunk_size。"""
with pytest.raises(ValueError, match="chunk_overlap"):
VectorizeTaskParams(
document_id="doc-1",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
user_id="u1",
chunk_size=100,
chunk_overlap=100,
)
def test_chunk_size_bounds(self):
"""chunk_size 必须在 50-8192 之间。"""
with pytest.raises(ValueError):
_make_vectorize_params(chunk_size=10)
def test_user_id_required(self):
"""user_id 必填。"""
with pytest.raises(ValueError):
VectorizeTaskParams(
document_id="doc-1",
kb_id="kb-1",
file_path="/tmp/x.txt",
file_type="txt",
)
class TestBatchIndexTaskParams:
"""BatchIndexTaskParams schema 校验测试。"""
def test_valid_params(self):
"""合法参数通过校验。"""
params = _make_batch_params()
assert len(params.document_ids) == 2
def test_empty_document_ids_rejected(self):
"""空 document_ids 列表被拒绝。"""
with pytest.raises(ValueError, match="document_ids"):
BatchIndexTaskParams(
kb_id="kb-1",
document_ids=[],
user_id="u1",
)
def test_too_many_document_ids_rejected(self):
"""超过 100 个 document_ids 被拒绝。"""
with pytest.raises(ValueError):
BatchIndexTaskParams(
kb_id="kb-1",
document_ids=[f"doc-{i}" for i in range(101)],
user_id="u1",
)
# ---------------------------------------------------------------------------
# 2. 错误消息净化
# ---------------------------------------------------------------------------
class TestSanitizeErrorMessage:
"""错误消息净化测试。"""
def test_strips_unix_paths(self):
"""剥离 Unix 文件路径。"""
msg = "Failed to open /usr/local/lib/python/file.py"
cleaned = sanitize_error_message(msg)
assert "/usr/local" not in cleaned
assert "<path>" in cleaned
def test_strips_windows_paths(self):
"""剥离 Windows 文件路径。"""
msg = "Error in C:\\Users\\admin\\file.py at line 10"
cleaned = sanitize_error_message(msg)
assert "C:\\Users" not in cleaned
assert "<path>" in cleaned
def test_strips_redis_connection_string(self):
"""剥离 Redis 连接串。"""
msg = "Cannot connect to redis://localhost:6379/0"
cleaned = sanitize_error_message(msg)
assert "redis://localhost" not in cleaned
assert "<connection-string>" in cleaned
def test_strips_postgresql_connection_string(self):
"""剥离 PostgreSQL 连接串。"""
msg = "DB error: postgresql://user:pass@host:5432/db"
cleaned = sanitize_error_message(msg)
assert "postgresql://" not in cleaned
assert "<connection-string>" in cleaned
def test_redacts_api_keys(self):
"""净化 API key。"""
msg = "Auth failed with api_key=sk-abc123def456"
cleaned = sanitize_error_message(msg)
assert "sk-abc123def456" not in cleaned
assert "<redacted>" in cleaned
def test_redacts_password(self):
"""净化 password。"""
msg = "Login failed password=secret123"
cleaned = sanitize_error_message(msg)
assert "secret123" not in cleaned
assert "<redacted>" in cleaned
def test_truncates_long_messages(self):
"""长消息截断。"""
msg = "x" * 1000
cleaned = sanitize_error_message(msg, max_length=100)
assert len(cleaned) == 100
assert cleaned.endswith("...")
def test_empty_message_returns_empty(self):
"""空消息返回空字符串。"""
assert sanitize_error_message("") == ""
# ---------------------------------------------------------------------------
# 3. TaskManager — 任务提交与状态查询
# ---------------------------------------------------------------------------
class TestTaskManagerSubmit:
"""任务提交测试。"""
async def test_submit_vectorize_creates_pending_task(self):
"""提交向量化任务创建 PENDING 状态记录。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
task_id = await manager.submit_vectorize(params)
assert task_id.startswith("vec-")
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
assert record.input_data["document_id"] == "doc-001"
assert record.metadata["user_id"] == "user1"
async def test_submit_batch_index_creates_pending_task(self):
"""提交批量索引任务创建 PENDING 状态记录。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_batch_params()
task_id = await manager.submit_batch_index(params)
assert task_id.startswith("batch-")
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
assert len(record.input_data["document_ids"]) == 2
async def test_submit_with_broker_dispatches(self):
"""有 broker 时调用 broker.kiq 派发任务。"""
store = InMemoryTaskStore()
mock_broker = MagicMock()
mock_broker.kiq = AsyncMock()
manager = TaskManager(broker=mock_broker, task_store=store)
params = _make_vectorize_params()
task_id = await manager.submit_vectorize(params)
mock_broker.kiq.assert_awaited_once()
# 验证派发参数包含 task_type 和 task_id
call_kwargs = mock_broker.kiq.call_args.kwargs
assert call_kwargs["task_type"] == "vectorize"
assert call_kwargs["task_id"] == task_id
# ---------------------------------------------------------------------------
# 4. 任务状态查询
# ---------------------------------------------------------------------------
class TestTaskManagerQuery:
"""任务状态查询测试。"""
async def test_get_task_status_returns_none_if_not_found(self):
"""查询不存在的任务返回 None。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
result = await manager.get_task_status("nonexistent")
assert result is None
async def test_list_tasks_filters_by_user(self):
"""list_tasks 按 user_id 过滤。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
# user1 提交 2 个任务
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1"))
# user2 提交 1 个任务
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2"))
user1_tasks = await manager.list_tasks(user_id="user1")
user2_tasks = await manager.list_tasks(user_id="user2")
all_tasks = await manager.list_tasks()
assert len(user1_tasks) == 2
assert len(user2_tasks) == 1
assert len(all_tasks) == 3
async def test_list_tasks_filters_by_status(self):
"""list_tasks 按 status 过滤。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync
store.update_status(task_id, TaskStatus.COMPLETED)
# 再提交一个 PENDING
await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2"))
pending = await manager.list_tasks(status=TaskStatus.PENDING)
completed = await manager.list_tasks(status=TaskStatus.COMPLETED)
assert len(pending) == 1
assert len(completed) == 1
# ---------------------------------------------------------------------------
# 5. per-user 并发上限
# ---------------------------------------------------------------------------
class TestConcurrencyLimit:
"""per-user 并发上限测试。"""
async def test_concurrency_limit_exceeded(self):
"""超过 per-user 并发上限抛出异常。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=2,
)
# 提交 2 个任务(达到上限)
await manager.submit_vectorize(_make_vectorize_params(document_id="d1"))
await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
# 第 3 个应抛出异常
with pytest.raises(ConcurrencyLimitExceeded):
await manager.submit_vectorize(_make_vectorize_params(document_id="d3"))
async def test_concurrency_limit_per_user_isolated(self):
"""不同用户的并发上限独立计算。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=1,
)
# user1 提交 1 个任务
await manager.submit_vectorize(_make_vectorize_params(user_id="user1"))
# user2 仍可提交
task_id = await manager.submit_vectorize(
_make_vectorize_params(document_id="d2", user_id="user2")
)
assert task_id is not None
async def test_concurrency_released_after_cancel(self):
"""取消任务后释放并发计数。"""
store = InMemoryTaskStore()
manager = TaskManager(
broker=None,
task_store=store,
max_concurrent_per_user=1,
)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 取消任务
await manager.cancel_task(task_id)
# 应可再次提交
task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2"))
assert task_id2 is not None
# ---------------------------------------------------------------------------
# 6. 任务取消
# ---------------------------------------------------------------------------
class TestTaskCancel:
"""任务取消测试。"""
async def test_cancel_pending_task(self):
"""取消 PENDING 任务成功。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
result = await manager.cancel_task(task_id)
assert result is True
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.CANCELLED
async def test_cancel_nonexistent_raises(self):
"""取消不存在的任务抛出 TaskNotFoundError。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
with pytest.raises(TaskNotFoundError):
await manager.cancel_task("nonexistent")
async def test_cancel_completed_returns_false(self):
"""取消已完成任务返回 False。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
store.update_status(task_id, TaskStatus.COMPLETED)
result = await manager.cancel_task(task_id)
assert result is False
# ---------------------------------------------------------------------------
# 7. WorkerSweeper — 超时任务检测
# ---------------------------------------------------------------------------
class TestWorkerSweeper:
"""Worker sweeper 测试。"""
async def test_sweep_detects_timed_out_task(self):
"""sweeper 检测超时任务并标记为 failed。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
store.update_status(task_id, TaskStatus.RUNNING)
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = old_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 1
record = store.get(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
assert record.error_message == "worker_timeout"
async def test_sweep_skips_fresh_heartbeat(self):
"""sweeper 跳过心跳未超时的任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# 新鲜心跳
fresh_time = datetime.now(timezone.utc).isoformat()
store.update_status(task_id, TaskStatus.RUNNING)
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = fresh_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
record = store.get(task_id)
assert record is not None
assert record.status == TaskStatus.RUNNING
async def test_sweep_skips_tasks_without_heartbeat(self):
"""sweeper 跳过无心跳记录的任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
store.update_status(task_id, TaskStatus.RUNNING)
# 不设置 worker_heartbeat_at
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
async def test_sweep_skips_non_running_tasks(self):
"""sweeper 只处理 RUNNING 状态任务。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
# PENDING 状态 — 不应被清理
old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat()
record = store.get(task_id)
assert record is not None
record.metadata["worker_heartbeat_at"] = old_time
sweeper = WorkerSweeper(task_store=store, ttl_seconds=300)
cleaned = await sweeper.sweep()
assert cleaned == 0
# ---------------------------------------------------------------------------
# 8. 降级模式 — 无 broker 同步执行
# ---------------------------------------------------------------------------
class TestDegradedMode:
"""降级模式测试 — 无 broker 时同步执行。"""
async def test_vectorize_success_transitions_to_completed(self):
"""降级模式 — 向量化成功后状态转为 COMPLETED。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
):
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 等待降级模式任务完成
await asyncio.sleep(0.1)
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.progress == 1.0
async def test_vectorize_failure_sanitizes_error(self):
"""降级模式 — 向量化失败时错误消息被净化。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task 抛出含路径的异常
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
side_effect=FileNotFoundError(
"[Errno 2] No such file or directory: /usr/local/data/missing.txt"
),
):
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 等待降级模式任务完成
await asyncio.sleep(0.1)
record = await manager.get_task_status(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
# 错误消息应被净化(路径替换为 <path>
assert "/usr/local" not in (record.error_message or "")
assert "<path>" in (record.error_message or "")
async def test_degraded_mode_does_not_block_caller(self):
"""降级模式 — 任务在后台执行,不阻塞调用方。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock run_vectorize_task — 模拟耗时操作
async def _slow_task(*args, **kwargs):
await asyncio.sleep(0.5)
with patch(
"agentkit.rag_platform.tasks.run_vectorize_task",
new_callable=AsyncMock,
side_effect=_slow_task,
):
# submit_vectorize 应立即返回(不等待任务完成)
task_id = await manager.submit_vectorize(params, dependencies=deps)
# 立即查询 — 应为 PENDING 或 RUNNING不阻塞
record = await manager.get_task_status(task_id)
assert record is not None
# 状态应为 PENDING任务刚提交还未开始执行
# 或 RUNNING已开始执行但未完成
assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING)
# 等待任务完成
await asyncio.sleep(0.6)
# ---------------------------------------------------------------------------
# 9. 任务执行函数测试
# ---------------------------------------------------------------------------
class TestRunVectorizeTask:
"""run_vectorize_task 函数测试。"""
async def test_run_vectorize_calls_document_processor(self):
"""run_vectorize_task 调用 DocumentProcessor.process。"""
params = _make_vectorize_params()
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 不依赖 llama_index
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
) as mock_process:
await run_vectorize_task(params, **deps)
mock_process.assert_awaited_once()
async def test_run_vectorize_failure_propagates(self):
"""run_vectorize_task 失败时抛出异常。"""
params = _make_vectorize_params()
deps = _make_mock_dependencies()
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
side_effect=RuntimeError("processing failed"),
):
with pytest.raises(RuntimeError, match="processing failed"):
await run_vectorize_task(params, **deps)
class TestRunBatchIndexTask:
"""run_batch_index_task 函数测试。"""
async def test_batch_index_returns_results(self):
"""run_batch_index_task 返回成功/失败列表。"""
params = BatchIndexTaskParams(
kb_id="kb-1",
document_ids=["doc-1", "doc-2"],
user_id="user1",
file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"},
file_types={"doc-1": "txt", "doc-2": "txt"},
)
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 模拟成功
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
):
result = await run_batch_index_task(params, **deps)
assert "succeeded" in result
assert "failed" in result
assert len(result["succeeded"]) == 2
assert len(result["failed"]) == 0
async def test_batch_index_partial_failure(self):
"""批量索引部分失败时返回失败列表。"""
params = BatchIndexTaskParams(
kb_id="kb-1",
document_ids=["doc-1", "doc-missing"],
user_id="user1",
file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"},
file_types={"doc-1": "txt", "doc-missing": "txt"},
)
deps = _make_mock_dependencies()
# mock DocumentProcessor.process — 第一次成功,第二次失败
call_count = [0]
async def _side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 2:
raise FileNotFoundError("missing file")
with patch(
"agentkit.rag_platform.tasks.DocumentProcessor.process",
new_callable=AsyncMock,
side_effect=_side_effect,
):
result = await run_batch_index_task(params, **deps)
assert "doc-1" in result["succeeded"]
assert "doc-missing" in result["failed"]
# ---------------------------------------------------------------------------
# 10. 心跳更新测试
# ---------------------------------------------------------------------------
class TestHeartbeat:
"""worker 心跳测试。"""
async def test_update_heartbeat_sets_timestamp(self):
"""update_heartbeat 设置 worker_heartbeat_at 时间戳。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
task_id = await manager.submit_vectorize(_make_vectorize_params())
await manager.update_heartbeat(task_id)
record = await manager.get_task_status(task_id)
assert record is not None
assert "worker_heartbeat_at" in record.metadata
# 验证时间戳可解析
hb = record.metadata["worker_heartbeat_at"]
datetime.fromisoformat(hb)
async def test_update_heartbeat_nonexistent_task_noop(self):
"""update_heartbeat 对不存在的任务无副作用。"""
store = InMemoryTaskStore()
manager = TaskManager(broker=None, task_store=store)
# 不应抛出异常
await manager.update_heartbeat("nonexistent")