131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import pytest
|
|
import subprocess
|
|
import os
|
|
from pathlib import Path
|
|
|
|
# 项目根目录
|
|
PROJECT_ROOT = "/Users/Chiguyong/Code/Fischer/geo/backend"
|
|
|
|
|
|
class TestDatabaseMigration:
|
|
"""数据库迁移验证测试"""
|
|
|
|
def test_alembic_current_shows_no_errors(self):
|
|
"""alembic current 应无错误输出"""
|
|
result = subprocess.run(
|
|
["alembic", "current"],
|
|
cwd=PROJECT_ROOT,
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
assert result.returncode == 0, f"alembic current failed: {result.stderr}"
|
|
|
|
def test_all_required_tables_exist(self):
|
|
"""所有必需表应存在"""
|
|
import asyncio
|
|
import sys
|
|
sys.path.insert(0, os.path.join(PROJECT_ROOT, "app"))
|
|
|
|
async def check_tables():
|
|
from app.database import engine
|
|
from sqlalchemy import inspect
|
|
|
|
async with engine.connect() as conn:
|
|
inspector = inspect(conn)
|
|
tables = await conn.run_sync(lambda sync_conn: inspect(sync_conn).get_table_names())
|
|
return set(tables)
|
|
|
|
try:
|
|
existing_tables = asyncio.run(check_tables())
|
|
except Exception as e:
|
|
pytest.skip(f"无法连接数据库: {e}")
|
|
|
|
required_tables = [
|
|
"users", "brands", "competitors", "queries",
|
|
"citations", "alerts", "alert_settings", "suggestions",
|
|
"organizations", "subscriptions", "subscription_plans",
|
|
"agent_configs", "agent_executions", "knowledge_bases",
|
|
"content_versions", "platform_rules", "analytics_events",
|
|
"distribution_schedules", "client_brands"
|
|
]
|
|
|
|
missing_tables = [t for t in required_tables if t not in existing_tables]
|
|
assert not missing_tables, f"缺失表: {missing_tables}"
|
|
|
|
def test_migration_head_matches_models(self):
|
|
"""迁移头应与模型定义一致"""
|
|
result = subprocess.run(
|
|
["alembic", "check"],
|
|
cwd=PROJECT_ROOT,
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
assert result.returncode == 0, f"alembic check failed: {result.stderr}"
|
|
|
|
def test_no_duplicate_migration_versions(self):
|
|
"""迁移版本号应唯一"""
|
|
versions_dir = Path(PROJECT_ROOT) / "alembic" / "versions"
|
|
if not versions_dir.exists():
|
|
pytest.skip("alembic/versions/ 目录不存在")
|
|
|
|
version_files = list(versions_dir.glob("*.py"))
|
|
version_numbers = []
|
|
|
|
for f in version_files:
|
|
# 版本号通常是目录名中下划线前的部分
|
|
# 例如: 059724556401_add_missing_sentiment_fields.py
|
|
name = f.stem
|
|
# 取第一个下划线前的部分作为版本号
|
|
if "_" in name:
|
|
version_num = name.split("_")[0]
|
|
version_numbers.append((version_num, f.name))
|
|
|
|
# 检查重复
|
|
seen = {}
|
|
duplicates = []
|
|
for version_num, filename in version_numbers:
|
|
if version_num in seen:
|
|
duplicates.append((version_num, filename, seen[version_num]))
|
|
else:
|
|
seen[version_num] = filename
|
|
|
|
assert not duplicates, f"发现重复版本号: {duplicates}"
|
|
|
|
def test_foreign_keys_integrity(self):
|
|
"""关键外键约束应存在"""
|
|
import asyncio
|
|
import sys
|
|
sys.path.insert(0, os.path.join(PROJECT_ROOT, "app"))
|
|
|
|
async def check_fks():
|
|
from app.database import engine
|
|
from sqlalchemy import inspect
|
|
|
|
async with engine.connect() as conn:
|
|
fks = await conn.run_sync(
|
|
lambda sync_conn: inspect(sync_conn).get_foreign_keys("brands")
|
|
)
|
|
return fks
|
|
|
|
try:
|
|
foreign_keys = asyncio.run(check_fks())
|
|
except Exception as e:
|
|
pytest.skip(f"无法连接数据库: {e}")
|
|
|
|
fk_columns = [fk['constrained_columns'] for fk in foreign_keys]
|
|
|
|
# 检查 brands.user_id 外键存在
|
|
has_user_fk = any('user_id' in cols for cols in fk_columns)
|
|
assert has_user_fk, "brands 表缺少 user_id 外键"
|
|
|
|
def test_alembic_history_shows_migrations(self):
|
|
"""alembic history 应显示迁移历史"""
|
|
result = subprocess.run(
|
|
["alembic", "history"],
|
|
cwd=PROJECT_ROOT,
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
assert result.returncode == 0, f"alembic history failed: {result.stderr}"
|
|
assert result.stdout.strip(), "alembic history 为空"
|