geo/backend/tests/test_api/test_api_keys_api.py

419 lines
13 KiB
Python

import uuid
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.user import User
from app.services.api_key_manager import APIKeyManager, KeySource
from app.services.auth import hash_password
from app.services.usage_tracker import UsageTracker
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
def key_manager():
return APIKeyManager()
@pytest_asyncio.fixture
def usage_tracker(test_user):
uid = str(test_user.id)
tracker = UsageTracker()
tracker.record(
user_id=uid,
brand_id="brand-1",
engine_type="deepseek",
query="test query",
input_tokens=100,
output_tokens=200,
cost=0.01,
)
tracker.record(
user_id=uid,
brand_id="brand-1",
engine_type="chatgpt",
query="test query 2",
input_tokens=500,
output_tokens=1000,
cost=0.05,
)
return tracker
@pytest_asyncio.fixture
async def async_client(async_session, test_user, key_manager, usage_tracker):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestAddAPIKey:
@pytest.mark.asyncio
async def test_add_key_success(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
response = await async_client.post(
"/api/v1/api-keys/",
json={
"engine_type": "chatgpt",
"api_key": "sk-abcdef1234567890",
"source": "user",
},
)
assert response.status_code == 200
data = response.json()
assert data["engine_type"] == "chatgpt"
assert "key_hint" in data
assert data["key_hint"].startswith("sk-")
assert data["status"] == "active"
assert "api_key" not in data
@pytest.mark.asyncio
async def test_add_key_no_plaintext_in_response(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
raw_key = "sk-super-secret-key-12345678"
response = await async_client.post(
"/api/v1/api-keys/",
json={
"engine_type": "deepseek",
"api_key": raw_key,
"source": "user",
},
)
assert response.status_code == 200
text = response.text
assert raw_key not in text
class TestListAPIKeys:
@pytest.mark.asyncio
async def test_list_keys_success(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
key_manager.add_key("deepseek", "dsk-xyz9876543210", source=KeySource.ENV)
response = await async_client.get("/api/v1/api-keys/")
assert response.status_code == 200
data = response.json()
assert "items" in data
assert len(data["items"]) == 2
@pytest.mark.asyncio
async def test_list_keys_filter_by_engine(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
key_manager.add_key("deepseek", "dsk-xyz9876543210", source=KeySource.ENV)
response = await async_client.get("/api/v1/api-keys/?engine_type=chatgpt")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 1
assert data["items"][0]["engine_type"] == "chatgpt"
@pytest.mark.asyncio
async def test_list_keys_masked(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
response = await async_client.get("/api/v1/api-keys/")
assert response.status_code == 200
data = response.json()
for item in data["items"]:
assert "api_key" not in item
assert "encrypted_key" not in item
assert "key_hint" in item
class TestDeleteAPIKey:
@pytest.mark.asyncio
async def test_delete_key_success(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
config = key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
response = await async_client.delete(
f"/api/v1/api-keys/chatgpt/{config.key_hint}"
)
assert response.status_code == 200
data = response.json()
assert data["deleted"] is True
@pytest.mark.asyncio
async def test_delete_key_not_found(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
response = await async_client.delete("/api/v1/api-keys/chatgpt/nonexistent")
assert response.status_code == 404
class TestVerifyAPIKey:
@pytest.mark.asyncio
async def test_verify_key_success(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
response = await async_client.post(
"/api/v1/api-keys/verify",
json={"engine_type": "chatgpt"},
)
assert response.status_code == 200
data = response.json()
assert data["engine_type"] == "chatgpt"
assert data["status"] == "active"
@pytest.mark.asyncio
async def test_verify_key_no_key_configured(self, async_client, key_manager):
from app.api.api_keys import set_key_manager
set_key_manager(key_manager)
response = await async_client.post(
"/api/v1/api-keys/verify",
json={"engine_type": "perplexity"},
)
assert response.status_code == 404
class TestGetEngines:
@pytest.mark.asyncio
async def test_get_engines_success(self, async_client):
response = await async_client.get("/api/v1/api-keys/engines")
assert response.status_code == 200
data = response.json()
assert "engines" in data
assert len(data["engines"]) > 0
engine = data["engines"][0]
assert "type" in engine
assert "cost_tier" in engine
assert "has_free_tier" in engine
assert "requires_own_key" in engine
@pytest.mark.asyncio
async def test_get_engines_contains_deepseek(self, async_client):
response = await async_client.get("/api/v1/api-keys/engines")
assert response.status_code == 200
data = response.json()
types = [e["type"] for e in data["engines"]]
assert "deepseek" in types
deepseek = next(e for e in data["engines"] if e["type"] == "deepseek")
assert deepseek["cost_tier"] == "free"
assert deepseek["has_free_tier"] is True
assert deepseek["requires_own_key"] is False
assert "input_price" in deepseek
assert "output_price" in deepseek
class TestUsageSummary:
@pytest.mark.asyncio
async def test_get_usage_summary(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/summary?period=month")
assert response.status_code == 200
data = response.json()
assert data["period"] == "month"
assert data["total_queries"] == 2
assert data["total_cost"] > 0
assert "by_engine" in data
@pytest.mark.asyncio
async def test_get_usage_summary_day_period(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/summary?period=day")
assert response.status_code == 200
data = response.json()
assert data["period"] == "day"
class TestUsageQuota:
@pytest.mark.asyncio
async def test_get_quota(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/quota")
assert response.status_code == 200
data = response.json()
assert "used" in data
assert "limit" in data
assert "usage_percentage" in data
assert "status" in data
assert data["status"] in ("ok", "warning", "exceeded")
@pytest.mark.asyncio
async def test_get_quota_ok_status(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/quota")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["usage_percentage"] < 100
class TestUsageByEngine:
@pytest.mark.asyncio
async def test_get_usage_by_engine(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/by-engine")
assert response.status_code == 200
data = response.json()
assert "engines" in data
assert len(data["engines"]) > 0
engine = data["engines"][0]
assert "type" in engine
assert "queries" in engine
assert "cost" in engine
@pytest.mark.asyncio
async def test_get_usage_by_engine_contains_data(self, async_client, usage_tracker):
from app.api.usage import set_usage_tracker
set_usage_tracker(usage_tracker)
response = await async_client.get("/api/v1/usage/by-engine")
assert response.status_code == 200
data = response.json()
types = [e["type"] for e in data["engines"]]
assert "deepseek" in types
assert "chatgpt" in types
class TestUnauthorizedAccess:
@pytest.mark.asyncio
async def test_api_keys_unauthorized_returns_401(self, async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/api-keys/", headers=headers)
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_usage_unauthorized_returns_401(self, async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/usage/summary", headers=headers)
assert response.status_code == 401
app.dependency_overrides.clear()