geo/backend/tests/test_infrastructure/test_database_migration.py

191 lines
6.0 KiB
Python

import pytest
import subprocess
import os
from pathlib import Path
# 项目根目录
PROJECT_ROOT = "/Users/Chiguyong/Code/Fischer/geo/backend"
# 从迁移文件 a79329c23b20_initial_complete_schema.py 提取的实际表名
REQUIRED_TABLES = [
"agent_registry",
"api_keys",
"brands",
"optimization_insights",
"organizations",
"platform_rule_versions",
"platform_rules",
"publish_records",
"alert_settings",
"alerts",
"competitor_insights",
"competitors",
"content_metrics",
"detection_tasks",
"diagnosis_records",
"monitoring_records",
"schema_suggestions",
"suggestions",
"trend_insights",
"usage_records",
"users",
"agent_configs",
"brand_knowledge",
"content_baselines",
"geo_plans",
"knowledge_bases",
"knowledge_search_logs",
"lifecycle_projects",
"org_members",
"payment_orders",
"queries",
"subscriptions",
"agent_tasks",
"citation_records",
"contents",
"geo_plan_actions",
"keywords",
"knowledge_documents",
"project_stages",
"query_tasks",
"agent_task_logs",
"attribution_records",
"content_reviews",
"content_versions",
"distribution_schedules",
"knowledge_chunks",
"knowledge_entities",
"knowledge_relations",
]
class TestDatabaseMigration:
"""数据库迁移验证测试"""
def test_alembic_current_shows_no_errors(self):
"""alembic current 应无错误输出"""
result = subprocess.run(
["alembic", "current"],
cwd=PROJECT_ROOT,
capture_output=True,
text=True
)
if result.returncode != 0 and "gaierror" in result.stderr:
pytest.skip("数据库不可用,跳过 alembic current 测试")
assert result.returncode == 0, f"alembic current failed: {result.stderr}"
def test_all_required_tables_exist(self):
"""所有必需表应存在"""
import asyncio
import sys
sys.path.insert(0, os.path.join(PROJECT_ROOT, "app"))
async def check_tables():
from app.database import engine
from sqlalchemy import inspect
async with engine.connect() as conn:
tables = await conn.run_sync(lambda sync_conn: inspect(sync_conn).get_table_names())
return set(tables)
try:
existing_tables = asyncio.run(check_tables())
except Exception as e:
pytest.skip(f"无法连接数据库: {e}")
missing_tables = [t for t in REQUIRED_TABLES if t not in existing_tables]
assert not missing_tables, f"缺失表: {missing_tables}"
def test_migration_head_matches_models(self):
"""迁移头应与模型定义一致"""
result = subprocess.run(
["alembic", "check"],
cwd=PROJECT_ROOT,
capture_output=True,
text=True
)
if result.returncode != 0 and "gaierror" in result.stderr:
pytest.skip("数据库不可用,跳过 alembic check 测试")
assert result.returncode == 0, f"alembic check failed: {result.stderr}"
def test_no_duplicate_migration_versions(self):
"""迁移版本号应唯一"""
versions_dir = Path(PROJECT_ROOT) / "alembic" / "versions"
if not versions_dir.exists():
pytest.skip("alembic/versions/ 目录不存在")
version_files = list(versions_dir.glob("*.py"))
version_numbers = []
for f in version_files:
name = f.stem
if "_" in name:
version_num = name.split("_")[0]
version_numbers.append((version_num, f.name))
seen = {}
duplicates = []
for version_num, filename in version_numbers:
if version_num in seen:
duplicates.append((version_num, filename, seen[version_num]))
else:
seen[version_num] = filename
assert not duplicates, f"发现重复版本号: {duplicates}"
def test_foreign_keys_integrity(self):
"""关键外键约束应存在"""
import asyncio
import sys
sys.path.insert(0, os.path.join(PROJECT_ROOT, "app"))
async def check_fks():
from app.database import engine
from sqlalchemy import inspect
async with engine.connect() as conn:
fks = await conn.run_sync(
lambda sync_conn: inspect(sync_conn).get_foreign_keys("brands")
)
return fks
try:
foreign_keys = asyncio.run(check_fks())
except Exception as e:
pytest.skip(f"无法连接数据库: {e}")
fk_columns = [fk['constrained_columns'] for fk in foreign_keys]
has_user_fk = any('user_id' in cols for cols in fk_columns)
assert has_user_fk, "brands 表缺少 user_id 外键"
def test_alembic_history_shows_migrations(self):
"""alembic history 应显示迁移历史"""
result = subprocess.run(
["alembic", "history"],
cwd=PROJECT_ROOT,
capture_output=True,
text=True
)
if result.returncode != 0 and "gaierror" in result.stderr:
pytest.skip("数据库不可用,跳过 alembic history 测试")
assert result.returncode == 0, f"alembic history failed: {result.stderr}"
assert result.stdout.strip(), "alembic history 为空"
@pytest.mark.skipif(
os.environ.get("CI") is None,
reason="alembic check drift detection 仅在 CI 环境中运行(需要完整数据库)"
)
def test_alembic_check_detects_drift(self):
"""alembic check 应检测到模型与迁移之间的漂移"""
result = subprocess.run(
["alembic", "check"],
cwd=PROJECT_ROOT,
capture_output=True,
text=True
)
# returncode 0 表示无漂移,非零表示存在漂移
assert result.returncode == 0, (
f"检测到迁移漂移!模型定义与迁移文件不一致:\n{result.stderr}\n{result.stdout}"
)