"""U2 测试 — Per-KB ACL 逻辑。 测试场景: 1. owner 用户可查询自己的 KB 2. 非 authorized 用户查询 KB 被拒绝 3. Agent 检索时 ACL 过滤生效(仅返回授权 KB 的结果) 4. grant/revoke 操作 """ from __future__ import annotations from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock from agentkit.rag_platform.acl import ( filter_kb_by_user_acl, grant_access, list_acl, revoke_access, ) def _make_mock_session_factory(execute_result=None): """创建 mock session factory。""" mock_session = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() if execute_result is not None: mock_session.execute = execute_result @asynccontextmanager async def factory(): yield mock_session return factory, mock_session def _make_execute_returning_rows(rows): """创建 mock execute 返回指定行(用于 select 查询)。""" mock_result = MagicMock() mock_result.all.return_value = rows return AsyncMock(return_value=mock_result) def _make_execute_returning_scalars(rows): """创建 mock execute 返回 scalars(用于 ORM select)。""" mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = rows mock_result.scalars.return_value = mock_scalars return AsyncMock(return_value=mock_result) class TestFilterKbByUserAcl: """ACL 过滤测试。""" async def test_empty_kb_ids_returns_empty(self): """all_kb_ids 为空时返回空列表。""" sf, _ = _make_mock_session_factory() result = await filter_kb_by_user_acl(sf, "user1", []) assert result == [] async def test_authorized_kb_returned(self): """用户有权限的 KB 被返回。""" mock_execute = _make_execute_returning_rows([("kb-1",), ("kb-2",)]) sf, _ = _make_mock_session_factory(mock_execute) result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"]) assert "kb-1" in result assert "kb-2" in result assert "kb-3" not in result # 未授权 async def test_unauthorized_kb_filtered_out(self): """非 authorized 用户的 KB 被过滤掉。""" mock_execute = _make_execute_returning_rows([]) # 无 ACL 条目 sf, _ = _make_mock_session_factory(mock_execute) result = await filter_kb_by_user_acl(sf, "stranger", ["kb-1", "kb-2"]) assert result == [] async def test_result_is_sorted(self): """返回结果按 ID 排序。""" mock_execute = _make_execute_returning_rows([("kb-3",), ("kb-1",)]) sf, _ = _make_mock_session_factory(mock_execute) result = await filter_kb_by_user_acl(sf, "user1", ["kb-1", "kb-2", "kb-3"]) assert result == ["kb-1", "kb-3"] class TestGrantAccess: """授权测试。""" async def test_grant_access_success(self): """grant_access 成功授权。""" mock_execute = AsyncMock() # insert 不返回结果 sf, mock_session = _make_mock_session_factory(mock_execute) result = await grant_access(sf, "kb-1", "user2", "viewer") assert result is True mock_session.commit.assert_awaited_once() async def test_grant_access_duplicate_returns_false(self): """grant_access 已存在条目返回 False。""" mock_execute = AsyncMock(side_effect=Exception("unique constraint")) sf, mock_session = _make_mock_session_factory(mock_execute) result = await grant_access(sf, "kb-1", "user2", "viewer") assert result is False mock_session.rollback.assert_awaited_once() class TestRevokeAccess: """撤销授权测试。""" async def test_revoke_access_success(self): """revoke_access 成功撤销。""" mock_result = MagicMock() mock_result.rowcount = 1 mock_execute = AsyncMock(return_value=mock_result) sf, mock_session = _make_mock_session_factory(mock_execute) result = await revoke_access(sf, "kb-1", "user2") assert result is True mock_session.commit.assert_awaited_once() async def test_revoke_access_not_found_returns_false(self): """revoke_access 条目不存在返回 False。""" mock_result = MagicMock() mock_result.rowcount = 0 mock_execute = AsyncMock(return_value=mock_result) sf, mock_session = _make_mock_session_factory(mock_execute) result = await revoke_access(sf, "kb-1", "stranger") assert result is False class TestListAcl: """ACL 列表测试。""" async def test_list_acl_returns_entries(self): """list_acl 返回 ACL 条目列表。""" acl_row1 = MagicMock(user_id="user1", role="owner", created_at=None) acl_row2 = MagicMock(user_id="user2", role="viewer", created_at=None) mock_execute = _make_execute_returning_scalars([acl_row1, acl_row2]) sf, _ = _make_mock_session_factory(mock_execute) result = await list_acl(sf, "kb-1") assert len(result) == 2 assert result[0]["user_id"] == "user1" assert result[0]["role"] == "owner" assert result[1]["user_id"] == "user2" assert result[1]["role"] == "viewer" async def test_list_acl_empty(self): """list_acl 无条目返回空列表。""" mock_execute = _make_execute_returning_scalars([]) sf, _ = _make_mock_session_factory(mock_execute) result = await list_acl(sf, "kb-1") assert result == []