fischer-agentkit/tests/unit/rag_platform/test_acl.py

171 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U2 测试 — Per-KB ACL 逻辑。
测试场景:
1. owner 用户可查询自己的 KB
2. 非 authorized 用户查询 KB 被拒绝
3. Agent 检索时 ACL 过滤生效(仅返回授权 KB 的结果)
4. grant/revoke 操作
"""
from __future__ import annotations
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.acl import (
filter_kb_by_user_acl,
grant_access,
list_acl,
revoke_access,
)
def _make_mock_session_factory(execute_result=None):
"""创建 mock session factory。"""
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
if execute_result is not None:
mock_session.execute = execute_result
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
def _make_execute_returning_rows(rows):
"""创建 mock execute 返回指定行(用于 select 查询)。"""
mock_result = MagicMock()
mock_result.all.return_value = rows
return AsyncMock(return_value=mock_result)
def _make_execute_returning_scalars(rows):
"""创建 mock execute 返回 scalars用于 ORM select"""
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = rows
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
class TestFilterKbByUserAcl:
"""ACL 过滤测试。"""
async def test_empty_kb_ids_returns_empty(self):
"""all_kb_ids 为空时返回空列表。"""
sf, _ = _make_mock_session_factory()
result = await filter_kb_by_user_acl(sf, "user1", [])
assert result == []
async def test_authorized_kb_returned(self):
"""用户有权限的 KB 被返回。"""
mock_execute = _make_execute_returning_rows([("kb-1",), ("kb-2",)])
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
assert "kb-1" in result
assert "kb-2" in result
assert "kb-3" not in result # 未授权
async def test_unauthorized_kb_filtered_out(self):
"""非 authorized 用户的 KB 被过滤掉。"""
mock_execute = _make_execute_returning_rows([]) # 无 ACL 条目
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "stranger", ["kb-1", "kb-2"])
assert result == []
async def test_result_is_sorted(self):
"""返回结果按 ID 排序。"""
mock_execute = _make_execute_returning_rows([("kb-3",), ("kb-1",)])
sf, _ = _make_mock_session_factory(mock_execute)
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
assert result == ["kb-1", "kb-3"]
class TestGrantAccess:
"""授权测试。"""
async def test_grant_access_success(self):
"""grant_access 成功授权。"""
mock_execute = AsyncMock() # insert 不返回结果
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await grant_access(sf, "kb-1", "user2", "viewer")
assert result is True
mock_session.commit.assert_awaited_once()
async def test_grant_access_duplicate_returns_false(self):
"""grant_access 已存在条目返回 False。"""
mock_execute = AsyncMock(side_effect=Exception("unique constraint"))
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await grant_access(sf, "kb-1", "user2", "viewer")
assert result is False
mock_session.rollback.assert_awaited_once()
class TestRevokeAccess:
"""撤销授权测试。"""
async def test_revoke_access_success(self):
"""revoke_access 成功撤销。"""
mock_result = MagicMock()
mock_result.rowcount = 1
mock_execute = AsyncMock(return_value=mock_result)
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await revoke_access(sf, "kb-1", "user2")
assert result is True
mock_session.commit.assert_awaited_once()
async def test_revoke_access_not_found_returns_false(self):
"""revoke_access 条目不存在返回 False。"""
mock_result = MagicMock()
mock_result.rowcount = 0
mock_execute = AsyncMock(return_value=mock_result)
sf, mock_session = _make_mock_session_factory(mock_execute)
result = await revoke_access(sf, "kb-1", "stranger")
assert result is False
class TestListAcl:
"""ACL 列表测试。"""
async def test_list_acl_returns_entries(self):
"""list_acl 返回 ACL 条目列表。"""
acl_row1 = MagicMock(user_id="user1", role="owner", created_at=None)
acl_row2 = MagicMock(user_id="user2", role="viewer", created_at=None)
mock_execute = _make_execute_returning_scalars([acl_row1, acl_row2])
sf, _ = _make_mock_session_factory(mock_execute)
result = await list_acl(sf, "kb-1")
assert len(result) == 2
assert result[0]["user_id"] == "user1"
assert result[0]["role"] == "owner"
assert result[1]["user_id"] == "user2"
assert result[1]["role"] == "viewer"
async def test_list_acl_empty(self):
"""list_acl 无条目返回空列表。"""
mock_execute = _make_execute_returning_scalars([])
sf, _ = _make_mock_session_factory(mock_execute)
result = await list_acl(sf, "kb-1")
assert result == []