test(admin): U10 — E2E + security isolation + quota enforcement tests
23 integration tests across 3 files: - test_e2e_admin_flow: 5 end-to-end lifecycle tests (department, user, LLM config, skill management, usage dashboard) - test_security_isolation: 7 department isolation tests + non-admin 403 tests (cross-dept skill/KB access, multi-dept union, admin sees all, removed user loses access, disabled dept, API key client) - test_quota_enforcement: 10 quota tests (token/cost/whitelist limits, multi-dept strictest-wins, real gateway integration, usage recording) 418 admin tests pass, no regressions.
This commit is contained in:
parent
e5a92427a4
commit
5e977539c7
|
|
@ -0,0 +1,560 @@
|
|||
"""End-to-end admin workflow integration tests (U10).
|
||||
|
||||
Verifies complete admin workflows end-to-end via the FastAPI TestClient:
|
||||
|
||||
- Department lifecycle: create → update → bind skill → bind KB →
|
||||
disable → enable → unbind skill → delete.
|
||||
- User lifecycle: create → assign department → reset password →
|
||||
remove department → soft delete.
|
||||
- LLM config lifecycle: add provider → set API key → set fallback →
|
||||
set quota → list → delete provider.
|
||||
- Skill management flow: disable skill → verify excluded from
|
||||
``GET /skills`` → enable → verify included.
|
||||
- Usage dashboard flow: query usage endpoints with empty data →
|
||||
verify empty results + CSV header.
|
||||
|
||||
The tests mount a minimal FastAPI app with the ``admin_router`` (and
|
||||
the public ``skills`` router for the skill-management flow). The
|
||||
``_require_admin`` dependency is overridden so the tests don't need
|
||||
real JWTs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.server.admin.context import DepartmentContext
|
||||
from agentkit.server.admin.kb_service import set_kb_service
|
||||
from agentkit.server.admin.llm_config_service import (
|
||||
LlmConfigService,
|
||||
set_llm_config_service,
|
||||
)
|
||||
from agentkit.server.admin.skill_service import set_skill_service
|
||||
from agentkit.server.admin.usage_service import set_usage_service
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.server.auth.session_service import SessionService, set_session_service
|
||||
from agentkit.server.routes import admin as admin_routes_module
|
||||
from agentkit.server.routes import skills as skills_routes_module
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_VALID_SKILL_YAML = """\
|
||||
name: e2e_test_skill
|
||||
agent_type: simple_generation
|
||||
version: "1.0.0"
|
||||
description: "E2E test skill"
|
||||
task_mode: llm_generate
|
||||
execution_mode: direct
|
||||
max_steps: 1
|
||||
prompt:
|
||||
identity: "E2E"
|
||||
instructions: "Handle test"
|
||||
tools: []
|
||||
"""
|
||||
|
||||
|
||||
def _sample_agentkit_config() -> dict[str, Any]:
|
||||
"""A minimal agentkit.yaml-style config for testing."""
|
||||
return {
|
||||
"server": {"host": "0.0.0.0", "port": 8001},
|
||||
"llm": {
|
||||
"providers": {
|
||||
"openai": {
|
||||
"type": "openai",
|
||||
"api_key": "sk-test-12345678",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"models": {"gpt-4o": {}},
|
||||
"max_tokens": 4096,
|
||||
"timeout": 120.0,
|
||||
},
|
||||
},
|
||||
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
||||
"fallbacks": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "e2e_admin.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_agentkit_yaml(tmp_path: Path) -> Path:
|
||||
"""Create a temporary agentkit.yaml config file."""
|
||||
path = tmp_path / "agentkit.yaml"
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(_sample_agentkit_config(), f, default_flow_style=False, allow_unicode=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_skills_dir(tmp_path: Path) -> str:
|
||||
"""A temp skills directory for YAML files."""
|
||||
d = tmp_path / "skills"
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skill_registry() -> SkillRegistry:
|
||||
return SkillRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_service(tmp_auth_db: Path):
|
||||
"""Install a SessionService singleton backed by the temp DB.
|
||||
|
||||
Required so that ``UserService.reset_password`` can find the
|
||||
SessionService via ``get_session_service()`` and revoke sessions.
|
||||
"""
|
||||
svc = SessionService(db_path=tmp_auth_db)
|
||||
set_session_service(svc)
|
||||
yield svc
|
||||
set_session_service(None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons():
|
||||
"""Reset singletons before/after each test to avoid state leakage."""
|
||||
set_llm_config_service(None)
|
||||
set_skill_service(None)
|
||||
set_kb_service(None)
|
||||
set_usage_service(None)
|
||||
yield
|
||||
set_llm_config_service(None)
|
||||
set_skill_service(None)
|
||||
set_kb_service(None)
|
||||
set_usage_service(None)
|
||||
|
||||
|
||||
def _make_admin_user() -> dict[str, Any]:
|
||||
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||
|
||||
|
||||
def _raise_forbidden() -> dict[str, Any]:
|
||||
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_app(
|
||||
tmp_auth_db: Path,
|
||||
tmp_agentkit_yaml: Path,
|
||||
tmp_skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
) -> FastAPI:
|
||||
"""A FastAPI app with admin + skills routers mounted.
|
||||
|
||||
The ``_require_admin`` dependency is overridden to return a fake
|
||||
admin user. The :class:`LlmConfigService` singleton is pre-populated
|
||||
so the routes use the temp YAML file.
|
||||
"""
|
||||
set_llm_config_service(LlmConfigService(tmp_agentkit_yaml))
|
||||
|
||||
app = FastAPI()
|
||||
app.state.auth_db_path = str(tmp_auth_db)
|
||||
app.state.skill_registry = skill_registry
|
||||
|
||||
class _FakeServerConfig:
|
||||
skill_paths = [tmp_skills_dir]
|
||||
|
||||
app.state.server_config = _FakeServerConfig()
|
||||
|
||||
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||
app.include_router(skills_routes_module.router, prefix="/api/v1")
|
||||
|
||||
# Default: allow admin access.
|
||||
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||
# Admin context for skills route (bypass department filtering).
|
||||
app.dependency_overrides[skills_routes_module.get_department_context] = lambda: (
|
||||
DepartmentContext(user_id="admin-1", department_ids=[], is_admin=True)
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(
|
||||
admin_app: FastAPI,
|
||||
session_service: SessionService,
|
||||
) -> TestClient:
|
||||
"""TestClient with admin access and SessionService installed."""
|
||||
return TestClient(admin_app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_department(client: TestClient, name: str, description: str = "") -> dict:
|
||||
resp = client.post(
|
||||
"/api/v1/admin/departments",
|
||||
json={"name": name, "description": description},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _create_user(
|
||||
client: TestClient,
|
||||
*,
|
||||
username: str,
|
||||
email: str,
|
||||
password: str = "Secret123!",
|
||||
role: str = "member",
|
||||
department_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
payload: dict[str, Any] = {
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
"role": role,
|
||||
}
|
||||
if department_ids is not None:
|
||||
payload["department_ids"] = department_ids
|
||||
resp = client.post("/api/v1/admin/users", json=payload)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _write_skill_yaml(skills_dir: str, name: str, content: str) -> str:
|
||||
path = os.path.join(skills_dir, f"{name}.yaml")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E admin flow tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestE2EAdminFlow:
|
||||
"""End-to-end admin workflow test."""
|
||||
|
||||
def test_full_department_lifecycle(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
tmp_skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
"""Create department → update → bind skill → bind KB → disable →
|
||||
enable → unbind skill → delete."""
|
||||
# 1. Create department "HR".
|
||||
hr = _create_department(admin_client, "HR", "Human Resources")
|
||||
dept_id = hr["id"]
|
||||
assert hr["name"] == "HR"
|
||||
assert hr["is_active"] is True
|
||||
|
||||
# 2. GET /admin/departments → HR in list.
|
||||
resp = admin_client.get("/api/v1/admin/departments")
|
||||
assert resp.status_code == 200
|
||||
names = {d["name"] for d in resp.json()}
|
||||
assert "HR" in names
|
||||
|
||||
# 3. PATCH /admin/departments/{id} → update name to "Human Resources".
|
||||
resp = admin_client.patch(
|
||||
f"/api/v1/admin/departments/{dept_id}",
|
||||
json={"name": "Human Resources", "description": "HR department"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Human Resources"
|
||||
|
||||
# 4. POST /admin/departments/{id}/skills/code_reviewer → bind skill.
|
||||
resp = admin_client.post(f"/api/v1/admin/departments/{dept_id}/skills/code_reviewer")
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["skill_name"] == "code_reviewer"
|
||||
|
||||
# 5. GET /admin/departments/{id}/skills → ["code_reviewer"].
|
||||
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/skills")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == ["code_reviewer"]
|
||||
|
||||
# 6. POST /admin/departments/{id}/kb/source-1 → bind KB.
|
||||
resp = admin_client.post(f"/api/v1/admin/departments/{dept_id}/kb/source-1")
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["kb_source_id"] == "source-1"
|
||||
|
||||
# 7. GET /admin/departments/{id}/kb → ["source-1"].
|
||||
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/kb")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == ["source-1"]
|
||||
|
||||
# 8. POST /admin/departments/{id}/disable → disabled.
|
||||
resp = admin_client.post(f"/api/v1/admin/departments/{dept_id}/disable")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["is_active"] is False
|
||||
|
||||
# 9. POST /admin/departments/{id}/enable → enabled.
|
||||
resp = admin_client.post(f"/api/v1/admin/departments/{dept_id}/enable")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["is_active"] is True
|
||||
|
||||
# 10. DELETE /admin/departments/{id}/skills/code_reviewer → unbound.
|
||||
resp = admin_client.delete(f"/api/v1/admin/departments/{dept_id}/skills/code_reviewer")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"unbound": True}
|
||||
assert admin_client.get(f"/api/v1/admin/departments/{dept_id}/skills").json() == []
|
||||
|
||||
# 11. DELETE /admin/departments/{id} → deleted.
|
||||
resp = admin_client.delete(f"/api/v1/admin/departments/{dept_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
assert admin_client.get(f"/api/v1/admin/departments/{dept_id}").status_code == 404
|
||||
|
||||
def test_full_user_lifecycle(self, admin_client: TestClient):
|
||||
"""Create user → assign department → reset password → remove
|
||||
department → delete."""
|
||||
# 1. Create department "Engineering".
|
||||
eng = _create_department(admin_client, "Engineering")
|
||||
eng_id = eng["id"]
|
||||
|
||||
# 2. POST /admin/users → create user "alice".
|
||||
alice = _create_user(
|
||||
admin_client,
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
)
|
||||
alice_id = alice["id"]
|
||||
assert alice["username"] == "alice"
|
||||
assert alice["departments"] == []
|
||||
|
||||
# 3. POST /admin/users/{id}/departments/{dept_id} → assign.
|
||||
resp = admin_client.post(f"/api/v1/admin/users/{alice_id}/departments/{eng_id}")
|
||||
assert resp.status_code == 201
|
||||
assert resp.json() == {"assigned": True}
|
||||
|
||||
# 4. GET /admin/users/{id} → has department.
|
||||
resp = admin_client.get(f"/api/v1/admin/users/{alice_id}")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["departments"]) == 1
|
||||
assert body["departments"][0]["name"] == "Engineering"
|
||||
|
||||
# 5. POST /admin/users/{id}/reset-password → reset.
|
||||
resp = admin_client.post(
|
||||
f"/api/v1/admin/users/{alice_id}/reset-password",
|
||||
json={"new_password": "NewSecret456!"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"reset": True}
|
||||
|
||||
# 6. DELETE /admin/users/{id}/departments/{dept_id} → remove.
|
||||
resp = admin_client.delete(f"/api/v1/admin/users/{alice_id}/departments/{eng_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"removed": True}
|
||||
|
||||
# Verify removal.
|
||||
resp = admin_client.get(f"/api/v1/admin/users/{alice_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["departments"] == []
|
||||
|
||||
# 7. DELETE /admin/users/{id} → soft delete.
|
||||
resp = admin_client.delete(f"/api/v1/admin/users/{alice_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
# Second delete on the now-inactive user returns 404.
|
||||
resp = admin_client.delete(f"/api/v1/admin/users/{alice_id}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_llm_config_lifecycle(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
tmp_agentkit_yaml: Path,
|
||||
):
|
||||
"""Add provider → set API key → set fallback → set quota →
|
||||
delete provider."""
|
||||
# 1. POST /admin/llm/providers → create "test-provider".
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={
|
||||
"name": "test-provider",
|
||||
"type": "openai",
|
||||
"api_key": "sk-test-abcdef1234",
|
||||
"base_url": "https://api.test.com",
|
||||
"models": {"test-model": {}},
|
||||
"max_tokens": 2048,
|
||||
"timeout": 60.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
assert resp.json()["name"] == "test-provider"
|
||||
assert resp.json()["api_key"].startswith("****")
|
||||
|
||||
# 2. POST /admin/llm/providers/test-provider/api-key → set key.
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/llm/providers/test-provider/api-key",
|
||||
json={"api_key": "sk-brand-new-key-9876"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["api_key"].startswith("****")
|
||||
|
||||
# 3. PUT /admin/llm/fallbacks/gpt-4 → set fallback chain.
|
||||
resp = admin_client.put(
|
||||
"/api/v1/admin/llm/fallbacks/gpt-4",
|
||||
json={"chain": ["openai", "test-provider"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["model"] == "gpt-4"
|
||||
assert body["chain"] == ["openai", "test-provider"]
|
||||
|
||||
# 4. PUT /admin/departments/{id}/quotas → set token_limit.
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={
|
||||
"quota_type": "token_limit",
|
||||
"limit_value": 1000,
|
||||
"period": "daily",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["quota_type"] == "token_limit"
|
||||
assert body["limit_value"] == 1000
|
||||
|
||||
# 5. GET /admin/llm/providers → provider in list.
|
||||
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||
assert resp.status_code == 200
|
||||
names = {p["name"] for p in resp.json()}
|
||||
assert "test-provider" in names
|
||||
assert "openai" in names
|
||||
|
||||
# 6. DELETE /admin/llm/providers/test-provider → fails (used in fallback).
|
||||
resp = admin_client.delete("/api/v1/admin/llm/providers/test-provider")
|
||||
assert resp.status_code == 400
|
||||
assert "fallback" in resp.json()["detail"].lower()
|
||||
|
||||
# Remove the fallback chain first, then delete.
|
||||
admin_client.delete("/api/v1/admin/llm/fallbacks/gpt-4")
|
||||
resp = admin_client.delete("/api/v1/admin/llm/providers/test-provider")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
# Confirm it's gone.
|
||||
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||
names = {p["name"] for p in resp.json()}
|
||||
assert "test-provider" not in names
|
||||
|
||||
def test_skill_management_flow(
|
||||
self,
|
||||
admin_client: TestClient,
|
||||
tmp_skills_dir: str,
|
||||
skill_registry: SkillRegistry,
|
||||
):
|
||||
"""Disable skill → verify excluded → enable → verify included."""
|
||||
# 1. Create a test skill YAML in tmp_skills_dir and register it.
|
||||
_write_skill_yaml(tmp_skills_dir, "e2e_test_skill", _VALID_SKILL_YAML)
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
SkillLoader(skill_registry=skill_registry).load_from_file(
|
||||
os.path.join(tmp_skills_dir, "e2e_test_skill.yaml")
|
||||
)
|
||||
|
||||
# Verify the skill is initially listed.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "e2e_test_skill" in names
|
||||
|
||||
# 2. POST /admin/skills/{name}/disable → disabled.
|
||||
resp = admin_client.post("/api/v1/admin/skills/e2e_test_skill/disable")
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
assert body["skill_name"] == "e2e_test_skill"
|
||||
assert body["is_disabled"] is True
|
||||
|
||||
# 3. Verify via GET /skills that the skill is excluded.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "e2e_test_skill" not in names
|
||||
|
||||
# 4. POST /admin/skills/{name}/enable → enabled.
|
||||
resp = admin_client.post("/api/v1/admin/skills/e2e_test_skill/enable")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["enabled"] is True
|
||||
|
||||
# Verify via GET /skills that the skill is included again.
|
||||
resp = admin_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = [s["name"] for s in resp.json()]
|
||||
assert "e2e_test_skill" in names
|
||||
|
||||
def test_usage_dashboard_flow(self, admin_client: TestClient):
|
||||
"""Query usage endpoints with empty data → verify empty results."""
|
||||
# The admin_app fixture doesn't install an llm_gateway on app.state,
|
||||
# so the usage routes will return 500. We need to install a stub
|
||||
# gateway with an empty InMemoryUsageStore.
|
||||
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
||||
|
||||
class _StubTracker:
|
||||
def __init__(self, store):
|
||||
self.store = store
|
||||
|
||||
class _StubGateway:
|
||||
def __init__(self, store):
|
||||
self._usage_tracker = _StubTracker(store)
|
||||
|
||||
store = InMemoryUsageStore()
|
||||
# Install the stub gateway on the app state.
|
||||
admin_client.app.state.llm_gateway = _StubGateway(store)
|
||||
|
||||
# 1. GET /admin/usage/summary → 200 with zeros.
|
||||
resp = admin_client.get("/api/v1/admin/usage/summary")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total_tokens"] == 0
|
||||
assert body["total_cost"] == 0.0
|
||||
assert body["total_requests"] == 0
|
||||
|
||||
# 2. GET /admin/usage/timeseries → 200 with empty list.
|
||||
resp = admin_client.get("/api/v1/admin/usage/timeseries")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# 3. GET /admin/usage/by-model → 200 with empty list.
|
||||
resp = admin_client.get("/api/v1/admin/usage/by-model")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# 4. GET /admin/usage/top-users → 200 with empty list.
|
||||
resp = admin_client.get("/api/v1/admin/usage/top-users")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
# 5. GET /admin/usage/export?format=csv → 200 with CSV header.
|
||||
resp = admin_client.get("/api/v1/admin/usage/export", params={"format": "csv"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/csv")
|
||||
reader = csv.DictReader(io.StringIO(resp.text))
|
||||
rows = list(reader)
|
||||
assert rows == []
|
||||
# Header should still be present.
|
||||
assert "timestamp" in resp.text
|
||||
assert "user_id" in resp.text
|
||||
assert "department_id" in resp.text
|
||||
|
|
@ -0,0 +1,425 @@
|
|||
"""Quota enforcement integration tests (U10).
|
||||
|
||||
Verifies quota enforcement end-to-end through the LLMGateway:
|
||||
|
||||
- Token limit exceeded → QuotaExceededError raised.
|
||||
- Cost limit exceeded → QuotaExceededError raised.
|
||||
- Model not in whitelist → QuotaExceededError raised.
|
||||
- No quota set → request allowed.
|
||||
- Multi-department: strictest-wins (one exceeds, other doesn't → rejected).
|
||||
- Integration test with real LLMGateway + mock provider + InMemoryUsageStore.
|
||||
|
||||
These tests use a real :class:`LLMGateway` with a :class:`FakeProvider`
|
||||
(mock LLM provider) and a real :class:`QuotaService` backed by a temp
|
||||
SQLite auth DB. No external services (Redis, real LLM API) are required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.gateway import LLMGateway, QuotaExceededError
|
||||
from agentkit.llm.protocol import (
|
||||
LLMProvider,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
TokenUsage,
|
||||
)
|
||||
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
||||
from agentkit.server.admin.quota_service import (
|
||||
get_quota_service,
|
||||
set_quota_service,
|
||||
)
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test doubles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeProvider(LLMProvider):
|
||||
"""A minimal LLMProvider that returns a fixed response.
|
||||
|
||||
The response usage (prompt_tokens, completion_tokens) can be
|
||||
customized per-instance to simulate different token consumption.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake",
|
||||
prompt_tokens: int = 100,
|
||||
completion_tokens: int = 50,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._prompt_tokens = prompt_tokens
|
||||
self._completion_tokens = completion_tokens
|
||||
self.last_request: LLMRequest | None = None
|
||||
self.call_count = 0
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
self.last_request = request
|
||||
self.call_count += 1
|
||||
return LLMResponse(
|
||||
content=f"response from {self._name}",
|
||||
model=request.model,
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store() -> InMemoryUsageStore:
|
||||
return InMemoryUsageStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gateway(store: InMemoryUsageStore) -> LLMGateway:
|
||||
"""A real LLMGateway with a FakeProvider registered as "openai"."""
|
||||
gw = LLMGateway(usage_store=store)
|
||||
gw.register_provider("openai", FakeProvider("openai"))
|
||||
return gw
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fresh_db(tmp_path: Path) -> Path:
|
||||
"""A brand-new auth DB on a fresh path (no data)."""
|
||||
db_path = tmp_path / "quota_enforcement.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_quota_singleton():
|
||||
"""Reset the QuotaService singleton before and after each test."""
|
||||
set_quota_service(None)
|
||||
yield
|
||||
set_quota_service(None)
|
||||
|
||||
|
||||
def _random_dept_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quota enforcement tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuotaEnforcement:
|
||||
"""Tests for quota enforcement in LLM calls."""
|
||||
|
||||
async def test_token_limit_blocks_request(
|
||||
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||
):
|
||||
"""When department exceeds token limit, LLM call raises QuotaExceededError."""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Set a 100-token daily limit.
|
||||
await svc.set_quota(fresh_db, dept_id, "token_limit", 100, period="daily")
|
||||
|
||||
# Pre-populate usage with 90 tokens (just under the limit).
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="prev",
|
||||
model="openai/gpt-4o",
|
||||
usage=TokenUsage(prompt_tokens=60, completion_tokens=30),
|
||||
cost=0.0,
|
||||
latency_ms=10,
|
||||
user_id="u1",
|
||||
department_id=dept_id,
|
||||
)
|
||||
|
||||
# The FakeProvider would use 100+50=150 tokens, but the quota
|
||||
# check happens BEFORE the provider call. Since current usage
|
||||
# (90) + nothing is checked — the gateway checks current_usage
|
||||
# >= limit, which is 90 < 100, so this would actually pass.
|
||||
#
|
||||
# To force a block, we pre-populate usage AT the limit (100
|
||||
# tokens). The check is `current_usage >= limit`, so 100 >= 100
|
||||
# → blocked.
|
||||
store._records.clear()
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="prev",
|
||||
model="openai/gpt-4o",
|
||||
usage=TokenUsage(prompt_tokens=70, completion_tokens=30),
|
||||
cost=0.0,
|
||||
latency_ms=10,
|
||||
user_id="u1",
|
||||
department_id=dept_id,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError) as exc_info:
|
||||
await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
err = exc_info.value
|
||||
assert err.department_id == dept_id
|
||||
assert err.quota_type == "token_limit"
|
||||
assert err.period == "daily"
|
||||
assert err.limit == 100
|
||||
assert err.current == 100 # 70 prompt + 30 completion
|
||||
|
||||
async def test_cost_limit_blocks_request(
|
||||
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||
):
|
||||
"""When department exceeds cost limit, LLM call raises QuotaExceededError."""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# cost_limit is in cents. Set 100 cents ($1.00) daily limit.
|
||||
await svc.set_quota(fresh_db, dept_id, "cost_limit", 100, period="daily")
|
||||
|
||||
# Pre-populate usage with $1.50 cost = 150 cents, exceeding the
|
||||
# 100-cent limit.
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="prev",
|
||||
model="openai/gpt-4o",
|
||||
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||
cost=1.50, # $1.50 = 150 cents
|
||||
latency_ms=10,
|
||||
user_id="u1",
|
||||
department_id=dept_id,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError) as exc_info:
|
||||
await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
err = exc_info.value
|
||||
assert err.quota_type == "cost_limit"
|
||||
assert err.period == "daily"
|
||||
assert err.limit == 100
|
||||
# current is in cents (150 cents = $1.50).
|
||||
assert err.current == 150.0
|
||||
|
||||
async def test_model_whitelist_blocks_unlisted_model(self, gateway: LLMGateway, fresh_db: Path):
|
||||
"""When model not in whitelist, LLM call is rejected."""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Whitelist only allows "claude" — gateway is calling "gpt-4o".
|
||||
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily")
|
||||
|
||||
with pytest.raises(QuotaExceededError) as exc_info:
|
||||
await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
err = exc_info.value
|
||||
assert err.quota_type == "model_whitelist"
|
||||
assert err.department_id == dept_id
|
||||
# For model_whitelist, current is the rejected model name.
|
||||
assert err.current == "openai/gpt-4o"
|
||||
|
||||
async def test_no_quota_allows_all(self, gateway: LLMGateway, fresh_db: Path):
|
||||
"""Without any quota set, all requests are allowed."""
|
||||
dept_id = _random_dept_id()
|
||||
# No quota set — request should succeed.
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
assert response.usage.total_tokens == 150 # 100 prompt + 50 completion
|
||||
|
||||
async def test_multi_department_strictest_wins(
|
||||
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||
):
|
||||
"""User in depts A+B: A has quota, B doesn't → A's quota applies.
|
||||
|
||||
Strictest-wins: if ANY department fails ANY check, the request
|
||||
is rejected.
|
||||
"""
|
||||
dept_a = _random_dept_id()
|
||||
dept_b = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Set a 1-token limit on dept A only; dept B has no quota.
|
||||
await svc.set_quota(fresh_db, dept_a, "token_limit", 1, period="daily")
|
||||
|
||||
# Pre-populate usage for dept A so it exceeds the 1-token limit.
|
||||
gateway._usage_tracker.record(
|
||||
agent_name="prev",
|
||||
model="openai/gpt-4o",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cost=0.0,
|
||||
latency_ms=10,
|
||||
user_id="u1",
|
||||
department_id=dept_a,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaExceededError) as exc_info:
|
||||
await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_a, dept_b],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
# The error should reference dept_a (the one that exceeded).
|
||||
assert exc_info.value.department_id == dept_a
|
||||
assert exc_info.value.quota_type == "token_limit"
|
||||
|
||||
async def test_quota_check_with_real_gateway(
|
||||
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||
):
|
||||
"""Integration test with real LLMGateway + mock provider.
|
||||
|
||||
Verifies the full flow:
|
||||
1. Quota check happens before the provider call.
|
||||
2. On success, usage is recorded with the correct department_id.
|
||||
3. The usage record carries user_id + department_id.
|
||||
"""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Set a generous token limit (1M tokens) — should not block.
|
||||
await svc.set_quota(fresh_db, dept_id, "token_limit", 1_000_000, period="daily")
|
||||
|
||||
# Make the LLM call.
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
|
||||
# Verify usage was recorded with the correct attributes.
|
||||
summary = store.get_usage()
|
||||
assert len(summary.records) == 1
|
||||
rec = summary.records[0]
|
||||
assert rec.user_id == "u1"
|
||||
assert rec.department_id == dept_id
|
||||
assert rec.model == "gpt-4o"
|
||||
assert rec.total_tokens == 150 # 100 prompt + 50 completion
|
||||
|
||||
# Verify the quota check counted this usage (next call should
|
||||
# still pass since the limit is 1M tokens).
|
||||
response2 = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi again"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response2.content == "response from openai"
|
||||
|
||||
# Now there should be 2 usage records.
|
||||
summary = store.get_usage()
|
||||
assert len(summary.records) == 2
|
||||
# All records should carry the department_id.
|
||||
assert all(r.department_id == dept_id for r in summary.records)
|
||||
|
||||
async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path):
|
||||
"""When db_path is None, no quota check is performed."""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Set a tiny quota that would normally block.
|
||||
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
|
||||
|
||||
# Call without db_path — should succeed (no quota check).
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=None,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
|
||||
async def test_quota_check_skipped_without_department_ids(
|
||||
self, gateway: LLMGateway, fresh_db: Path
|
||||
):
|
||||
"""When department_ids is None, no quota check is performed."""
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=None,
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
|
||||
async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path):
|
||||
"""Model in whitelist → request allowed."""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Whitelist uses the full resolved model identifier (provider/model).
|
||||
await svc.set_quota(
|
||||
fresh_db,
|
||||
dept_id,
|
||||
"model_whitelist",
|
||||
["openai/gpt-4o"],
|
||||
period="daily",
|
||||
)
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
|
||||
async def test_quota_check_uses_correct_period_window(
|
||||
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||
):
|
||||
"""Quota check uses the daily window (since 00:00 UTC today).
|
||||
|
||||
The quota check happens BEFORE the LLM call, using the current
|
||||
accumulated usage. So:
|
||||
- 1st call: usage=0, check 0 >= 150 → False → allowed. After
|
||||
the call, usage=150.
|
||||
- 2nd call: usage=150, check 150 >= 150 → True → blocked.
|
||||
"""
|
||||
dept_id = _random_dept_id()
|
||||
svc = get_quota_service()
|
||||
# Set a 150-token daily limit (the FakeProvider uses 150 tokens
|
||||
# per call: 100 prompt + 50 completion).
|
||||
await svc.set_quota(fresh_db, dept_id, "token_limit", 150, period="daily")
|
||||
|
||||
# First call: current usage is 0, under the 150 limit → allowed.
|
||||
response = await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert response.content == "response from openai"
|
||||
|
||||
# Second call: current usage is now 150 (from the first call),
|
||||
# which is >= the 150-token limit → blocked.
|
||||
with pytest.raises(QuotaExceededError) as exc_info:
|
||||
await gateway.chat(
|
||||
messages=[{"role": "user", "content": "hi again"}],
|
||||
model="openai/gpt-4o",
|
||||
user_id="u1",
|
||||
department_ids=[dept_id],
|
||||
db_path=fresh_db,
|
||||
)
|
||||
assert exc_info.value.quota_type == "token_limit"
|
||||
assert exc_info.value.current == 150 # accumulated from first call
|
||||
assert exc_info.value.limit == 150
|
||||
|
|
@ -0,0 +1,483 @@
|
|||
"""Security isolation integration tests for department-scoped resources (U10).
|
||||
|
||||
Verifies department-based access control end-to-end through the full
|
||||
request stack (route → DepartmentContext → filtering → response):
|
||||
|
||||
- User in dept A cannot see skills/KB bound to dept B.
|
||||
- User in depts A+B sees the union of both departments' resources.
|
||||
- Admin sees all resources regardless of department bindings.
|
||||
- User removed from a department loses access to that department's
|
||||
resources.
|
||||
- Disabled department's resources are not visible to its users.
|
||||
- API-key client (no user_id) sees only global resources.
|
||||
- Non-admin user gets 403 on all admin endpoints.
|
||||
|
||||
The tests mount a minimal FastAPI app with the ``skills`` and
|
||||
``kb-management`` routers (for isolation verification) plus the
|
||||
``admin_router`` (for the 403 checks). The ``get_department_context``
|
||||
dependency is overridden per-test to simulate different callers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.server.admin.context import DepartmentContext
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.server.routes import admin as admin_routes_module
|
||||
from agentkit.server.routes import kb_management as kb_routes
|
||||
from agentkit.server.routes import skills as skills_routes
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "security_isolation.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skill_registry() -> SkillRegistry:
|
||||
"""A SkillRegistry pre-loaded with three test skills.
|
||||
|
||||
- ``hr_skill`` — will be bound to department A.
|
||||
- ``eng_skill`` — will be bound to department B.
|
||||
- ``global_skill`` — has NO department binding (global).
|
||||
"""
|
||||
registry = SkillRegistry()
|
||||
for name in ("hr_skill", "eng_skill", "global_skill"):
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="test_type",
|
||||
task_mode="llm_generate",
|
||||
description=f"Test skill {name}",
|
||||
prompt={"identity": name, "instructions": "test"},
|
||||
)
|
||||
registry.register(Skill(config=config))
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kb_store():
|
||||
"""Reset the module-level KB source store singleton."""
|
||||
kb_routes._source_store = kb_routes.KnowledgeSourceStore()
|
||||
return kb_routes._source_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(
|
||||
tmp_auth_db: Path,
|
||||
skill_registry: SkillRegistry,
|
||||
kb_store: kb_routes.KnowledgeSourceStore,
|
||||
) -> FastAPI:
|
||||
"""A FastAPI app with skills + kb-management + admin routers mounted.
|
||||
|
||||
The ``get_department_context`` dependency is overridden per-test via
|
||||
``app.dependency_overrides``. The default override is
|
||||
"unauthenticated caller".
|
||||
"""
|
||||
application = FastAPI()
|
||||
application.state.auth_db_path = str(tmp_auth_db)
|
||||
application.state.skill_registry = skill_registry
|
||||
|
||||
application.include_router(skills_routes.router, prefix="/api/v1")
|
||||
application.include_router(kb_routes.router, prefix="/api/v1")
|
||||
application.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||
|
||||
# Default: unauthenticated caller.
|
||||
application.dependency_overrides[skills_routes.get_department_context] = (
|
||||
_unauthenticated_context
|
||||
)
|
||||
application.dependency_overrides[kb_routes.get_department_context] = _unauthenticated_context
|
||||
# Default: admin access allowed (used only for the admin 403 test
|
||||
# which overrides this).
|
||||
application.dependency_overrides[admin_routes_module._require_admin] = lambda: (
|
||||
_make_admin_user()
|
||||
)
|
||||
return application
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app: FastAPI) -> TestClient:
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Department-context dependency overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _unauthenticated_context(request: Request) -> DepartmentContext:
|
||||
"""Simulate an unauthenticated caller (no current_user)."""
|
||||
return DepartmentContext(user_id=None, department_ids=[], is_admin=False)
|
||||
|
||||
|
||||
def _ctx_for_user(
|
||||
user_id: str | None,
|
||||
department_ids: list[str],
|
||||
is_admin: bool = False,
|
||||
):
|
||||
"""Build a dependency override returning a fixed DepartmentContext."""
|
||||
|
||||
async def _override(request: Request) -> DepartmentContext:
|
||||
return DepartmentContext(
|
||||
user_id=user_id,
|
||||
department_ids=list(department_ids),
|
||||
is_admin=is_admin,
|
||||
)
|
||||
|
||||
return _override
|
||||
|
||||
|
||||
def _set_caller(
|
||||
app: FastAPI,
|
||||
user_id: str | None,
|
||||
department_ids: list[str],
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Install the dependency overrides for both skills + kb routers."""
|
||||
override = _ctx_for_user(user_id, department_ids, is_admin)
|
||||
app.dependency_overrides[skills_routes.get_department_context] = override
|
||||
app.dependency_overrides[kb_routes.get_department_context] = override
|
||||
|
||||
|
||||
def _make_admin_user() -> dict[str, Any]:
|
||||
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||
|
||||
|
||||
def _raise_forbidden() -> dict[str, Any]:
|
||||
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB helpers (synchronous sqlite3 — no event-loop mixing with TestClient)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _create_department(db_path: Path, name: str) -> str:
|
||||
dept_id = str(uuid.uuid4())
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"INSERT INTO departments (id, name, description, is_active, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(dept_id, name, "", 1, _now_iso()),
|
||||
)
|
||||
db.commit()
|
||||
return dept_id
|
||||
|
||||
|
||||
def _disable_department(db_path: Path, department_id: str) -> None:
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"UPDATE departments SET is_active = 0 WHERE id = ?",
|
||||
(department_id,),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _bind_skill(db_path: Path, department_id: str, skill_name: str) -> None:
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"INSERT INTO department_skill_bindings (id, department_id, skill_name, created_at) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), department_id, skill_name, _now_iso()),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _bind_kb(db_path: Path, department_id: str, kb_source_id: str) -> None:
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"INSERT INTO department_kb_bindings (id, department_id, kb_source_id, created_at) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(str(uuid.uuid4()), department_id, kb_source_id, _now_iso()),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _assign_user_to_department(db_path: Path, user_id: str, department_id: str) -> None:
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"INSERT INTO user_departments (user_id, department_id, created_at) VALUES (?, ?, ?)",
|
||||
(user_id, department_id, _now_iso()),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _remove_user_from_department(db_path: Path, user_id: str, department_id: str) -> None:
|
||||
with sqlite3.connect(str(db_path)) as db:
|
||||
db.execute(
|
||||
"DELETE FROM user_departments WHERE user_id = ? AND department_id = ?",
|
||||
(user_id, department_id),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test fixture: departments A and B with skill/KB bindings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dept_setup(tmp_auth_db: Path, kb_store: kb_routes.KnowledgeSourceStore):
|
||||
"""Create departments A and B, bind skills and KB sources.
|
||||
|
||||
Layout:
|
||||
- Department A: bound to ``hr_skill`` and KB source ``hr_kb``
|
||||
- Department B: bound to ``eng_skill`` and KB source ``eng_kb``
|
||||
- ``global_skill`` and ``global_kb`` have NO bindings (global)
|
||||
"""
|
||||
dept_a = _create_department(tmp_auth_db, "HR")
|
||||
dept_b = _create_department(tmp_auth_db, "Engineering")
|
||||
_bind_skill(tmp_auth_db, dept_a, "hr_skill")
|
||||
_bind_skill(tmp_auth_db, dept_b, "eng_skill")
|
||||
# global_skill intentionally has no binding.
|
||||
|
||||
# Create KB sources in the in-memory store.
|
||||
hr_kb = kb_store.add_source("HR KB", "local", {})
|
||||
eng_kb = kb_store.add_source("Engineering KB", "local", {})
|
||||
global_kb = kb_store.add_source("Global KB", "local", {})
|
||||
_bind_kb(tmp_auth_db, dept_a, hr_kb.id)
|
||||
_bind_kb(tmp_auth_db, dept_b, eng_kb.id)
|
||||
# global_kb intentionally has no binding.
|
||||
|
||||
return {
|
||||
"dept_a": dept_a,
|
||||
"dept_b": dept_b,
|
||||
"hr_kb_id": hr_kb.id,
|
||||
"eng_kb_id": eng_kb.id,
|
||||
"global_kb_id": global_kb.id,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Department isolation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDepartmentIsolation:
|
||||
"""Security tests for department-based resource isolation."""
|
||||
|
||||
def test_user_cannot_see_other_department_skills(
|
||||
self, app: FastAPI, client: TestClient, dept_setup: dict
|
||||
):
|
||||
"""User in dept A cannot see skills bound to dept B."""
|
||||
_set_caller(app, user_id="alice", department_ids=[dept_setup["dept_a"]])
|
||||
resp = client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = {s["name"] for s in resp.json()}
|
||||
# Alice sees hr_skill (dept A) + global_skill, NOT eng_skill (dept B).
|
||||
assert "hr_skill" in names
|
||||
assert "global_skill" in names
|
||||
assert "eng_skill" not in names
|
||||
|
||||
def test_user_cannot_see_other_department_kb(
|
||||
self, app: FastAPI, client: TestClient, dept_setup: dict
|
||||
):
|
||||
"""User in dept A cannot see KB sources bound to dept B."""
|
||||
_set_caller(app, user_id="alice", department_ids=[dept_setup["dept_a"]])
|
||||
resp = client.get("/api/v1/kb-management/sources")
|
||||
assert resp.status_code == 200
|
||||
ids = {s["id"] for s in resp.json()["sources"]}
|
||||
# Alice sees hr_kb (dept A) + global_kb, NOT eng_kb (dept B).
|
||||
assert dept_setup["hr_kb_id"] in ids
|
||||
assert dept_setup["global_kb_id"] in ids
|
||||
assert dept_setup["eng_kb_id"] not in ids
|
||||
|
||||
def test_user_in_multiple_departments_sees_union(
|
||||
self, app: FastAPI, client: TestClient, dept_setup: dict
|
||||
):
|
||||
"""User in depts A+B sees skills from both."""
|
||||
_set_caller(
|
||||
app,
|
||||
user_id="alice",
|
||||
department_ids=[dept_setup["dept_a"], dept_setup["dept_b"]],
|
||||
)
|
||||
resp = client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = {s["name"] for s in resp.json()}
|
||||
# Alice sees hr_skill (A) + eng_skill (B) + global_skill.
|
||||
assert names == {"hr_skill", "eng_skill", "global_skill"}
|
||||
|
||||
def test_admin_sees_all_resources(self, app: FastAPI, client: TestClient, dept_setup: dict):
|
||||
"""Admin user sees all resources regardless of department bindings."""
|
||||
_set_caller(
|
||||
app,
|
||||
user_id="admin-1",
|
||||
department_ids=[],
|
||||
is_admin=True,
|
||||
)
|
||||
# Admin sees all skills.
|
||||
resp = client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert names == {"hr_skill", "eng_skill", "global_skill"}
|
||||
|
||||
# Admin sees all KB sources.
|
||||
resp = client.get("/api/v1/kb-management/sources")
|
||||
assert resp.status_code == 200
|
||||
ids = {s["id"] for s in resp.json()["sources"]}
|
||||
assert ids == {
|
||||
dept_setup["hr_kb_id"],
|
||||
dept_setup["eng_kb_id"],
|
||||
dept_setup["global_kb_id"],
|
||||
}
|
||||
|
||||
def test_user_removed_from_department_loses_access(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: TestClient,
|
||||
tmp_auth_db: Path,
|
||||
dept_setup: dict,
|
||||
):
|
||||
"""User removed from dept A can no longer see dept A's skills."""
|
||||
user_id = "user-removal"
|
||||
dept_a = dept_setup["dept_a"]
|
||||
|
||||
# Initially assign to dept A.
|
||||
_assign_user_to_department(tmp_auth_db, user_id, dept_a)
|
||||
_set_caller(app, user_id=user_id, department_ids=[dept_a])
|
||||
|
||||
resp = client.get("/api/v1/skills")
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert "hr_skill" in names
|
||||
|
||||
# Remove from dept A — simulate the context change.
|
||||
_remove_user_from_department(tmp_auth_db, user_id, dept_a)
|
||||
_set_caller(app, user_id=user_id, department_ids=[])
|
||||
|
||||
resp = client.get("/api/v1/skills")
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert "hr_skill" not in names
|
||||
assert "global_skill" in names
|
||||
|
||||
def test_disabled_department_excluded(
|
||||
self,
|
||||
app: FastAPI,
|
||||
client: TestClient,
|
||||
tmp_auth_db: Path,
|
||||
dept_setup: dict,
|
||||
):
|
||||
"""Disabled department's resources are not visible to its users.
|
||||
|
||||
The ``_fetch_user_department_ids`` helper in
|
||||
:mod:`agentkit.server.admin.context` filters out disabled
|
||||
departments (``is_active=0``). We simulate this by disabling
|
||||
dept A in the DB and updating the caller's context to reflect
|
||||
the now-empty department list.
|
||||
"""
|
||||
user_id = "user-disabled-dept"
|
||||
dept_a = dept_setup["dept_a"]
|
||||
|
||||
# Initially assign to dept A and verify access.
|
||||
_assign_user_to_department(tmp_auth_db, user_id, dept_a)
|
||||
_set_caller(app, user_id=user_id, department_ids=[dept_a])
|
||||
resp = client.get("/api/v1/skills")
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert "hr_skill" in names
|
||||
|
||||
# Disable dept A in the DB.
|
||||
_disable_department(tmp_auth_db, dept_a)
|
||||
# Simulate the context change: the next request's
|
||||
# ``get_department_context`` would re-query user_departments
|
||||
# and find no *active* departments.
|
||||
_set_caller(app, user_id=user_id, department_ids=[])
|
||||
|
||||
resp = client.get("/api/v1/skills")
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert "hr_skill" not in names
|
||||
assert "global_skill" in names
|
||||
|
||||
def test_api_key_client_sees_only_global(
|
||||
self, app: FastAPI, client: TestClient, dept_setup: dict
|
||||
):
|
||||
"""API key client (no user_id) sees only global resources."""
|
||||
# API-key client → user_id=None, department_ids=[], is_admin=False.
|
||||
_set_caller(app, user_id=None, department_ids=[], is_admin=False)
|
||||
|
||||
# Skills: only global_skill.
|
||||
resp = client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
names = {s["name"] for s in resp.json()}
|
||||
assert names == {"global_skill"}
|
||||
|
||||
# KB sources: only global_kb.
|
||||
resp = client.get("/api/v1/kb-management/sources")
|
||||
assert resp.status_code == 200
|
||||
ids = {s["id"] for s in resp.json()["sources"]}
|
||||
assert ids == {dept_setup["global_kb_id"]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-admin access tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNonAdminAccess:
|
||||
"""Non-admin user gets 403 on all admin endpoints."""
|
||||
|
||||
def test_non_admin_cannot_access_admin_endpoints(self, app: FastAPI, tmp_auth_db: Path):
|
||||
"""Non-admin user gets 403 on a representative sample of admin endpoints."""
|
||||
# Override _require_admin to raise 403 (simulating a non-admin caller).
|
||||
app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(app)
|
||||
|
||||
# Representative sample of admin endpoints across all resource types.
|
||||
endpoints = [
|
||||
("GET", "/api/v1/admin/departments"),
|
||||
("POST", "/api/v1/admin/departments"),
|
||||
("GET", "/api/v1/admin/users"),
|
||||
("POST", "/api/v1/admin/users"),
|
||||
("GET", "/api/v1/admin/llm/providers"),
|
||||
("POST", "/api/v1/admin/llm/providers"),
|
||||
("GET", "/api/v1/admin/usage/summary"),
|
||||
("GET", "/api/v1/admin/usage/timeseries"),
|
||||
("GET", "/api/v1/admin/usage/by-model"),
|
||||
("GET", "/api/v1/admin/usage/top-users"),
|
||||
("GET", "/api/v1/admin/usage/export"),
|
||||
]
|
||||
|
||||
for method, path in endpoints:
|
||||
if method == "GET":
|
||||
resp = client.get(path)
|
||||
elif method == "POST":
|
||||
# Use a minimal body where required; we expect 403 before
|
||||
# body validation runs.
|
||||
body: dict[str, Any] | None = {}
|
||||
if path == "/api/v1/admin/departments":
|
||||
body = {"name": "X"}
|
||||
elif path == "/api/v1/admin/users":
|
||||
body = {
|
||||
"username": "x",
|
||||
"email": "x@x.com",
|
||||
"password": "Pw123!",
|
||||
}
|
||||
elif path == "/api/v1/admin/llm/providers":
|
||||
body = {
|
||||
"name": "x",
|
||||
"type": "openai",
|
||||
"api_key": "sk-x",
|
||||
}
|
||||
resp = client.post(path, json=body)
|
||||
assert resp.status_code == 403, (
|
||||
f"{method} {path} returned {resp.status_code}, expected 403"
|
||||
)
|
||||
Loading…
Reference in New Issue