289 lines
9.1 KiB
Python
289 lines
9.1 KiB
Python
"""U2 测试 — KB 持久化存储。
|
||
|
||
测试场景:
|
||
1. KB 元数据写入 PG,重启后仍存在(mock 验证 commit)
|
||
2. owner 用户可查询自己的 KB
|
||
3. delete_kb CASCADE 删除关联数据
|
||
4. 文档 CRUD 操作
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.models import (
|
||
DocumentStatus,
|
||
KnowledgeBase,
|
||
)
|
||
from agentkit.rag_platform.store import KBStore
|
||
|
||
|
||
def _make_kb_row(**overrides):
|
||
"""创建模拟 KBModel 行。"""
|
||
defaults = {
|
||
"id": "kb-001",
|
||
"name": "test-kb",
|
||
"description": "",
|
||
"owner": "user1",
|
||
"status": "active",
|
||
"default_query_mode": "blend",
|
||
"default_hit_processing": "model_opt",
|
||
"caching_disabled": False,
|
||
"created_at": datetime.now(timezone.utc),
|
||
"updated_at": datetime.now(timezone.utc),
|
||
}
|
||
defaults.update(overrides)
|
||
mock = MagicMock()
|
||
mock.configure_mock(**defaults)
|
||
return mock
|
||
|
||
|
||
def _make_doc_row(**overrides):
|
||
"""创建模拟 DocumentModel 行。"""
|
||
defaults = {
|
||
"id": "doc-001",
|
||
"kb_id": "kb-001",
|
||
"filename": "test.pdf",
|
||
"file_type": "pdf",
|
||
"file_size": 1024,
|
||
"status": "pending",
|
||
"error_message": None,
|
||
"created_at": datetime.now(timezone.utc),
|
||
"updated_at": datetime.now(timezone.utc),
|
||
}
|
||
defaults.update(overrides)
|
||
mock = MagicMock()
|
||
mock.configure_mock(**defaults)
|
||
return mock
|
||
|
||
|
||
def _make_mock_session_factory():
|
||
"""创建 mock session factory,返回 (factory, mock_session)。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.add = MagicMock()
|
||
mock_session.commit = AsyncMock()
|
||
mock_session.rollback = AsyncMock()
|
||
mock_session.refresh = AsyncMock()
|
||
mock_session.delete = AsyncMock()
|
||
mock_session.flush = AsyncMock()
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
return factory, mock_session
|
||
|
||
|
||
def _make_mock_execute(result_rows):
|
||
"""创建 mock execute,返回 scalars().all() = result_rows。"""
|
||
mock_result = MagicMock()
|
||
mock_scalars = MagicMock()
|
||
mock_scalars.all.return_value = result_rows
|
||
mock_scalars.first.return_value = result_rows[0] if result_rows else None
|
||
mock_result.scalars.return_value = mock_scalars
|
||
return AsyncMock(return_value=mock_result)
|
||
|
||
|
||
class TestKBStoreCreate:
|
||
"""KB 创建测试。"""
|
||
|
||
async def test_create_kb_returns_knowledge_base(self):
|
||
"""create_kb 返回 KnowledgeBase 领域模型。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
|
||
# flush 后 kb.id 可用
|
||
def _flush_side_effect():
|
||
mock_session._kb_id = "kb-001"
|
||
|
||
mock_session.flush.side_effect = _flush_side_effect
|
||
|
||
# refresh 后设置所有必需属性
|
||
async def _refresh(obj):
|
||
obj.id = "kb-001"
|
||
obj.status = "active"
|
||
obj.created_at = datetime.now(timezone.utc)
|
||
obj.updated_at = datetime.now(timezone.utc)
|
||
|
||
mock_session.refresh.side_effect = _refresh
|
||
|
||
store = KBStore(sf)
|
||
kb = await store.create_kb(name="test", owner="user1", description="desc")
|
||
|
||
assert isinstance(kb, KnowledgeBase)
|
||
assert kb.name == "test"
|
||
assert kb.owner == "user1"
|
||
assert kb.description == "desc"
|
||
# 验证 add 被调用(KB + ACL)
|
||
assert mock_session.add.call_count == 2
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_create_kb_creates_owner_acl(self):
|
||
"""create_kb 在同一事务中创建 owner ACL 条目。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
|
||
async def _refresh(obj):
|
||
obj.id = "kb-001"
|
||
obj.status = "active"
|
||
obj.created_at = datetime.now(timezone.utc)
|
||
obj.updated_at = datetime.now(timezone.utc)
|
||
|
||
mock_session.refresh.side_effect = _refresh
|
||
|
||
store = KBStore(sf)
|
||
await store.create_kb(name="test", owner="user1")
|
||
|
||
# 验证第二个 add 调用是 ACL 条目
|
||
second_add_call = mock_session.add.call_args_list[1]
|
||
acl_obj = second_add_call.args[0]
|
||
assert acl_obj.user_id == "user1"
|
||
assert acl_obj.role == "owner"
|
||
|
||
|
||
class TestKBStoreQuery:
|
||
"""KB 查询测试。"""
|
||
|
||
async def test_get_kb_returns_none_if_not_found(self):
|
||
"""get_kb 返回 None 当 KB 不存在。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
mock_session.execute = _make_mock_execute([])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.get_kb("nonexistent")
|
||
|
||
assert result is None
|
||
|
||
async def test_get_kb_returns_knowledge_base(self):
|
||
"""get_kb 返回 KnowledgeBase 领域模型。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
kb_row = _make_kb_row()
|
||
mock_session.execute = _make_mock_execute([kb_row])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.get_kb("kb-001")
|
||
|
||
assert result is not None
|
||
assert result.id == "kb-001"
|
||
assert result.name == "test-kb"
|
||
|
||
async def test_list_kbs_filters_by_owner(self):
|
||
"""list_kbs 按 owner 过滤。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
kb_rows = [_make_kb_row(id="kb-1"), _make_kb_row(id="kb-2", name="second")]
|
||
mock_session.execute = _make_mock_execute(kb_rows)
|
||
|
||
store = KBStore(sf)
|
||
results = await store.list_kbs(owner="user1")
|
||
|
||
assert len(results) == 2
|
||
assert results[0].id == "kb-1"
|
||
|
||
async def test_list_kbs_returns_empty(self):
|
||
"""list_kbs 空列表正常返回。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
mock_session.execute = _make_mock_execute([])
|
||
|
||
store = KBStore(sf)
|
||
results = await store.list_kbs()
|
||
|
||
assert results == []
|
||
|
||
|
||
class TestKBStoreDelete:
|
||
"""KB 删除测试。"""
|
||
|
||
async def test_delete_kb_returns_true(self):
|
||
"""delete_kb 删除成功返回 True。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
kb_row = _make_kb_row()
|
||
mock_session.execute = _make_mock_execute([kb_row])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.delete_kb("kb-001")
|
||
|
||
assert result is True
|
||
mock_session.delete.assert_awaited_once_with(kb_row)
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_delete_kb_returns_false_if_not_found(self):
|
||
"""delete_kb KB 不存在返回 False。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
mock_session.execute = _make_mock_execute([])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.delete_kb("nonexistent")
|
||
|
||
assert result is False
|
||
|
||
|
||
class TestDocumentOperations:
|
||
"""文档 CRUD 测试。"""
|
||
|
||
async def test_add_document(self):
|
||
"""add_document 创建文档记录。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
|
||
async def _refresh(obj):
|
||
obj.id = "doc-001"
|
||
obj.status = "pending"
|
||
obj.created_at = datetime.now(timezone.utc)
|
||
obj.updated_at = datetime.now(timezone.utc)
|
||
|
||
mock_session.refresh.side_effect = _refresh
|
||
|
||
store = KBStore(sf)
|
||
doc = await store.add_document("kb-001", "test.pdf", "pdf", 1024)
|
||
|
||
assert doc.filename == "test.pdf"
|
||
assert doc.kb_id == "kb-001"
|
||
mock_session.add.assert_called_once()
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_list_documents(self):
|
||
"""list_documents 返回指定 KB 的文档。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
doc_rows = [_make_doc_row(id="d1"), _make_doc_row(id="d2", filename="second.pdf")]
|
||
mock_session.execute = _make_mock_execute(doc_rows)
|
||
|
||
store = KBStore(sf)
|
||
results = await store.list_documents("kb-001")
|
||
|
||
assert len(results) == 2
|
||
assert results[0].id == "d1"
|
||
|
||
async def test_update_document_status(self):
|
||
"""update_document_status 更新文档状态。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
doc_row = _make_doc_row()
|
||
mock_session.execute = _make_mock_execute([doc_row])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.update_document_status("doc-001", DocumentStatus.indexed)
|
||
|
||
assert result is not None
|
||
assert doc_row.status == "indexed"
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_update_document_status_not_found(self):
|
||
"""update_document_status 文档不存在返回 None。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
mock_session.execute = _make_mock_execute([])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.update_document_status("nonexistent", DocumentStatus.failed)
|
||
|
||
assert result is None
|
||
|
||
async def test_delete_document(self):
|
||
"""delete_document 删除成功。"""
|
||
sf, mock_session = _make_mock_session_factory()
|
||
doc_row = _make_doc_row()
|
||
mock_session.execute = _make_mock_execute([doc_row])
|
||
|
||
store = KBStore(sf)
|
||
result = await store.delete_document("doc-001")
|
||
|
||
assert result is True
|
||
mock_session.delete.assert_awaited_once_with(doc_row)
|