feat(rag_platform): U2 — KB persistence + per-KB ACL

Add PostgreSQL-backed KB store replacing in-memory KnowledgeSourceStore:
- models.py: ORM models (KBModel, DocumentModel, KBAclModel) using
  SQLAlchemy 2 DeclarativeBase + Mapped style
- store.py: KBStore with async CRUD for KBs and documents,
  create_kb creates owner ACL in same transaction
- acl.py: filter_kb_by_user_acl(), grant_access(), revoke_access(),
  list_acl() — follows filter_kb_sources_by_department pattern

Schema: rag_platform_kbs, rag_platform_documents, rag_platform_kb_acl
with FK CASCADE on kb_id. UniqueConstraint on (kb_id, user_id).

Tests: 23 unit tests covering KB CRUD, document operations, ACL
filtering, grant/revoke. All 37 rag_platform tests pass.
This commit is contained in:
chiguyong 2026-06-25 11:01:04 +08:00
parent 27d0184392
commit c1a21f57a1
5 changed files with 807 additions and 0 deletions

View File

@ -0,0 +1,109 @@
"""Per-KB 访问控制 — ACL 过滤逻辑。
遵循 filter_kb_sources_by_department 模式实现 filter_kb_by_user_acl()
ACL 变更必须触发检索缓存失效KTD5
"""
from __future__ import annotations
import logging
from sqlalchemy import select
from agentkit.rag_platform.models import KBAclModel
logger = logging.getLogger(__name__)
async def filter_kb_by_user_acl(
session_factory,
user_id: str,
all_kb_ids: list[str],
) -> list[str]:
"""返回用户有权访问的 KB ID 子集。
权限规则用户是 KB owner viewer 时有访问权限
filter_kb_sources_by_department 并行运行 部门过滤和 ACL 过滤取交集
Args:
session_factory: SQLAlchemy async session factory
user_id: 用户 ID
all_kb_ids: 所有已知 KB ID通常来自 KBStore.list_kbs()
Returns:
排序后的授权 KB ID 列表
"""
if not all_kb_ids:
return []
async with session_factory() as db:
stmt = (
select(KBAclModel.kb_id)
.where(KBAclModel.user_id == user_id)
.where(KBAclModel.kb_id.in_(all_kb_ids))
.distinct()
)
result = await db.execute(stmt)
authorized = {row[0] for row in result.all()}
visible = authorized & set(all_kb_ids)
return sorted(visible)
async def grant_access(
session_factory,
kb_id: str,
user_id: str,
role: str = "viewer",
) -> bool:
"""授予用户对 KB 的访问权限。
Args:
role: "owner" "viewer"
Returns:
True 如果授权成功False 如果已存在相同条目
"""
from sqlalchemy import insert
async with session_factory() as db:
stmt = insert(KBAclModel).values(kb_id=kb_id, user_id=user_id, role=role)
try:
await db.execute(stmt)
await db.commit()
logger.info("Granted %s access to kb=%s for user=%s", role, kb_id, user_id)
return True
except Exception: # UniqueConstraint violation — 已存在
await db.rollback()
return False
async def revoke_access(session_factory, kb_id: str, user_id: str) -> bool:
"""撤销用户对 KB 的访问权限。
不会撤销 owner 的权限owner 始终有访问权限
ACL 变更后必须触发检索缓存失效
"""
from sqlalchemy import delete
async with session_factory() as db:
stmt = (
delete(KBAclModel)
.where(KBAclModel.kb_id == kb_id)
.where(KBAclModel.user_id == user_id)
.where(KBAclModel.role != "owner") # 不撤销 owner
)
result = await db.execute(stmt)
await db.commit()
return result.rowcount > 0 # type: ignore[union-attr]
async def list_acl(session_factory, kb_id: str) -> list[dict]:
"""列出 KB 的所有 ACL 条目。"""
async with session_factory() as db:
stmt = select(KBAclModel).where(KBAclModel.kb_id == kb_id)
result = await db.execute(stmt)
return [
{"user_id": row.user_id, "role": row.role, "created_at": row.created_at}
for row in result.scalars().all()
]

View File

