fischer-agentkit/tests/unit/rag_platform/test_store.py

289 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U2 测试 — KB 持久化存储。
测试场景:
1. KB 元数据写入 PG重启后仍存在mock 验证 commit
2. owner 用户可查询自己的 KB
3. delete_kb CASCADE 删除关联数据
4. 文档 CRUD 操作
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.models import (
DocumentStatus,
KnowledgeBase,
)
from agentkit.rag_platform.store import KBStore
def _make_kb_row(**overrides):
"""创建模拟 KBModel 行。"""
defaults = {
"id": "kb-001",
"name": "test-kb",
"description": "",
"owner": "user1",
"status": "active",
"default_query_mode": "blend",
"default_hit_processing": "model_opt",
"caching_disabled": False,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
defaults.update(overrides)
mock = MagicMock()
mock.configure_mock(**defaults)
return mock
def _make_doc_row(**overrides):
"""创建模拟 DocumentModel 行。"""
defaults = {
"id": "doc-001",
"kb_id": "kb-001",
"filename": "test.pdf",
"file_type": "pdf",
"file_size": 1024,
"status": "pending",
"error_message": None,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
defaults.update(overrides)
mock = MagicMock()
mock.configure_mock(**defaults)
return mock
def _make_mock_session_factory():
"""创建 mock session factory返回 (factory, mock_session)。"""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_session.refresh = AsyncMock()
mock_session.delete = AsyncMock()
mock_session.flush = AsyncMock()
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
def _make_mock_execute(result_rows):
"""创建 mock execute返回 scalars().all() = result_rows。"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = result_rows
mock_scalars.first.return_value = result_rows[0] if result_rows else None
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
class TestKBStoreCreate:
"""KB 创建测试。"""
async def test_create_kb_returns_knowledge_base(self):
"""create_kb 返回 KnowledgeBase 领域模型。"""
sf, mock_session = _make_mock_session_factory()
# flush 后 kb.id 可用
def _flush_side_effect():
mock_session._kb_id = "kb-001"
mock_session.flush.side_effect = _flush_side_effect
# refresh 后设置所有必需属性
async def _refresh(obj):
obj.id = "kb-001"
obj.status = "active"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
kb = await store.create_kb(name="test", owner="user1", description="desc")
assert isinstance(kb, KnowledgeBase)
assert kb.name == "test"
assert kb.owner == "user1"
assert kb.description == "desc"
# 验证 add 被调用KB + ACL
assert mock_session.add.call_count == 2
mock_session.commit.assert_awaited_once()
async def test_create_kb_creates_owner_acl(self):
"""create_kb 在同一事务中创建 owner ACL 条目。"""
sf, mock_session = _make_mock_session_factory()
async def _refresh(obj):
obj.id = "kb-001"
obj.status = "active"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
await store.create_kb(name="test", owner="user1")
# 验证第二个 add 调用是 ACL 条目
second_add_call = mock_session.add.call_args_list[1]
acl_obj = second_add_call.args[0]
assert acl_obj.user_id == "user1"
assert acl_obj.role == "owner"
class TestKBStoreQuery:
"""KB 查询测试。"""
async def test_get_kb_returns_none_if_not_found(self):
"""get_kb 返回 None 当 KB 不存在。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.get_kb("nonexistent")
assert result is None
async def test_get_kb_returns_knowledge_base(self):
"""get_kb 返回 KnowledgeBase 领域模型。"""
sf, mock_session = _make_mock_session_factory()
kb_row = _make_kb_row()
mock_session.execute = _make_mock_execute([kb_row])
store = KBStore(sf)
result = await store.get_kb("kb-001")
assert result is not None
assert result.id == "kb-001"
assert result.name == "test-kb"
async def test_list_kbs_filters_by_owner(self):
"""list_kbs 按 owner 过滤。"""
sf, mock_session = _make_mock_session_factory()
kb_rows = [_make_kb_row(id="kb-1"), _make_kb_row(id="kb-2", name="second")]
mock_session.execute = _make_mock_execute(kb_rows)
store = KBStore(sf)
results = await store.list_kbs(owner="user1")
assert len(results) == 2
assert results[0].id == "kb-1"
async def test_list_kbs_returns_empty(self):
"""list_kbs 空列表正常返回。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
results = await store.list_kbs()
assert results == []
class TestKBStoreDelete:
"""KB 删除测试。"""
async def test_delete_kb_returns_true(self):
"""delete_kb 删除成功返回 True。"""
sf, mock_session = _make_mock_session_factory()
kb_row = _make_kb_row()
mock_session.execute = _make_mock_execute([kb_row])
store = KBStore(sf)
result = await store.delete_kb("kb-001")
assert result is True
mock_session.delete.assert_awaited_once_with(kb_row)
mock_session.commit.assert_awaited_once()
async def test_delete_kb_returns_false_if_not_found(self):
"""delete_kb KB 不存在返回 False。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.delete_kb("nonexistent")
assert result is False
class TestDocumentOperations:
"""文档 CRUD 测试。"""
async def test_add_document(self):
"""add_document 创建文档记录。"""
sf, mock_session = _make_mock_session_factory()
async def _refresh(obj):
obj.id = "doc-001"
obj.status = "pending"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
doc = await store.add_document("kb-001", "test.pdf", "pdf", 1024)
assert doc.filename == "test.pdf"
assert doc.kb_id == "kb-001"
mock_session.add.assert_called_once()
mock_session.commit.assert_awaited_once()
async def test_list_documents(self):
"""list_documents 返回指定 KB 的文档。"""
sf, mock_session = _make_mock_session_factory()
doc_rows = [_make_doc_row(id="d1"), _make_doc_row(id="d2", filename="second.pdf")]
mock_session.execute = _make_mock_execute(doc_rows)
store = KBStore(sf)
results = await store.list_documents("kb-001")
assert len(results) == 2
assert results[0].id == "d1"
async def test_update_document_status(self):
"""update_document_status 更新文档状态。"""
sf, mock_session = _make_mock_session_factory()
doc_row = _make_doc_row()
mock_session.execute = _make_mock_execute([doc_row])
store = KBStore(sf)
result = await store.update_document_status("doc-001", DocumentStatus.indexed)
assert result is not None
assert doc_row.status == "indexed"
mock_session.commit.assert_awaited_once()
async def test_update_document_status_not_found(self):
"""update_document_status 文档不存在返回 None。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.update_document_status("nonexistent", DocumentStatus.failed)
assert result is None
async def test_delete_document(self):
"""delete_document 删除成功。"""
sf, mock_session = _make_mock_session_factory()
doc_row = _make_doc_row()
mock_session.execute = _make_mock_execute([doc_row])
store = KBStore(sf)
result = await store.delete_document("doc-001")
assert result is True
mock_session.delete.assert_awaited_once_with(doc_row)