419 lines
13 KiB
Python
419 lines
13 KiB
Python
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.services.api_key_manager import APIKeyManager, KeySource
|
|
from app.services.auth import hash_password
|
|
from app.services.usage_tracker import UsageTracker
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="test@example.com",
|
|
password_hash=hash_password("Test@123456"),
|
|
name="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
def key_manager():
|
|
return APIKeyManager()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
def usage_tracker(test_user):
|
|
uid = str(test_user.id)
|
|
tracker = UsageTracker()
|
|
tracker.record(
|
|
user_id=uid,
|
|
brand_id="brand-1",
|
|
engine_type="deepseek",
|
|
query="test query",
|
|
input_tokens=100,
|
|
output_tokens=200,
|
|
cost=0.01,
|
|
)
|
|
tracker.record(
|
|
user_id=uid,
|
|
brand_id="brand-1",
|
|
engine_type="chatgpt",
|
|
query="test query 2",
|
|
input_tokens=500,
|
|
output_tokens=1000,
|
|
cost=0.05,
|
|
)
|
|
return tracker
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user, key_manager, usage_tracker):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestAddAPIKey:
|
|
@pytest.mark.asyncio
|
|
async def test_add_key_success(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/api-keys/",
|
|
json={
|
|
"engine_type": "chatgpt",
|
|
"api_key": "sk-abcdef1234567890",
|
|
"source": "user",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["engine_type"] == "chatgpt"
|
|
assert "key_hint" in data
|
|
assert data["key_hint"].startswith("sk-")
|
|
assert data["status"] == "active"
|
|
assert "api_key" not in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_key_no_plaintext_in_response(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
|
|
raw_key = "sk-super-secret-key-12345678"
|
|
response = await async_client.post(
|
|
"/api/v1/api-keys/",
|
|
json={
|
|
"engine_type": "deepseek",
|
|
"api_key": raw_key,
|
|
"source": "user",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
assert raw_key not in text
|
|
|
|
|
|
class TestListAPIKeys:
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_success(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
|
|
key_manager.add_key("deepseek", "dsk-xyz9876543210", source=KeySource.ENV)
|
|
|
|
response = await async_client.get("/api/v1/api-keys/")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "items" in data
|
|
assert len(data["items"]) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_filter_by_engine(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
|
|
key_manager.add_key("deepseek", "dsk-xyz9876543210", source=KeySource.ENV)
|
|
|
|
response = await async_client.get("/api/v1/api-keys/?engine_type=chatgpt")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["items"]) == 1
|
|
assert data["items"][0]["engine_type"] == "chatgpt"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_masked(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
|
|
|
|
response = await async_client.get("/api/v1/api-keys/")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for item in data["items"]:
|
|
assert "api_key" not in item
|
|
assert "encrypted_key" not in item
|
|
assert "key_hint" in item
|
|
|
|
|
|
class TestDeleteAPIKey:
|
|
@pytest.mark.asyncio
|
|
async def test_delete_key_success(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
config = key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
|
|
|
|
response = await async_client.delete(
|
|
f"/api/v1/api-keys/chatgpt/{config.key_hint}"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["deleted"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_key_not_found(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
|
|
response = await async_client.delete("/api/v1/api-keys/chatgpt/nonexistent")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestVerifyAPIKey:
|
|
@pytest.mark.asyncio
|
|
async def test_verify_key_success(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
key_manager.add_key("chatgpt", "sk-abcdef1234567890", source=KeySource.USER)
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/api-keys/verify",
|
|
json={"engine_type": "chatgpt"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["engine_type"] == "chatgpt"
|
|
assert data["status"] == "active"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_key_no_key_configured(self, async_client, key_manager):
|
|
from app.api.api_keys import set_key_manager
|
|
|
|
set_key_manager(key_manager)
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/api-keys/verify",
|
|
json={"engine_type": "perplexity"},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestGetEngines:
|
|
@pytest.mark.asyncio
|
|
async def test_get_engines_success(self, async_client):
|
|
response = await async_client.get("/api/v1/api-keys/engines")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "engines" in data
|
|
assert len(data["engines"]) > 0
|
|
|
|
engine = data["engines"][0]
|
|
assert "type" in engine
|
|
assert "cost_tier" in engine
|
|
assert "has_free_tier" in engine
|
|
assert "requires_own_key" in engine
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_engines_contains_deepseek(self, async_client):
|
|
response = await async_client.get("/api/v1/api-keys/engines")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
types = [e["type"] for e in data["engines"]]
|
|
assert "deepseek" in types
|
|
|
|
deepseek = next(e for e in data["engines"] if e["type"] == "deepseek")
|
|
assert deepseek["cost_tier"] == "free"
|
|
assert deepseek["has_free_tier"] is True
|
|
assert deepseek["requires_own_key"] is False
|
|
assert "input_price" in deepseek
|
|
assert "output_price" in deepseek
|
|
|
|
|
|
class TestUsageSummary:
|
|
@pytest.mark.asyncio
|
|
async def test_get_usage_summary(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/summary?period=month")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["period"] == "month"
|
|
assert data["total_queries"] == 2
|
|
assert data["total_cost"] > 0
|
|
assert "by_engine" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_usage_summary_day_period(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/summary?period=day")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["period"] == "day"
|
|
|
|
|
|
class TestUsageQuota:
|
|
@pytest.mark.asyncio
|
|
async def test_get_quota(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/quota")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "used" in data
|
|
assert "limit" in data
|
|
assert "usage_percentage" in data
|
|
assert "status" in data
|
|
assert data["status"] in ("ok", "warning", "exceeded")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_quota_ok_status(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/quota")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ok"
|
|
assert data["usage_percentage"] < 100
|
|
|
|
|
|
class TestUsageByEngine:
|
|
@pytest.mark.asyncio
|
|
async def test_get_usage_by_engine(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/by-engine")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "engines" in data
|
|
assert len(data["engines"]) > 0
|
|
|
|
engine = data["engines"][0]
|
|
assert "type" in engine
|
|
assert "queries" in engine
|
|
assert "cost" in engine
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_usage_by_engine_contains_data(self, async_client, usage_tracker):
|
|
from app.api.usage import set_usage_tracker
|
|
|
|
set_usage_tracker(usage_tracker)
|
|
|
|
response = await async_client.get("/api/v1/usage/by-engine")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
types = [e["type"] for e in data["engines"]]
|
|
assert "deepseek" in types
|
|
assert "chatgpt" in types
|
|
|
|
|
|
class TestUnauthorizedAccess:
|
|
@pytest.mark.asyncio
|
|
async def test_api_keys_unauthorized_returns_401(self, async_session):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
headers = {"Authorization": "Bearer invalid_token"}
|
|
response = await client.get("/api/v1/api-keys/", headers=headers)
|
|
assert response.status_code == 401
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_usage_unauthorized_returns_401(self, async_session):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
headers = {"Authorization": "Bearer invalid_token"}
|
|
response = await client.get("/api/v1/usage/summary", headers=headers)
|
|
assert response.status_code == 401
|
|
|
|
app.dependency_overrides.clear()
|