171 lines
5.5 KiB
Python
171 lines
5.5 KiB
Python
"""U2 测试 — Per-KB ACL 逻辑。
|
||
|
||
测试场景:
|
||
1. owner 用户可查询自己的 KB
|
||
2. 非 authorized 用户查询 KB 被拒绝
|
||
3. Agent 检索时 ACL 过滤生效(仅返回授权 KB 的结果)
|
||
4. grant/revoke 操作
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.acl import (
|
||
filter_kb_by_user_acl,
|
||
grant_access,
|
||
list_acl,
|
||
revoke_access,
|
||
)
|
||
|
||
|
||
def _make_mock_session_factory(execute_result=None):
|
||
"""创建 mock session factory。"""
|
||
mock_session = AsyncMock()
|
||
mock_session.commit = AsyncMock()
|
||
mock_session.rollback = AsyncMock()
|
||
|
||
if execute_result is not None:
|
||
mock_session.execute = execute_result
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
return factory, mock_session
|
||
|
||
|
||
def _make_execute_returning_rows(rows):
|
||
"""创建 mock execute 返回指定行(用于 select 查询)。"""
|
||
mock_result = MagicMock()
|
||
mock_result.all.return_value = rows
|
||
return AsyncMock(return_value=mock_result)
|
||
|
||
|
||
def _make_execute_returning_scalars(rows):
|
||
"""创建 mock execute 返回 scalars(用于 ORM select)。"""
|
||
mock_result = MagicMock()
|
||
mock_scalars = MagicMock()
|
||
mock_scalars.all.return_value = rows
|
||
mock_result.scalars.return_value = mock_scalars
|
||
return AsyncMock(return_value=mock_result)
|
||
|
||
|
||
class TestFilterKbByUserAcl:
|
||
"""ACL 过滤测试。"""
|
||
|
||
async def test_empty_kb_ids_returns_empty(self):
|
||
"""all_kb_ids 为空时返回空列表。"""
|
||
sf, _ = _make_mock_session_factory()
|
||
result = await filter_kb_by_user_acl(sf, "user1", [])
|
||
assert result == []
|
||
|
||
async def test_authorized_kb_returned(self):
|
||
"""用户有权限的 KB 被返回。"""
|
||
mock_execute = _make_execute_returning_rows([("kb-1",), ("kb-2",)])
|
||
sf, _ = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
|
||
|
||
assert "kb-1" in result
|
||
assert "kb-2" in result
|
||
assert "kb-3" not in result # 未授权
|
||
|
||
async def test_unauthorized_kb_filtered_out(self):
|
||
"""非 authorized 用户的 KB 被过滤掉。"""
|
||
mock_execute = _make_execute_returning_rows([]) # 无 ACL 条目
|
||
sf, _ = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await filter_kb_by_user_acl(sf, "stranger", ["kb-1", "kb-2"])
|
||
|
||
assert result == []
|
||
|
||
async def test_result_is_sorted(self):
|
||
"""返回结果按 ID 排序。"""
|
||
mock_execute = _make_execute_returning_rows([("kb-3",), ("kb-1",)])
|
||
sf, _ = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"])
|
||
|
||
assert result == ["kb-1", "kb-3"]
|
||
|
||
|
||
class TestGrantAccess:
|
||
"""授权测试。"""
|
||
|
||
async def test_grant_access_success(self):
|
||
"""grant_access 成功授权。"""
|
||
mock_execute = AsyncMock() # insert 不返回结果
|
||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await grant_access(sf, "kb-1", "user2", "viewer")
|
||
|
||
assert result is True
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_grant_access_duplicate_returns_false(self):
|
||
"""grant_access 已存在条目返回 False。"""
|
||
mock_execute = AsyncMock(side_effect=Exception("unique constraint"))
|
||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await grant_access(sf, "kb-1", "user2", "viewer")
|
||
|
||
assert result is False
|
||
mock_session.rollback.assert_awaited_once()
|
||
|
||
|
||
class TestRevokeAccess:
|
||
"""撤销授权测试。"""
|
||
|
||
async def test_revoke_access_success(self):
|
||
"""revoke_access 成功撤销。"""
|
||
mock_result = MagicMock()
|
||
mock_result.rowcount = 1
|
||
mock_execute = AsyncMock(return_value=mock_result)
|
||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await revoke_access(sf, "kb-1", "user2")
|
||
|
||
assert result is True
|
||
mock_session.commit.assert_awaited_once()
|
||
|
||
async def test_revoke_access_not_found_returns_false(self):
|
||
"""revoke_access 条目不存在返回 False。"""
|
||
mock_result = MagicMock()
|
||
mock_result.rowcount = 0
|
||
mock_execute = AsyncMock(return_value=mock_result)
|
||
sf, mock_session = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await revoke_access(sf, "kb-1", "stranger")
|
||
|
||
assert result is False
|
||
|
||
|
||
class TestListAcl:
|
||
"""ACL 列表测试。"""
|
||
|
||
async def test_list_acl_returns_entries(self):
|
||
"""list_acl 返回 ACL 条目列表。"""
|
||
acl_row1 = MagicMock(user_id="user1", role="owner", created_at=None)
|
||
acl_row2 = MagicMock(user_id="user2", role="viewer", created_at=None)
|
||
mock_execute = _make_execute_returning_scalars([acl_row1, acl_row2])
|
||
sf, _ = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await list_acl(sf, "kb-1")
|
||
|
||
assert len(result) == 2
|
||
assert result[0]["user_id"] == "user1"
|
||
assert result[0]["role"] == "owner"
|
||
assert result[1]["user_id"] == "user2"
|
||
assert result[1]["role"] == "viewer"
|
||
|
||
async def test_list_acl_empty(self):
|
||
"""list_acl 无条目返回空列表。"""
|
||
mock_execute = _make_execute_returning_scalars([])
|
||
sf, _ = _make_mock_session_factory(mock_execute)
|
||
|
||
result = await list_acl(sf, "kb-1")
|
||
|
||
assert result == []
|