feat(rag_platform): U2 — KB persistence + per-KB ACL
Add PostgreSQL-backed KB store replacing in-memory KnowledgeSourceStore: - models.py: ORM models (KBModel, DocumentModel, KBAclModel) using SQLAlchemy 2 DeclarativeBase + Mapped style - store.py: KBStore with async CRUD for KBs and documents, create_kb creates owner ACL in same transaction - acl.py: filter_kb_by_user_acl(), grant_access(), revoke_access(), list_acl() — follows filter_kb_sources_by_department pattern Schema: rag_platform_kbs, rag_platform_documents, rag_platform_kb_acl with FK CASCADE on kb_id. UniqueConstraint on (kb_id, user_id). Tests: 23 unit tests covering KB CRUD, document operations, ACL filtering, grant/revoke. All 37 rag_platform tests pass.
This commit is contained in:
parent
27d0184392
commit
c1a21f57a1
|
|
@ -0,0 +1,109 @@
|
|||
"""Per-KB 访问控制 — ACL 过滤逻辑。
|
||||
|
||||
遵循 filter_kb_sources_by_department 模式,实现 filter_kb_by_user_acl()。
|
||||
ACL 变更必须触发检索缓存失效(KTD5)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from agentkit.rag_platform.models import KBAclModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def filter_kb_by_user_acl(
|
||||
session_factory,
|
||||
user_id: str,
|
||||
all_kb_ids: list[str],
|
||||
) -> list[str]:
|
||||
"""返回用户有权访问的 KB ID 子集。
|
||||
|
||||
权限规则:用户是 KB 的 owner 或 viewer 时有访问权限。
|
||||
与 filter_kb_sources_by_department 并行运行 — 部门过滤和 ACL 过滤取交集。
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy async session factory
|
||||
user_id: 用户 ID
|
||||
all_kb_ids: 所有已知 KB ID(通常来自 KBStore.list_kbs())
|
||||
|
||||
Returns:
|
||||
排序后的授权 KB ID 列表
|
||||
"""
|
||||
if not all_kb_ids:
|
||||
return []
|
||||
|
||||
async with session_factory() as db:
|
||||
stmt = (
|
||||
select(KBAclModel.kb_id)
|
||||
.where(KBAclModel.user_id == user_id)
|
||||
.where(KBAclModel.kb_id.in_(all_kb_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
authorized = {row[0] for row in result.all()}
|
||||
|
||||
visible = authorized & set(all_kb_ids)
|
||||
return sorted(visible)
|
||||
|
||||
|
||||
async def grant_access(
|
||||
session_factory,
|
||||
kb_id: str,
|
||||
user_id: str,
|
||||
role: str = "viewer",
|
||||
) -> bool:
|
||||
"""授予用户对 KB 的访问权限。
|
||||
|
||||
Args:
|
||||
role: "owner" 或 "viewer"
|
||||
|
||||
Returns:
|
||||
True 如果授权成功,False 如果已存在相同条目
|
||||
"""
|
||||
from sqlalchemy import insert
|
||||
|
||||
async with session_factory() as db:
|
||||
stmt = insert(KBAclModel).values(kb_id=kb_id, user_id=user_id, role=role)
|
||||
try:
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
logger.info("Granted %s access to kb=%s for user=%s", role, kb_id, user_id)
|
||||
return True
|
||||
except Exception: # UniqueConstraint violation — 已存在
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
|
||||
async def revoke_access(session_factory, kb_id: str, user_id: str) -> bool:
|
||||
"""撤销用户对 KB 的访问权限。
|
||||
|
||||
不会撤销 owner 的权限(owner 始终有访问权限)。
|
||||
ACL 变更后必须触发检索缓存失效。
|
||||
"""
|
||||
from sqlalchemy import delete
|
||||
|
||||
async with session_factory() as db:
|
||||
stmt = (
|
||||
delete(KBAclModel)
|
||||
.where(KBAclModel.kb_id == kb_id)
|
||||
.where(KBAclModel.user_id == user_id)
|
||||
.where(KBAclModel.role != "owner") # 不撤销 owner
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
|
||||
async def list_acl(session_factory, kb_id: str) -> list[dict]:
|
||||
"""列出 KB 的所有 ACL 条目。"""
|
||||
async with session_factory() as db:
|
||||
stmt = select(KBAclModel).where(KBAclModel.kb_id == kb_id)
|
||||
result = await db.execute(stmt)
|
||||
return [
|
||||
{"user_id": row.user_id, "role": row.role, "created_at": row.created_at}
|
||||
for row in result.scalars().all()
|
||||
]
|
||||
|
|
@ -104,3 +104,71 @@ class QueryResult(BaseModel):
|
|||
metadata: dict = Field(default_factory=dict)
|
||||
document_id: str
|
||||
kb_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint # noqa: E402
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column # noqa: E402
|
||||
|
||||
|
||||
class ORMBase(DeclarativeBase):
|
||||
"""rag_platform ORM 基类 — 与 memory/models.py 的 declarative_base() 独立。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KBModel(ORMBase):
|
||||
"""知识库 ORM 模型。"""
|
||||
|
||||
__tablename__ = "rag_platform_kbs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(Text, default="")
|
||||
owner: Mapped[str] = mapped_column(String(255), index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="active", index=True)
|
||||
default_query_mode: Mapped[str] = mapped_column(String(32), default="blend")
|
||||
default_hit_processing: Mapped[str] = mapped_column(String(32), default="model_opt")
|
||||
caching_disabled: Mapped[bool] = mapped_column(default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
|
||||
|
||||
|
||||
class DocumentModel(ORMBase):
|
||||
"""知识库文档 ORM 模型。"""
|
||||
|
||||
__tablename__ = "rag_platform_documents"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
|
||||
kb_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("rag_platform_kbs.id", ondelete="CASCADE"), index=True
|
||||
)
|
||||
filename: Mapped[str] = mapped_column(String(512))
|
||||
file_type: Mapped[str] = mapped_column(String(64))
|
||||
file_size: Mapped[int] = mapped_column(default=0)
|
||||
status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
|
||||
|
||||
|
||||
class KBAclModel(ORMBase):
|
||||
"""Per-KB 访问控制 ORM 模型。
|
||||
|
||||
kb_id FK → rag_platform_kbs.id ON DELETE CASCADE(KTD5 不变量)。
|
||||
ACL 条目与 KB 元数据共享事务边界(同一 PG DB)。
|
||||
"""
|
||||
|
||||
__tablename__ = "rag_platform_kb_acl"
|
||||
__table_args__ = (UniqueConstraint("kb_id", "user_id", name="uq_kb_acl_kb_user"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=_uuid)
|
||||
kb_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("rag_platform_kbs.id", ondelete="CASCADE"), index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(String(255), index=True)
|
||||
role: Mapped[str] = mapped_column(String(32), default="viewer") # owner | viewer
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,172 @@
|
|||
"""KB 持久化存储 — PostgreSQL 后端,替换内存 KnowledgeSourceStore。
|
||||
|
||||
遵循 memory/episodic.py 的 async session 模式。
|
||||
ACL 条目与 KB 元数据共享事务边界(同一 PG DB)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from agentkit.rag_platform.models import (
|
||||
Document,
|
||||
DocumentStatus,
|
||||
KBStatus,
|
||||
KnowledgeBase,
|
||||
KBAclModel,
|
||||
KBModel,
|
||||
DocumentModel,
|
||||
ORMBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_session_factory(database_url: str):
|
||||
"""创建 async session factory(遵循 memory/models.py 模式)。"""
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def ensure_tables(database_url: str) -> None:
|
||||
"""创建 rag_platform 表(幂等,IF NOT EXISTS)。"""
|
||||
engine = create_async_engine(database_url, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(ORMBase.metadata.create_all)
|
||||
await engine.dispose()
|
||||
logger.info("rag_platform tables ensured")
|
||||
|
||||
|
||||
class KBStore:
|
||||
"""KB/Document 持久化存储(PostgreSQL async)。
|
||||
|
||||
替换内存 KnowledgeSourceStore — KB 元数据重启不丢失。
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
name: str,
|
||||
owner: str,
|
||||
description: str = "",
|
||||
default_query_mode: str = "blend",
|
||||
default_hit_processing: str = "model_opt",
|
||||
caching_disabled: bool = False,
|
||||
) -> KnowledgeBase:
|
||||
"""创建知识库 + owner ACL 条目(同一事务)。"""
|
||||
async with self._sf() as db:
|
||||
kb = KBModel(
|
||||
name=name,
|
||||
owner=owner,
|
||||
description=description,
|
||||
default_query_mode=default_query_mode,
|
||||
default_hit_processing=default_hit_processing,
|
||||
caching_disabled=caching_disabled,
|
||||
)
|
||||
db.add(kb)
|
||||
await db.flush() # 获取 kb.id
|
||||
|
||||
acl = KBAclModel(kb_id=kb.id, user_id=owner, role="owner")
|
||||
db.add(acl)
|
||||
await db.commit()
|
||||
await db.refresh(kb)
|
||||
return KnowledgeBase.model_validate(kb)
|
||||
|
||||
async def get_kb(self, kb_id: str) -> KnowledgeBase | None:
|
||||
async with self._sf() as db:
|
||||
stmt = select(KBModel).where(KBModel.id == kb_id)
|
||||
result = await db.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
return KnowledgeBase.model_validate(row) if row else None
|
||||
|
||||
async def list_kbs(self, owner: str | None = None) -> list[KnowledgeBase]:
|
||||
async with self._sf() as db:
|
||||
stmt = select(KBModel).where(KBModel.status == KBStatus.active.value)
|
||||
if owner:
|
||||
stmt = stmt.where(KBModel.owner == owner)
|
||||
stmt = stmt.order_by(KBModel.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return [KnowledgeBase.model_validate(r) for r in result.scalars().all()]
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除 KB — CASCADE 自动删除关联 documents 和 ACL 条目。"""
|
||||
async with self._sf() as db:
|
||||
stmt = select(KBModel).where(KBModel.id == kb_id)
|
||||
result = await db.execute(stmt)
|
||||
kb = result.scalars().first()
|
||||
if kb is None:
|
||||
return False
|
||||
await db.delete(kb)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
kb_id: str,
|
||||
filename: str,
|
||||
file_type: str,
|
||||
file_size: int,
|
||||
) -> Document:
|
||||
async with self._sf() as db:
|
||||
doc = DocumentModel(
|
||||
kb_id=kb_id,
|
||||
filename=filename,
|
||||
file_type=file_type,
|
||||
file_size=file_size,
|
||||
)
|
||||
db.add(doc)
|
||||
await db.commit()
|
||||
await db.refresh(doc)
|
||||
return Document.model_validate(doc)
|
||||
|
||||
async def get_document(self, document_id: str) -> Document | None:
|
||||
async with self._sf() as db:
|
||||
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
|
||||
result = await db.execute(stmt)
|
||||
row = result.scalars().first()
|
||||
return Document.model_validate(row) if row else None
|
||||
|
||||
async def list_documents(self, kb_id: str) -> list[Document]:
|
||||
async with self._sf() as db:
|
||||
stmt = (
|
||||
select(DocumentModel)
|
||||
.where(DocumentModel.kb_id == kb_id)
|
||||
.order_by(DocumentModel.created_at.desc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return [Document.model_validate(r) for r in result.scalars().all()]
|
||||
|
||||
async def update_document_status(
|
||||
self,
|
||||
document_id: str,
|
||||
status: DocumentStatus,
|
||||
error_message: str | None = None,
|
||||
) -> Document | None:
|
||||
async with self._sf() as db:
|
||||
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
|
||||
result = await db.execute(stmt)
|
||||
doc = result.scalars().first()
|
||||
if doc is None:
|
||||
return None
|
||||
doc.status = status.value
|
||||
doc.error_message = error_message
|
||||
await db.commit()
|
||||
await db.refresh(doc)
|
||||
return Document.model_validate(doc)
|
||||
|
||||
async def delete_document(self, document_id: str) -> bool:
|
||||
async with self._sf() as db:
|
||||
stmt = select(DocumentModel).where(DocumentModel.id == document_id)
|
||||
result = await db.execute(stmt)
|
||||
doc = result.scalars().first()
|
||||
if doc is None:
|
||||
return False
|
||||
await db.delete(doc)
|
||||
await db.commit()
|
||||
return True
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""U2 测试 — Per-KB ACL 逻辑。
|
||||
|
||||
测试场景:
|
||||
1. owner 用户可查询自己的 KB
|
||||
2. 非 authorized 用户查询 KB 被拒绝
|
||||
3. Agent 检索时 ACL 过滤生效(仅返回授权 KB 的结果)
|
||||
4. grant/revoke 操作
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.rag_platform.acl import (
|
||||
filter_kb_by_user_acl,
|
||||
grant_access,
|
||||
list_acl,
|
||||
revoke_access,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_session_factory(execute_result=None):
|
||||
"""创建 mock session factory。"""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
if execute_result is not None:
|
||||
mock_session.execute = execute_result
|
||||
|
||||
@asynccontextmanager
|
||||
async def factory():
|
||||
yield mock_session
|
||||
|
||||
return factory, mock_session
|
||||
|
||||
|
||||
def _make_execute_returning_rows(rows):
|
||||
"""创建 mock execute 返回指定行(用于 select 查询)。"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = rows
|
||||
return AsyncMock(return_value=mock_result)
|
||||
|
||||
|
||||
def _make_execute_returning_scalars(rows):
|
||||
"""创建 mock execute 返回 scalars(用于 ORM select)。"""
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = rows
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
return AsyncMock(return_value=mock_result)
|
||||
|
||||
|
||||
class TestFilterKbByUserAcl:
|
||||
"""ACL 过滤测试。"""
|
||||
|
||||
async def test_empty_kb_ids_returns_empty(self):
|
||||
"""all_kb_ids 为空时返回空列表。"""
|
||||
sf, _ = _make_mock_session_factory()
|
||||
result = await filter_kb_by_user_acl(sf, "user1", [])
|
||||
assert result == []
|
||||
|
||||
async def test_authorized_kb_returned(self):
|
||||
"""用户有权限的 KB 被返回。"""
|
||||
mock_execute = _make_execute_returning_rows([("kb-1",), ("kb-2",)])
|
||||
sf, _ = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
|
||||
|
||||
assert "kb-1" in result
|
||||
assert "kb-2" in result
|
||||
assert "kb-3" not in result # 未授权
|
||||
|
||||
async def test_unauthorized_kb_filtered_out(self):
|
||||
"""非 authorized 用户的 KB 被过滤掉。"""
|
||||
mock_execute = _make_execute_returning_rows([]) # 无 ACL 条目
|
||||
sf, _ = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await filter_kb_by_user_acl(sf, "stranger", ["kb-1", "kb-2"])
|
||||
|
||||
assert result == []
|
||||
|
||||
async def test_result_is_sorted(self):
|
||||
"""返回结果按 ID 排序。"""
|
||||
mock_execute = _make_execute_returning_rows([("kb-3",), ("kb-1",)])
|
||||
sf, _ = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
|
||||
|
||||
assert result == ["kb-1", "kb-3"]
|
||||
|
||||
|
||||
class TestGrantAccess:
|
||||
"""授权测试。"""
|
||||
|
||||
async def test_grant_access_success(self):
|
||||
"""grant_access 成功授权。"""
|
||||
mock_execute = AsyncMock() # insert 不返回结果
|
||||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await grant_access(sf, "kb-1", "user2", "viewer")
|
||||
|
||||
assert result is True
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_grant_access_duplicate_returns_false(self):
|
||||
"""grant_access 已存在条目返回 False。"""
|
||||
mock_execute = AsyncMock(side_effect=Exception("unique constraint"))
|
||||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await grant_access(sf, "kb-1", "user2", "viewer")
|
||||
|
||||
assert result is False
|
||||
mock_session.rollback.assert_awaited_once()
|
||||
|
||||
|
||||
class TestRevokeAccess:
|
||||
"""撤销授权测试。"""
|
||||
|
||||
async def test_revoke_access_success(self):
|
||||
"""revoke_access 成功撤销。"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 1
|
||||
mock_execute = AsyncMock(return_value=mock_result)
|
||||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await revoke_access(sf, "kb-1", "user2")
|
||||
|
||||
assert result is True
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_revoke_access_not_found_returns_false(self):
|
||||
"""revoke_access 条目不存在返回 False。"""
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 0
|
||||
mock_execute = AsyncMock(return_value=mock_result)
|
||||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await revoke_access(sf, "kb-1", "stranger")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestListAcl:
|
||||
"""ACL 列表测试。"""
|
||||
|
||||
async def test_list_acl_returns_entries(self):
|
||||
"""list_acl 返回 ACL 条目列表。"""
|
||||
acl_row1 = MagicMock(user_id="user1", role="owner", created_at=None)
|
||||
acl_row2 = MagicMock(user_id="user2", role="viewer", created_at=None)
|
||||
mock_execute = _make_execute_returning_scalars([acl_row1, acl_row2])
|
||||
sf, _ = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await list_acl(sf, "kb-1")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["user_id"] == "user1"
|
||||
assert result[0]["role"] == "owner"
|
||||
assert result[1]["user_id"] == "user2"
|
||||
assert result[1]["role"] == "viewer"
|
||||
|
||||
async def test_list_acl_empty(self):
|
||||
"""list_acl 无条目返回空列表。"""
|
||||
mock_execute = _make_execute_returning_scalars([])
|
||||
sf, _ = _make_mock_session_factory(mock_execute)
|
||||
|
||||
result = await list_acl(sf, "kb-1")
|
||||
|
||||
assert result == []
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
"""U2 测试 — KB 持久化存储。
|
||||
|
||||
测试场景:
|
||||
1. KB 元数据写入 PG,重启后仍存在(mock 验证 commit)
|
||||
2. owner 用户可查询自己的 KB
|
||||
3. delete_kb CASCADE 删除关联数据
|
||||
4. 文档 CRUD 操作
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.rag_platform.models import (
|
||||
DocumentStatus,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from agentkit.rag_platform.store import KBStore
|
||||
|
||||
|
||||
def _make_kb_row(**overrides):
|
||||
"""创建模拟 KBModel 行。"""
|
||||
defaults = {
|
||||
"id": "kb-001",
|
||||
"name": "test-kb",
|
||||
"description": "",
|
||||
"owner": "user1",
|
||||
"status": "active",
|
||||
"default_query_mode": "blend",
|
||||
"default_hit_processing": "model_opt",
|
||||
"caching_disabled": False,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock()
|
||||
mock.configure_mock(**defaults)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_doc_row(**overrides):
|
||||
"""创建模拟 DocumentModel 行。"""
|
||||
defaults = {
|
||||
"id": "doc-001",
|
||||
"kb_id": "kb-001",
|
||||
"filename": "test.pdf",
|
||||
"file_type": "pdf",
|
||||
"file_size": 1024,
|
||||
"status": "pending",
|
||||
"error_message": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
mock = MagicMock()
|
||||
mock.configure_mock(**defaults)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_mock_session_factory():
|
||||
"""创建 mock session factory,返回 (factory, mock_session)。"""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
mock_session.delete = AsyncMock()
|
||||
mock_session.flush = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def factory():
|
||||
yield mock_session
|
||||
|
||||
return factory, mock_session
|
||||
|
||||
|
||||
def _make_mock_execute(result_rows):
|
||||
"""创建 mock execute,返回 scalars().all() = result_rows。"""
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = result_rows
|
||||
mock_scalars.first.return_value = result_rows[0] if result_rows else None
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
return AsyncMock(return_value=mock_result)
|
||||
|
||||
|
||||
class TestKBStoreCreate:
|
||||
"""KB 创建测试。"""
|
||||
|
||||
async def test_create_kb_returns_knowledge_base(self):
|
||||
"""create_kb 返回 KnowledgeBase 领域模型。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
|
||||
# flush 后 kb.id 可用
|
||||
def _flush_side_effect():
|
||||
mock_session._kb_id = "kb-001"
|
||||
|
||||
mock_session.flush.side_effect = _flush_side_effect
|
||||
|
||||
# refresh 后设置所有必需属性
|
||||
async def _refresh(obj):
|
||||
obj.id = "kb-001"
|
||||
obj.status = "active"
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
obj.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
mock_session.refresh.side_effect = _refresh
|
||||
|
||||
store = KBStore(sf)
|
||||
kb = await store.create_kb(name="test", owner="user1", description="desc")
|
||||
|
||||
assert isinstance(kb, KnowledgeBase)
|
||||
assert kb.name == "test"
|
||||
assert kb.owner == "user1"
|
||||
assert kb.description == "desc"
|
||||
# 验证 add 被调用(KB + ACL)
|
||||
assert mock_session.add.call_count == 2
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_create_kb_creates_owner_acl(self):
|
||||
"""create_kb 在同一事务中创建 owner ACL 条目。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
|
||||
async def _refresh(obj):
|
||||
obj.id = "kb-001"
|
||||
obj.status = "active"
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
obj.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
mock_session.refresh.side_effect = _refresh
|
||||
|
||||
store = KBStore(sf)
|
||||
await store.create_kb(name="test", owner="user1")
|
||||
|
||||
# 验证第二个 add 调用是 ACL 条目
|
||||
second_add_call = mock_session.add.call_args_list[1]
|
||||
acl_obj = second_add_call.args[0]
|
||||
assert acl_obj.user_id == "user1"
|
||||
assert acl_obj.role == "owner"
|
||||
|
||||
|
||||
class TestKBStoreQuery:
|
||||
"""KB 查询测试。"""
|
||||
|
||||
async def test_get_kb_returns_none_if_not_found(self):
|
||||
"""get_kb 返回 None 当 KB 不存在。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
mock_session.execute = _make_mock_execute([])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.get_kb("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_kb_returns_knowledge_base(self):
|
||||
"""get_kb 返回 KnowledgeBase 领域模型。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
kb_row = _make_kb_row()
|
||||
mock_session.execute = _make_mock_execute([kb_row])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.get_kb("kb-001")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "kb-001"
|
||||
assert result.name == "test-kb"
|
||||
|
||||
async def test_list_kbs_filters_by_owner(self):
|
||||
"""list_kbs 按 owner 过滤。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
kb_rows = [_make_kb_row(id="kb-1"), _make_kb_row(id="kb-2", name="second")]
|
||||
mock_session.execute = _make_mock_execute(kb_rows)
|
||||
|
||||
store = KBStore(sf)
|
||||
results = await store.list_kbs(owner="user1")
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "kb-1"
|
||||
|
||||
async def test_list_kbs_returns_empty(self):
|
||||
"""list_kbs 空列表正常返回。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
mock_session.execute = _make_mock_execute([])
|
||||
|
||||
store = KBStore(sf)
|
||||
results = await store.list_kbs()
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestKBStoreDelete:
|
||||
"""KB 删除测试。"""
|
||||
|
||||
async def test_delete_kb_returns_true(self):
|
||||
"""delete_kb 删除成功返回 True。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
kb_row = _make_kb_row()
|
||||
mock_session.execute = _make_mock_execute([kb_row])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.delete_kb("kb-001")
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_awaited_once_with(kb_row)
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_delete_kb_returns_false_if_not_found(self):
|
||||
"""delete_kb KB 不存在返回 False。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
mock_session.execute = _make_mock_execute([])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.delete_kb("nonexistent")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDocumentOperations:
|
||||
"""文档 CRUD 测试。"""
|
||||
|
||||
async def test_add_document(self):
|
||||
"""add_document 创建文档记录。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
|
||||
async def _refresh(obj):
|
||||
obj.id = "doc-001"
|
||||
obj.status = "pending"
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
obj.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
mock_session.refresh.side_effect = _refresh
|
||||
|
||||
store = KBStore(sf)
|
||||
doc = await store.add_document("kb-001", "test.pdf", "pdf", 1024)
|
||||
|
||||
assert doc.filename == "test.pdf"
|
||||
assert doc.kb_id == "kb-001"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_list_documents(self):
|
||||
"""list_documents 返回指定 KB 的文档。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
doc_rows = [_make_doc_row(id="d1"), _make_doc_row(id="d2", filename="second.pdf")]
|
||||
mock_session.execute = _make_mock_execute(doc_rows)
|
||||
|
||||
store = KBStore(sf)
|
||||
results = await store.list_documents("kb-001")
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "d1"
|
||||
|
||||
async def test_update_document_status(self):
|
||||
"""update_document_status 更新文档状态。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
doc_row = _make_doc_row()
|
||||
mock_session.execute = _make_mock_execute([doc_row])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.update_document_status("doc-001", DocumentStatus.indexed)
|
||||
|
||||
assert result is not None
|
||||
assert doc_row.status == "indexed"
|
||||
mock_session.commit.assert_awaited_once()
|
||||
|
||||
async def test_update_document_status_not_found(self):
|
||||
"""update_document_status 文档不存在返回 None。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
mock_session.execute = _make_mock_execute([])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.update_document_status("nonexistent", DocumentStatus.failed)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_delete_document(self):
|
||||
"""delete_document 删除成功。"""
|
||||
sf, mock_session = _make_mock_session_factory()
|
||||
doc_row = _make_doc_row()
|
||||
mock_session.execute = _make_mock_execute([doc_row])
|
||||
|
||||
store = KBStore(sf)
|
||||
result = await store.delete_document("doc-001")
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_awaited_once_with(doc_row)
|
||||
Loading…
Reference in New Issue