@ -104,3 +104,71 @@ class QueryResult(BaseModel):
metadata: dict = Field(default_factory=dict) metadata: dict = Field(default_factory=dict)
document_id: str document_id: str
kb_id: str kb_id: str
# ---------------------------------------------------------------------------
# ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped)
# ---------------------------------------------------------------------------
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint # noqa: E402
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column # noqa: E402
class ORMBase(DeclarativeBase):
"""rag_platform ORM 基类 — 与 memory/models.py 的 declarative_base() 独立。"""
pass
class KBModel(ORMBase):
"""知识库 ORM 模型。"""
__tablename__ = "rag_platform_kbs"
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(Text, default="")
owner: Mapped[str] = mapped_column(String(255), index=True)
status: Mapped[str] = mapped_column(String(32), default="active", index=True)
default_query_mode: Mapped[str] = mapped_column(String(32), default="blend")
default_hit_processing: Mapped[str] = mapped_column(String(32), default="model_opt")
caching_disabled: Mapped[bool] = mapped_column(default=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
class DocumentModel(ORMBase):
"""知识库文档 ORM 模型。"""
__tablename__ = "rag_platform_documents"
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
kb_id: Mapped[str] = mapped_column(
String, ForeignKey("rag_platform_kbs.id", ondelete="CASCADE"), index=True
)
filename: Mapped[str] = mapped_column(String(512))
file_type: Mapped[str] = mapped_column(String(64))
file_size: Mapped[int] = mapped_column(default=0)
status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
class KBAclModel(ORMBase):
"""Per-KB 访问控制 ORM 模型。
kb_id FK rag_platform_kbs.id ON DELETE CASCADEKTD5 不变量
ACL 条目与 KB 元数据共享事务边界同一 PG DB
"""
__tablename__ = "rag_platform_kb_acl"
__table_args__ = (UniqueConstraint("kb_id", "user_id", name="uq_kb_acl_kb_user"),)
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
kb_id: Mapped[str] = mapped_column(
String, ForeignKey("rag_platform_kbs.id", ondelete="CASCADE"), index=True
)
user_id: Mapped[str] = mapped_column(String(255), index=True)
role: Mapped[str] = mapped_column(String(32), default="viewer") # owner | viewer
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)

View File

@ -0,0 +1,172 @@
"""KB 持久化存储 — PostgreSQL 后端,替换内存 KnowledgeSourceStore。
遵循 memory/episodic.py async session 模式
ACL 条目与 KB 元数据共享事务边界同一 PG DB
"""
from __future__ import annotations
import logging
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from agentkit.rag_platform.models import (
Document,
DocumentStatus,
KBStatus,
KnowledgeBase,
KBAclModel,
KBModel,
DocumentModel,
ORMBase,
)
logger = logging.getLogger(__name__)
def create_session_factory(database_url: str):
"""创建 async session factory遵循 memory/models.py 模式)。"""
engine = create_async_engine(database_url, echo=False)
return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def ensure_tables(database_url: str) -> None:
"""创建 rag_platform 表幂等IF NOT EXISTS"""
engine = create_async_engine(database_url, echo=False)
async with engine.begin() as conn:
await conn.run_sync(ORMBase.metadata.create_all)
await engine.dispose()
logger.info("rag_platform tables ensured")
class KBStore:
"""KB/Document 持久化存储PostgreSQL async
替换内存 KnowledgeSourceStore KB 元数据重启不丢失
"""
def __init__(self, session_factory) -> None:
self._sf = session_factory
async def create_kb(
self,
name: str,
owner: str,
description: str = "",
default_query_mode: str = "blend",
default_hit_processing: str = "model_opt",
caching_disabled: bool = False,
) -> KnowledgeBase:
"""创建知识库 + owner ACL 条目(同一事务)。"""
async with self._sf() as db:
kb = KBModel(
name=name,
owner=owner,
description=description,
default_query_mode=default_query_mode,
default_hit_processing=default_hit_processing,
caching_disabled=caching_disabled,
)
db.add(kb)
await db.flush() # 获取 kb.id
acl = KBAclModel(kb_id=kb.id, user_id=owner, role="owner")
db.add(acl)
await db.commit()
await db.refresh(kb)
return KnowledgeBase.model_validate(kb)
async def get_kb(self, kb_id: str) -> KnowledgeBase | None:
async with self._sf() as db:
stmt = select(KBModel).where(KBModel.id == kb_id)
result = await db.execute(stmt)
row = result.scalars().first()
return KnowledgeBase.model_validate(row) if row else None
async def list_kbs(self, owner: str | None = None) -> list[KnowledgeBase]:
async with self._sf() as db:
stmt = select(KBModel).where(KBModel.status == KBStatus.active.value)
if owner:
stmt = stmt.where(KBModel.owner == owner)
stmt = stmt.order_by(KBModel.created_at.desc())
result = await db.execute(stmt)
return [KnowledgeBase.model_validate(r) for r in result.scalars().all()]
async def delete_kb(self, kb_id: str) -> bool:
"""删除 KB — CASCADE 自动删除关联 documents 和 ACL 条目。"""
async with self._sf() as db:
stmt = select(KBModel).where(KBModel.id == kb_id)
result = await db.execute(stmt)
kb = result.scalars().first()
if kb is None:
return False
await db.delete(kb)
await db.commit()
return True
async def add_document(
self,
kb_id: str,
filename: str,
file_type: str,
file_size: int,
) -> Document:
async with self._sf() as db:
doc = DocumentModel(
kb_id=kb_id,
filename=filename,
file_type=file_type,
file_size=file_size,
)
db.add(doc)
await db.commit()
await db.refresh(doc)
return Document.model_validate(doc)
async def get_document(self, document_id: str) -> Document | None:
async with self._sf() as db:
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
result = await db.execute(stmt)
row = result.scalars().first()
return Document.model_validate(row) if row else None
async def list_documents(self, kb_id: str) -> list[Document]:
async with self._sf() as db:
stmt = (
select(DocumentModel)
.where(DocumentModel.kb_id == kb_id)
.order_by(DocumentModel.created_at.desc())
)
result = await db.execute(stmt)
return [Document.model_validate(r) for r in result.scalars().all()]
async def update_document_status(
self,
document_id: str,
status: DocumentStatus,
error_message: str | None = None,
) -> Document | None:
async with self._sf() as db:
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
result = await db.execute(stmt)
doc = result.scalars().first()
if doc is None:
return None
doc.status = status.value
doc.error_message = error_message
await db.commit()
await db.refresh(doc)
return Document.model_validate(doc)
async def delete_document(self, document_id: str) -> bool:
async with self._sf() as db:
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
result = await db.execute(stmt)
doc = result.scalars().first()
if doc is None:
return False
await db.delete(doc)
await db.commit()
return True

