"""U2 测试 — KB 持久化存储。 测试场景: 1. KB 元数据写入 PG,重启后仍存在(mock 验证 commit) 2. owner 用户可查询自己的 KB 3. delete_kb CASCADE 删除关联数据 4. 文档 CRUD 操作 """ from __future__ import annotations from contextlib import asynccontextmanager from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock from agentkit.rag_platform.models import ( DocumentStatus, KnowledgeBase, ) from agentkit.rag_platform.store import KBStore def _make_kb_row(**overrides): """创建模拟 KBModel 行。""" defaults = { "id": "kb-001", "name": "test-kb", "description": "", "owner": "user1", "status": "active", "default_query_mode": "blend", "default_hit_processing": "model_opt", "caching_disabled": False, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } defaults.update(overrides) mock = MagicMock() mock.configure_mock(**defaults) return mock def _make_doc_row(**overrides): """创建模拟 DocumentModel 行。""" defaults = { "id": "doc-001", "kb_id": "kb-001", "filename": "test.pdf", "file_type": "pdf", "file_size": 1024, "status": "pending", "error_message": None, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } defaults.update(overrides) mock = MagicMock() mock.configure_mock(**defaults) return mock def _make_mock_session_factory(): """创建 mock session factory,返回 (factory, mock_session)。""" mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() mock_session.refresh = AsyncMock() mock_session.delete = AsyncMock() mock_session.flush = AsyncMock() @asynccontextmanager async def factory(): yield mock_session return factory, mock_session def _make_mock_execute(result_rows): """创建 mock execute,返回 scalars().all() = result_rows。""" mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = result_rows mock_scalars.first.return_value = result_rows[0] if result_rows else None mock_result.scalars.return_value = mock_scalars return AsyncMock(return_value=mock_result) class TestKBStoreCreate: """KB 创建测试。""" async def test_create_kb_returns_knowledge_base(self): """create_kb 返回 KnowledgeBase 领域模型。""" sf, mock_session = _make_mock_session_factory() # flush 后 kb.id 可用 def _flush_side_effect(): mock_session._kb_id = "kb-001" mock_session.flush.side_effect = _flush_side_effect # refresh 后设置所有必需属性 async def _refresh(obj): obj.id = "kb-001" obj.status = "active" obj.created_at = datetime.now(timezone.utc) obj.updated_at = datetime.now(timezone.utc) mock_session.refresh.side_effect = _refresh store = KBStore(sf) kb = await store.create_kb(name="test", owner="user1", description="desc") assert isinstance(kb, KnowledgeBase) assert kb.name == "test" assert kb.owner == "user1" assert kb.description == "desc" # 验证 add 被调用(KB + ACL) assert mock_session.add.call_count == 2 mock_session.commit.assert_awaited_once() async def test_create_kb_creates_owner_acl(self): """create_kb 在同一事务中创建 owner ACL 条目。""" sf, mock_session = _make_mock_session_factory() async def _refresh(obj): obj.id = "kb-001" obj.status = "active" obj.created_at = datetime.now(timezone.utc) obj.updated_at = datetime.now(timezone.utc) mock_session.refresh.side_effect = _refresh store = KBStore(sf) await store.create_kb(name="test", owner="user1") # 验证第二个 add 调用是 ACL 条目 second_add_call = mock_session.add.call_args_list[1] acl_obj = second_add_call.args[0] assert acl_obj.user_id == "user1" assert acl_obj.role == "owner" class TestKBStoreQuery: """KB 查询测试。""" async def test_get_kb_returns_none_if_not_found(self): """get_kb 返回 None 当 KB 不存在。""" sf, mock_session = _make_mock_session_factory() mock_session.execute = _make_mock_execute([]) store = KBStore(sf) result = await store.get_kb("nonexistent") assert result is None async def test_get_kb_returns_knowledge_base(self): """get_kb 返回 KnowledgeBase 领域模型。""" sf, mock_session = _make_mock_session_factory() kb_row = _make_kb_row() mock_session.execute = _make_mock_execute([kb_row]) store = KBStore(sf) result = await store.get_kb("kb-001") assert result is not None assert result.id == "kb-001" assert result.name == "test-kb" async def test_list_kbs_filters_by_owner(self): """list_kbs 按 owner 过滤。""" sf, mock_session = _make_mock_session_factory() kb_rows = [_make_kb_row(id="kb-1"), _make_kb_row(id="kb-2", name="second")] mock_session.execute = _make_mock_execute(kb_rows) store = KBStore(sf) results = await store.list_kbs(owner="user1") assert len(results) == 2 assert results[0].id == "kb-1" async def test_list_kbs_returns_empty(self): """list_kbs 空列表正常返回。""" sf, mock_session = _make_mock_session_factory() mock_session.execute = _make_mock_execute([]) store = KBStore(sf) results = await store.list_kbs() assert results == [] class TestKBStoreDelete: """KB 删除测试。""" async def test_delete_kb_returns_true(self): """delete_kb 删除成功返回 True。""" sf, mock_session = _make_mock_session_factory() kb_row = _make_kb_row() mock_session.execute = _make_mock_execute([kb_row]) store = KBStore(sf) result = await store.delete_kb("kb-001") assert result is True mock_session.delete.assert_awaited_once_with(kb_row) mock_session.commit.assert_awaited_once() async def test_delete_kb_returns_false_if_not_found(self): """delete_kb KB 不存在返回 False。""" sf, mock_session = _make_mock_session_factory() mock_session.execute = _make_mock_execute([]) store = KBStore(sf) result = await store.delete_kb("nonexistent") assert result is False class TestDocumentOperations: """文档 CRUD 测试。""" async def test_add_document(self): """add_document 创建文档记录。""" sf, mock_session = _make_mock_session_factory() async def _refresh(obj): obj.id = "doc-001" obj.status = "pending" obj.created_at = datetime.now(timezone.utc) obj.updated_at = datetime.now(timezone.utc) mock_session.refresh.side_effect = _refresh store = KBStore(sf) doc = await store.add_document("kb-001", "test.pdf", "pdf", 1024) assert doc.filename == "test.pdf" assert doc.kb_id == "kb-001" mock_session.add.assert_called_once() mock_session.commit.assert_awaited_once() async def test_list_documents(self): """list_documents 返回指定 KB 的文档。""" sf, mock_session = _make_mock_session_factory() doc_rows = [_make_doc_row(id="d1"), _make_doc_row(id="d2", filename="second.pdf")] mock_session.execute = _make_mock_execute(doc_rows) store = KBStore(sf) results = await store.list_documents("kb-001") assert len(results) == 2 assert results[0].id == "d1" async def test_update_document_status(self): """update_document_status 更新文档状态。""" sf, mock_session = _make_mock_session_factory() doc_row = _make_doc_row() mock_session.execute = _make_mock_execute([doc_row]) store = KBStore(sf) result = await store.update_document_status("doc-001", DocumentStatus.indexed) assert result is not None assert doc_row.status == "indexed" mock_session.commit.assert_awaited_once() async def test_update_document_status_not_found(self): """update_document_status 文档不存在返回 None。""" sf, mock_session = _make_mock_session_factory() mock_session.execute = _make_mock_execute([]) store = KBStore(sf) result = await store.update_document_status("nonexistent", DocumentStatus.failed) assert result is None async def test_delete_document(self): """delete_document 删除成功。""" sf, mock_session = _make_mock_session_factory() doc_row = _make_doc_row() mock_session.execute = _make_mock_execute([doc_row]) store = KBStore(sf) result = await store.delete_document("doc-001") assert result is True mock_session.delete.assert_awaited_once_with(doc_row)