"""U8 测试 — TaskIQ 异步任务集成 — 文档向量化。 测试场景: 1. 任务参数 schema 校验(Pydantic 模型) 2. 错误消息净化(剥离内部路径/连接串/密钥) 3. 任务提交创建任务记录(PENDING) 4. 任务状态查询(pending → running → completed) 5. 任务历史可查询(按 user_id 过滤) 6. per-user 并发上限强制 7. 任务取消 8. Worker sweeper 检测超时任务 9. 降级模式(无 broker)同步执行 10. 任务失败后错误消息净化 """ from __future__ import annotations import asyncio from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.protocol import TaskStatus from agentkit.rag_platform.tasks import ( BatchIndexTaskParams, ConcurrencyLimitExceeded, TaskManager, TaskNotFoundError, VectorizeTaskParams, WorkerSweeper, run_batch_index_task, run_vectorize_task, sanitize_error_message, ) from agentkit.server.task_store import InMemoryTaskStore # --------------------------------------------------------------------------- # 测试 fixtures # --------------------------------------------------------------------------- def _make_vectorize_params(**overrides) -> VectorizeTaskParams: """创建向量化任务参数。""" defaults = { "document_id": "doc-001", "kb_id": "kb-001", "file_path": "/tmp/test.txt", "file_type": "txt", "user_id": "user1", } defaults.update(overrides) return VectorizeTaskParams(**defaults) def _make_batch_params(**overrides) -> BatchIndexTaskParams: """创建批量索引任务参数。""" defaults = { "kb_id": "kb-001", "document_ids": ["doc-1", "doc-2"], "user_id": "user1", "file_paths": {"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"}, "file_types": {"doc-1": "txt", "doc-2": "txt"}, } defaults.update(overrides) return BatchIndexTaskParams(**defaults) def _make_mock_dependencies(): """创建 mock 依赖(store/vector_store/embed_model)。""" store = MagicMock() store.update_document_status = AsyncMock() vector_store = MagicMock() vector_store.async_add = AsyncMock() embed_model = MagicMock() embed_model.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536) return {"store": store, "vector_store": vector_store, "embed_model": embed_model} # --------------------------------------------------------------------------- # 1. 任务参数 schema 校验 # --------------------------------------------------------------------------- class TestVectorizeTaskParams: """VectorizeTaskParams schema 校验测试。""" def test_valid_params(self): """合法参数通过校验。""" params = _make_vectorize_params() assert params.document_id == "doc-001" assert params.chunk_size == 512 assert params.chunk_overlap == 50 def test_empty_document_id_rejected(self): """空 document_id 被拒绝。""" with pytest.raises(ValueError, match="document_id"): VectorizeTaskParams( document_id="", kb_id="kb-1", file_path="/tmp/x.txt", file_type="txt", user_id="u1", ) def test_chunk_overlap_must_be_less_than_size(self): """chunk_overlap 必须小于 chunk_size。""" with pytest.raises(ValueError, match="chunk_overlap"): VectorizeTaskParams( document_id="doc-1", kb_id="kb-1", file_path="/tmp/x.txt", file_type="txt", user_id="u1", chunk_size=100, chunk_overlap=100, ) def test_chunk_size_bounds(self): """chunk_size 必须在 50-8192 之间。""" with pytest.raises(ValueError): _make_vectorize_params(chunk_size=10) def test_user_id_required(self): """user_id 必填。""" with pytest.raises(ValueError): VectorizeTaskParams( document_id="doc-1", kb_id="kb-1", file_path="/tmp/x.txt", file_type="txt", ) class TestBatchIndexTaskParams: """BatchIndexTaskParams schema 校验测试。""" def test_valid_params(self): """合法参数通过校验。""" params = _make_batch_params() assert len(params.document_ids) == 2 def test_empty_document_ids_rejected(self): """空 document_ids 列表被拒绝。""" with pytest.raises(ValueError, match="document_ids"): BatchIndexTaskParams( kb_id="kb-1", document_ids=[], user_id="u1", ) def test_too_many_document_ids_rejected(self): """超过 100 个 document_ids 被拒绝。""" with pytest.raises(ValueError): BatchIndexTaskParams( kb_id="kb-1", document_ids=[f"doc-{i}" for i in range(101)], user_id="u1", ) # --------------------------------------------------------------------------- # 2. 错误消息净化 # --------------------------------------------------------------------------- class TestSanitizeErrorMessage: """错误消息净化测试。""" def test_strips_unix_paths(self): """剥离 Unix 文件路径。""" msg = "Failed to open /usr/local/lib/python/file.py" cleaned = sanitize_error_message(msg) assert "/usr/local" not in cleaned assert "" in cleaned def test_strips_windows_paths(self): """剥离 Windows 文件路径。""" msg = "Error in C:\\Users\\admin\\file.py at line 10" cleaned = sanitize_error_message(msg) assert "C:\\Users" not in cleaned assert "" in cleaned def test_strips_redis_connection_string(self): """剥离 Redis 连接串。""" msg = "Cannot connect to redis://localhost:6379/0" cleaned = sanitize_error_message(msg) assert "redis://localhost" not in cleaned assert "" in cleaned def test_strips_postgresql_connection_string(self): """剥离 PostgreSQL 连接串。""" msg = "DB error: postgresql://user:pass@host:5432/db" cleaned = sanitize_error_message(msg) assert "postgresql://" not in cleaned assert "" in cleaned def test_redacts_api_keys(self): """净化 API key。""" msg = "Auth failed with api_key=sk-abc123def456" cleaned = sanitize_error_message(msg) assert "sk-abc123def456" not in cleaned assert "" in cleaned def test_redacts_password(self): """净化 password。""" msg = "Login failed password=secret123" cleaned = sanitize_error_message(msg) assert "secret123" not in cleaned assert "" in cleaned def test_truncates_long_messages(self): """长消息截断。""" msg = "x" * 1000 cleaned = sanitize_error_message(msg, max_length=100) assert len(cleaned) == 100 assert cleaned.endswith("...") def test_empty_message_returns_empty(self): """空消息返回空字符串。""" assert sanitize_error_message("") == "" # --------------------------------------------------------------------------- # 3. TaskManager — 任务提交与状态查询 # --------------------------------------------------------------------------- class TestTaskManagerSubmit: """任务提交测试。""" async def test_submit_vectorize_creates_pending_task(self): """提交向量化任务创建 PENDING 状态记录。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) params = _make_vectorize_params() task_id = await manager.submit_vectorize(params) assert task_id.startswith("vec-") record = await manager.get_task_status(task_id) assert record is not None assert record.status == TaskStatus.PENDING assert record.input_data["document_id"] == "doc-001" assert record.metadata["user_id"] == "user1" async def test_submit_batch_index_creates_pending_task(self): """提交批量索引任务创建 PENDING 状态记录。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) params = _make_batch_params() task_id = await manager.submit_batch_index(params) assert task_id.startswith("batch-") record = await manager.get_task_status(task_id) assert record is not None assert record.status == TaskStatus.PENDING assert len(record.input_data["document_ids"]) == 2 async def test_submit_with_broker_dispatches(self): """有 broker 时调用 broker.kiq 派发任务。""" store = InMemoryTaskStore() mock_broker = MagicMock() mock_broker.kiq = AsyncMock() manager = TaskManager(broker=mock_broker, task_store=store) params = _make_vectorize_params() task_id = await manager.submit_vectorize(params) mock_broker.kiq.assert_awaited_once() # 验证派发参数包含 task_type 和 task_id call_kwargs = mock_broker.kiq.call_args.kwargs assert call_kwargs["task_type"] == "vectorize" assert call_kwargs["task_id"] == task_id # --------------------------------------------------------------------------- # 4. 任务状态查询 # --------------------------------------------------------------------------- class TestTaskManagerQuery: """任务状态查询测试。""" async def test_get_task_status_returns_none_if_not_found(self): """查询不存在的任务返回 None。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) result = await manager.get_task_status("nonexistent") assert result is None async def test_list_tasks_filters_by_user(self): """list_tasks 按 user_id 过滤。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) # user1 提交 2 个任务 await manager.submit_vectorize(_make_vectorize_params(user_id="user1")) await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2", user_id="user1")) # user2 提交 1 个任务 await manager.submit_vectorize(_make_vectorize_params(document_id="doc-3", user_id="user2")) user1_tasks = await manager.list_tasks(user_id="user1") user2_tasks = await manager.list_tasks(user_id="user2") all_tasks = await manager.list_tasks() assert len(user1_tasks) == 2 assert len(user2_tasks) == 1 assert len(all_tasks) == 3 async def test_list_tasks_filters_by_status(self): """list_tasks 按 status 过滤。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) # 手动置为 COMPLETED — InMemoryTaskStore.update_status 为 sync store.update_status(task_id, TaskStatus.COMPLETED) # 再提交一个 PENDING await manager.submit_vectorize(_make_vectorize_params(document_id="doc-2")) pending = await manager.list_tasks(status=TaskStatus.PENDING) completed = await manager.list_tasks(status=TaskStatus.COMPLETED) assert len(pending) == 1 assert len(completed) == 1 # --------------------------------------------------------------------------- # 5. per-user 并发上限 # --------------------------------------------------------------------------- class TestConcurrencyLimit: """per-user 并发上限测试。""" async def test_concurrency_limit_exceeded(self): """超过 per-user 并发上限抛出异常。""" store = InMemoryTaskStore() manager = TaskManager( broker=None, task_store=store, max_concurrent_per_user=2, ) # 提交 2 个任务(达到上限) await manager.submit_vectorize(_make_vectorize_params(document_id="d1")) await manager.submit_vectorize(_make_vectorize_params(document_id="d2")) # 第 3 个应抛出异常 with pytest.raises(ConcurrencyLimitExceeded): await manager.submit_vectorize(_make_vectorize_params(document_id="d3")) async def test_concurrency_limit_per_user_isolated(self): """不同用户的并发上限独立计算。""" store = InMemoryTaskStore() manager = TaskManager( broker=None, task_store=store, max_concurrent_per_user=1, ) # user1 提交 1 个任务 await manager.submit_vectorize(_make_vectorize_params(user_id="user1")) # user2 仍可提交 task_id = await manager.submit_vectorize( _make_vectorize_params(document_id="d2", user_id="user2") ) assert task_id is not None async def test_concurrency_released_after_cancel(self): """取消任务后释放并发计数。""" store = InMemoryTaskStore() manager = TaskManager( broker=None, task_store=store, max_concurrent_per_user=1, ) task_id = await manager.submit_vectorize(_make_vectorize_params()) # 取消任务 await manager.cancel_task(task_id) # 应可再次提交 task_id2 = await manager.submit_vectorize(_make_vectorize_params(document_id="d2")) assert task_id2 is not None # --------------------------------------------------------------------------- # 6. 任务取消 # --------------------------------------------------------------------------- class TestTaskCancel: """任务取消测试。""" async def test_cancel_pending_task(self): """取消 PENDING 任务成功。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) result = await manager.cancel_task(task_id) assert result is True record = await manager.get_task_status(task_id) assert record is not None assert record.status == TaskStatus.CANCELLED async def test_cancel_nonexistent_raises(self): """取消不存在的任务抛出 TaskNotFoundError。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) with pytest.raises(TaskNotFoundError): await manager.cancel_task("nonexistent") async def test_cancel_completed_returns_false(self): """取消已完成任务返回 False。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) store.update_status(task_id, TaskStatus.COMPLETED) result = await manager.cancel_task(task_id) assert result is False # --------------------------------------------------------------------------- # 7. WorkerSweeper — 超时任务检测 # --------------------------------------------------------------------------- class TestWorkerSweeper: """Worker sweeper 测试。""" async def test_sweep_detects_timed_out_task(self): """sweeper 检测超时任务并标记为 failed。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) # 模拟 RUNNING 状态 + 过期心跳 — InMemoryTaskStore 方法为 sync old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat() store.update_status(task_id, TaskStatus.RUNNING) record = store.get(task_id) assert record is not None record.metadata["worker_heartbeat_at"] = old_time sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) cleaned = await sweeper.sweep() assert cleaned == 1 record = store.get(task_id) assert record is not None assert record.status == TaskStatus.FAILED assert record.error_message == "worker_timeout" async def test_sweep_skips_fresh_heartbeat(self): """sweeper 跳过心跳未超时的任务。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) # 新鲜心跳 fresh_time = datetime.now(timezone.utc).isoformat() store.update_status(task_id, TaskStatus.RUNNING) record = store.get(task_id) assert record is not None record.metadata["worker_heartbeat_at"] = fresh_time sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) cleaned = await sweeper.sweep() assert cleaned == 0 record = store.get(task_id) assert record is not None assert record.status == TaskStatus.RUNNING async def test_sweep_skips_tasks_without_heartbeat(self): """sweeper 跳过无心跳记录的任务。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) store.update_status(task_id, TaskStatus.RUNNING) # 不设置 worker_heartbeat_at sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) cleaned = await sweeper.sweep() assert cleaned == 0 async def test_sweep_skips_non_running_tasks(self): """sweeper 只处理 RUNNING 状态任务。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) # PENDING 状态 — 不应被清理 old_time = (datetime.now(timezone.utc) - timedelta(seconds=600)).isoformat() record = store.get(task_id) assert record is not None record.metadata["worker_heartbeat_at"] = old_time sweeper = WorkerSweeper(task_store=store, ttl_seconds=300) cleaned = await sweeper.sweep() assert cleaned == 0 # --------------------------------------------------------------------------- # 8. 降级模式 — 无 broker 同步执行 # --------------------------------------------------------------------------- class TestDegradedMode: """降级模式测试 — 无 broker 时同步执行。""" async def test_vectorize_success_transitions_to_completed(self): """降级模式 — 向量化成功后状态转为 COMPLETED。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) params = _make_vectorize_params() deps = _make_mock_dependencies() # mock run_vectorize_task — 测试 TaskManager 状态转换,不依赖 llama_index with patch( "agentkit.rag_platform.tasks.run_vectorize_task", new_callable=AsyncMock, ): task_id = await manager.submit_vectorize(params, dependencies=deps) # 等待降级模式任务完成 await asyncio.sleep(0.1) record = await manager.get_task_status(task_id) assert record is not None assert record.status == TaskStatus.COMPLETED assert record.progress == 1.0 async def test_vectorize_failure_sanitizes_error(self): """降级模式 — 向量化失败时错误消息被净化。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) params = _make_vectorize_params() deps = _make_mock_dependencies() # mock run_vectorize_task 抛出含路径的异常 with patch( "agentkit.rag_platform.tasks.run_vectorize_task", new_callable=AsyncMock, side_effect=FileNotFoundError( "[Errno 2] No such file or directory: /usr/local/data/missing.txt" ), ): task_id = await manager.submit_vectorize(params, dependencies=deps) # 等待降级模式任务完成 await asyncio.sleep(0.1) record = await manager.get_task_status(task_id) assert record is not None assert record.status == TaskStatus.FAILED # 错误消息应被净化(路径替换为 ) assert "/usr/local" not in (record.error_message or "") assert "" in (record.error_message or "") async def test_degraded_mode_does_not_block_caller(self): """降级模式 — 任务在后台执行,不阻塞调用方。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) params = _make_vectorize_params() deps = _make_mock_dependencies() # mock run_vectorize_task — 模拟耗时操作 async def _slow_task(*args, **kwargs): await asyncio.sleep(0.5) with patch( "agentkit.rag_platform.tasks.run_vectorize_task", new_callable=AsyncMock, side_effect=_slow_task, ): # submit_vectorize 应立即返回(不等待任务完成) task_id = await manager.submit_vectorize(params, dependencies=deps) # 立即查询 — 应为 PENDING 或 RUNNING(不阻塞) record = await manager.get_task_status(task_id) assert record is not None # 状态应为 PENDING(任务刚提交,还未开始执行) # 或 RUNNING(已开始执行但未完成) assert record.status in (TaskStatus.PENDING, TaskStatus.RUNNING) # 等待任务完成 await asyncio.sleep(0.6) # --------------------------------------------------------------------------- # 9. 任务执行函数测试 # --------------------------------------------------------------------------- class TestRunVectorizeTask: """run_vectorize_task 函数测试。""" async def test_run_vectorize_calls_document_processor(self): """run_vectorize_task 调用 DocumentProcessor.process。""" params = _make_vectorize_params() deps = _make_mock_dependencies() # mock DocumentProcessor.process — 不依赖 llama_index with patch( "agentkit.rag_platform.tasks.DocumentProcessor.process", new_callable=AsyncMock, ) as mock_process: await run_vectorize_task(params, **deps) mock_process.assert_awaited_once() async def test_run_vectorize_failure_propagates(self): """run_vectorize_task 失败时抛出异常。""" params = _make_vectorize_params() deps = _make_mock_dependencies() with patch( "agentkit.rag_platform.tasks.DocumentProcessor.process", new_callable=AsyncMock, side_effect=RuntimeError("processing failed"), ): with pytest.raises(RuntimeError, match="processing failed"): await run_vectorize_task(params, **deps) class TestRunBatchIndexTask: """run_batch_index_task 函数测试。""" async def test_batch_index_returns_results(self): """run_batch_index_task 返回成功/失败列表。""" params = BatchIndexTaskParams( kb_id="kb-1", document_ids=["doc-1", "doc-2"], user_id="user1", file_paths={"doc-1": "/tmp/a.txt", "doc-2": "/tmp/b.txt"}, file_types={"doc-1": "txt", "doc-2": "txt"}, ) deps = _make_mock_dependencies() # mock DocumentProcessor.process — 模拟成功 with patch( "agentkit.rag_platform.tasks.DocumentProcessor.process", new_callable=AsyncMock, ): result = await run_batch_index_task(params, **deps) assert "succeeded" in result assert "failed" in result assert len(result["succeeded"]) == 2 assert len(result["failed"]) == 0 async def test_batch_index_partial_failure(self): """批量索引部分失败时返回失败列表。""" params = BatchIndexTaskParams( kb_id="kb-1", document_ids=["doc-1", "doc-missing"], user_id="user1", file_paths={"doc-1": "/tmp/a.txt", "doc-missing": "/nonexistent/file.txt"}, file_types={"doc-1": "txt", "doc-missing": "txt"}, ) deps = _make_mock_dependencies() # mock DocumentProcessor.process — 第一次成功,第二次失败 call_count = [0] async def _side_effect(*args, **kwargs): call_count[0] += 1 if call_count[0] == 2: raise FileNotFoundError("missing file") with patch( "agentkit.rag_platform.tasks.DocumentProcessor.process", new_callable=AsyncMock, side_effect=_side_effect, ): result = await run_batch_index_task(params, **deps) assert "doc-1" in result["succeeded"] assert "doc-missing" in result["failed"] # --------------------------------------------------------------------------- # 10. 心跳更新测试 # --------------------------------------------------------------------------- class TestHeartbeat: """worker 心跳测试。""" async def test_update_heartbeat_sets_timestamp(self): """update_heartbeat 设置 worker_heartbeat_at 时间戳。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) task_id = await manager.submit_vectorize(_make_vectorize_params()) await manager.update_heartbeat(task_id) record = await manager.get_task_status(task_id) assert record is not None assert "worker_heartbeat_at" in record.metadata # 验证时间戳可解析 hb = record.metadata["worker_heartbeat_at"] datetime.fromisoformat(hb) async def test_update_heartbeat_nonexistent_task_noop(self): """update_heartbeat 对不存在的任务无副作用。""" store = InMemoryTaskStore() manager = TaskManager(broker=None, task_store=store) # 不应抛出异常 await manager.update_heartbeat("nonexistent")