View File

@ -0,0 +1,170 @@
"""U2 测试 — Per-KB ACL 逻辑。
测试场景
1. owner 用户可查询自己的 KB
2. authorized 用户查询 KB 被拒绝
3. Agent 检索时 ACL 过滤生效仅返回授权 KB 的结果
4. grant/revoke 操作
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.acl import (
filter_kb_by_user_acl,
grant_access,
list_acl,
revoke_access,
)
def _make_mock_session_factory(execute_result=None):
"""创建 mock session factory。"""
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
if execute_result is not None:
mock_session.execute = execute_result
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
def _make_execute_returning_rows(rows):
"""创建 mock execute 返回指定行(用于 select 查询)。"""
mock_result = MagicMock()
mock_result.all.return_value = rows
return AsyncMock(return_value=mock_result)
def _make_execute_returning_scalars(rows):
"""创建 mock execute 返回 scalars用于 ORM select"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = rows
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
class TestFilterKbByUserAcl:
"""ACL 过滤测试。"""
async def test_empty_kb_ids_returns_empty(self):
"""all_kb_ids 为空时返回空列表。"""
sf, _ = _make_mock_session_factory()
result = await filter_kb_by_user_acl(sf, "user1", [])
assert result == []
async def test_authorized_kb_returned(self):
"""用户有权限的 KB 被返回。"""
mock_execute = _make_execute_returning_rows([("kb-1",), ("kb-2",)])
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
assert "kb-1" in result
assert "kb-2" in result
assert "kb-3" not in result # 未授权
async def test_unauthorized_kb_filtered_out(self):
"""非 authorized 用户的 KB 被过滤掉。"""
mock_execute = _make_execute_returning_rows([]) # 无 ACL 条目
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "stranger", ["kb-1", "kb-2"])
assert result == []
async def test_result_is_sorted(self):
"""返回结果按 ID 排序。"""
mock_execute = _make_execute_returning_rows([("kb-3",), ("kb-1",)])
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
assert result == ["kb-1", "kb-3"]
class TestGrantAccess:
"""授权测试。"""
async def test_grant_access_success(self):
"""grant_access 成功授权。"""
mock_execute = AsyncMock() # insert 不返回结果
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await grant_access(sf, "kb-1", "user2", "viewer")
assert result is True
mock_session.commit.assert_awaited_once()
async def test_grant_access_duplicate_returns_false(self):
"""grant_access 已存在条目返回 False。"""
mock_execute = AsyncMock(side_effect=Exception("unique constraint"))
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await grant_access(sf, "kb-1", "user2", "viewer")
assert result is False
mock_session.rollback.assert_awaited_once()
class TestRevokeAccess:
"""撤销授权测试。"""
async def test_revoke_access_success(self):
"""revoke_access 成功撤销。"""
mock_result = MagicMock()
mock_result.rowcount = 1
mock_execute = AsyncMock(return_value=mock_result)
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await revoke_access(sf, "kb-1", "user2")
assert result is True
mock_session.commit.assert_awaited_once()
async def test_revoke_access_not_found_returns_false(self):
"""revoke_access 条目不存在返回 False。"""
mock_result = MagicMock()
mock_result.rowcount = 0
mock_execute = AsyncMock(return_value=mock_result)
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await revoke_access(sf, "kb-1", "stranger")
assert result is False
class TestListAcl:
"""ACL 列表测试。"""
async def test_list_acl_returns_entries(self):
"""list_acl 返回 ACL 条目列表。"""
acl_row1 = MagicMock(user_id="user1", role="owner", created_at=None)
acl_row2 = MagicMock(user_id="user2", role="viewer", created_at=None)
mock_execute = _make_execute_returning_scalars([acl_row1, acl_row2])
sf, _ = _make_mock_session_factory(mock_execute)
result = await list_acl(sf, "kb-1")
assert len(result) == 2
assert result[0]["user_id"] == "user1"
assert result[0]["role"] == "owner"
assert result[1]["user_id"] == "user2"
assert result[1]["role"] == "viewer"
async def test_list_acl_empty(self):
"""list_acl 无条目返回空列表。"""
mock_execute = _make_execute_returning_scalars([])
sf, _ = _make_mock_session_factory(mock_execute)
result = await list_acl(sf, "kb-1")
assert result == []

View File

@ -0,0 +1,288 @@
"""U2 测试 — KB 持久化存储。
测试场景
1. KB 元数据写入 PG重启后仍存在mock 验证 commit
2. owner 用户可查询自己的 KB
3. delete_kb CASCADE 删除关联数据
4. 文档 CRUD 操作
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.models import (
DocumentStatus,
KnowledgeBase,
)
from agentkit.rag_platform.store import KBStore
def _make_kb_row(**overrides):
"""创建模拟 KBModel 行。"""
defaults = {
"id": "kb-001",
"name": "test-kb",
"description": "",
"owner": "user1",
"status": "active",
"default_query_mode": "blend",
"default_hit_processing": "model_opt",
"caching_disabled": False,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
defaults.update(overrides)
mock = MagicMock()
mock.configure_mock(**defaults)
return mock
def _make_doc_row(**overrides):
"""创建模拟 DocumentModel 行。"""
defaults = {
"id": "doc-001",
"kb_id": "kb-001",
"filename": "test.pdf",
"file_type": "pdf",
"file_size": 1024,
"status": "pending",
"error_message": None,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
defaults.update(overrides)
mock = MagicMock()
mock.configure_mock(**defaults)
return mock
def _make_mock_session_factory():
"""创建 mock session factory返回 (factory, mock_session)。"""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_session.refresh = AsyncMock()
mock_session.delete = AsyncMock()
mock_session.flush = AsyncMock()
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
def _make_mock_execute(result_rows):
"""创建 mock execute返回 scalars().all() = result_rows。"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = result_rows
mock_scalars.first.return_value = result_rows[0] if result_rows else None
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
class TestKBStoreCreate:
"""KB 创建测试。"""
async def test_create_kb_returns_knowledge_base(self):
"""create_kb 返回 KnowledgeBase 领域模型。"""
sf, mock_session = _make_mock_session_factory()
# flush 后 kb.id 可用
def _flush_side_effect():
mock_session._kb_id = "kb-001"
mock_session.flush.side_effect = _flush_side_effect
# refresh 后设置所有必需属性
async def _refresh(obj):
obj.id = "kb-001"
obj.status = "active"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
kb = await store.create_kb(name="test", owner="user1", description="desc")
assert isinstance(kb, KnowledgeBase)
assert kb.name == "test"
assert kb.owner == "user1"
assert kb.description == "desc"
# 验证 add 被调用KB + ACL
assert mock_session.add.call_count == 2
mock_session.commit.assert_awaited_once()
async def test_create_kb_creates_owner_acl(self):
"""create_kb 在同一事务中创建 owner ACL 条目。"""
sf, mock_session = _make_mock_session_factory()
async def _refresh(obj):
obj.id = "kb-001"
obj.status = "active"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
await store.create_kb(name="test", owner="user1")
# 验证第二个 add 调用是 ACL 条目
second_add_call = mock_session.add.call_args_list[1]
acl_obj = second_add_call.args[0]
assert acl_obj.user_id == "user1"
assert acl_obj.role == "owner"
class TestKBStoreQuery:
"""KB 查询测试。"""
async def test_get_kb_returns_none_if_not_found(self):
"""get_kb 返回 None 当 KB 不存在。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.get_kb("nonexistent")
assert result is None
async def test_get_kb_returns_knowledge_base(self):
"""get_kb 返回 KnowledgeBase 领域模型。"""
sf, mock_session = _make_mock_session_factory()
kb_row = _make_kb_row()
mock_session.execute = _make_mock_execute([kb_row])
store = KBStore(sf)
result = await store.get_kb("kb-001")
assert result is not None
assert result.id == "kb-001"
assert result.name == "test-kb"
async def test_list_kbs_filters_by_owner(self):
"""list_kbs 按 owner 过滤。"""
sf, mock_session = _make_mock_session_factory()
kb_rows = [_make_kb_row(id="kb-1"), _make_kb_row(id="kb-2", name="second")]
mock_session.execute = _make_mock_execute(kb_rows)
store = KBStore(sf)
results = await store.list_kbs(owner="user1")
assert len(results) == 2
assert results[0].id == "kb-1"
async def test_list_kbs_returns_empty(self):
"""list_kbs 空列表正常返回。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
results = await store.list_kbs()
assert results == []
class TestKBStoreDelete:
"""KB 删除测试。"""
async def test_delete_kb_returns_true(self):
"""delete_kb 删除成功返回 True。"""
sf, mock_session = _make_mock_session_factory()
kb_row = _make_kb_row()
mock_session.execute = _make_mock_execute([kb_row])
store = KBStore(sf)
result = await store.delete_kb("kb-001")
assert result is True
mock_session.delete.assert_awaited_once_with(kb_row)
mock_session.commit.assert_awaited_once()
async def test_delete_kb_returns_false_if_not_found(self):
"""delete_kb KB 不存在返回 False。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.delete_kb("nonexistent")
assert result is False
class TestDocumentOperations:
"""文档 CRUD 测试。"""
async def test_add_document(self):
"""add_document 创建文档记录。"""
sf, mock_session = _make_mock_session_factory()
async def _refresh(obj):
obj.id = "doc-001"
obj.status = "pending"
obj.created_at = datetime.now(timezone.utc)
obj.updated_at = datetime.now(timezone.utc)
mock_session.refresh.side_effect = _refresh
store = KBStore(sf)
doc = await store.add_document("kb-001", "test.pdf", "pdf", 1024)
assert doc.filename == "test.pdf"
assert doc.kb_id == "kb-001"
mock_session.add.assert_called_once()
mock_session.commit.assert_awaited_once()
async def test_list_documents(self):
"""list_documents 返回指定 KB 的文档。"""
sf, mock_session = _make_mock_session_factory()
doc_rows = [_make_doc_row(id="d1"), _make_doc_row(id="d2", filename="second.pdf")]
mock_session.execute = _make_mock_execute(doc_rows)
store = KBStore(sf)
results = await store.list_documents("kb-001")
assert len(results) == 2
assert results[0].id == "d1"
async def test_update_document_status(self):
"""update_document_status 更新文档状态。"""
sf, mock_session = _make_mock_session_factory()
doc_row = _make_doc_row()
mock_session.execute = _make_mock_execute([doc_row])
store = KBStore(sf)
result = await store.update_document_status("doc-001", DocumentStatus.indexed)
assert result is not None
assert doc_row.status == "indexed"
mock_session.commit.assert_awaited_once()
async def test_update_document_status_not_found(self):
"""update_document_status 文档不存在返回 None。"""
sf, mock_session = _make_mock_session_factory()
mock_session.execute = _make_mock_execute([])
store = KBStore(sf)
result = await store.update_document_status("nonexistent", DocumentStatus.failed)
assert result is None
async def test_delete_document(self):
"""delete_document 删除成功。"""
sf, mock_session = _make_mock_session_factory()
doc_row = _make_doc_row()
mock_session.execute = _make_mock_execute([doc_row])
store = KBStore(sf)
result = await store.delete_document("doc-001")
assert result is True
mock_session.delete.assert_awaited_once_with(doc_row)