fix: 审计问题修复第二轮

安全:
- H2: CORS生产环境localhost警告验证器
- M12: JWT_SECRET已有≥32字符验证(确认)

代码质量:
- H4: 11处Any类型替换为具体联合类型
- H5: 4个模型测试文件(47个测试),模型覆盖率32%→64%
- M11: Alembic迁移脚本(6个缺失表),修复迁移链分支

测试: 717 passed
This commit is contained in:
chiguyong 2026-05-26 07:34:07 +08:00
parent aeaa50e89e
commit 0a39ce6ef1
12 changed files with 871 additions and 20 deletions

View File

@ -14,7 +14,7 @@ import enum
# revision identifiers, used by Alembic.
revision: str = "f7a8b9c0de56"
down_revision: Union[str, None] = "e5f7g9h1cd45"
down_revision: Union[str, None] = "810a29804f5a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

View File

@ -0,0 +1,157 @@
"""Add missing tables: brands, competitors, api_keys, usage_records, platform_rule_versions, detection_tasks
Revision ID: g1h2i3j4kl56
Revises: f7a8b9c0de56
Create Date: 2026-05-26 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = "g1h2i3j4kl56"
down_revision: Union[str, Sequence[str], None] = "f7a8b9c0de56"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"brands",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(50), nullable=False),
sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False),
sa.Column("website", sa.String(500), nullable=True),
sa.Column("industry", sa.String(50), nullable=True),
sa.Column("platforms", postgresql.JSONB(), server_default="[]", nullable=False),
sa.Column("frequency", sa.String(20), server_default="weekly", nullable=False),
sa.Column("status", sa.String(20), server_default="active", nullable=False),
sa.Column("last_queried_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("next_query_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_brands_user_id", "brands", ["user_id"])
op.create_table(
"competitors",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(50), nullable=False),
sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
)
op.create_index("idx_competitors_brand_id", "competitors", ["brand_id"])
op.create_table(
"api_keys",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("engine_type", sa.String(20), nullable=False),
sa.Column("encrypted_key", sa.String(500), nullable=False),
sa.Column("key_hint", sa.String(50), nullable=False),
sa.Column("key_source", sa.String(10), server_default="user", nullable=True),
sa.Column("status", sa.String(20), server_default="active", nullable=True),
sa.Column("priority", sa.Integer(), server_default="0", nullable=True),
sa.Column("last_verified_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"])
op.create_index("idx_api_keys_user_engine", "api_keys", ["user_id", "engine_type"])
op.create_index("idx_api_keys_engine_status", "api_keys", ["engine_type", "status"])
op.create_table(
"usage_records",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("engine_type", sa.String(20), nullable=False),
sa.Column("query", sa.String(500), nullable=False),
sa.Column("input_tokens", sa.Integer(), server_default="0", nullable=True),
sa.Column("output_tokens", sa.Integer(), server_default="0", nullable=True),
sa.Column("cost", sa.Float(), server_default="0.0", nullable=True),
sa.Column("extra_data", postgresql.JSONB(), server_default="{}", nullable=True),
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="SET NULL"),
)
op.create_index("idx_usage_records_user_id", "usage_records", ["user_id"])
op.create_index("idx_usage_records_timestamp", "usage_records", ["timestamp"])
op.create_index("idx_usage_records_user_engine", "usage_records", ["user_id", "engine_type"])
op.create_index("idx_usage_records_user_timestamp", "usage_records", ["user_id", "timestamp"])
op.create_index("idx_usage_records_engine_timestamp", "usage_records", ["engine_type", "timestamp"])
op.create_table(
"platform_rule_versions",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("rule_id", sa.String(100), nullable=False),
sa.Column("platform", sa.String(50), nullable=False),
sa.Column("version", sa.Integer(), nullable=False),
sa.Column("rule_data", postgresql.JSONB(), nullable=False),
sa.Column("change_summary", sa.String(500), nullable=True),
sa.Column("created_by", sa.String(100), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_rule_versions_rule_id", "platform_rule_versions", ["rule_id"])
op.create_index("idx_rule_versions_platform", "platform_rule_versions", ["platform"])
op.create_table(
"detection_tasks",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(200), nullable=False),
sa.Column("frequency", sa.String(20), nullable=False),
sa.Column("engines", postgresql.JSONB(), server_default="[]", nullable=False),
sa.Column("queries", postgresql.JSONB(), server_default="[]", nullable=False),
sa.Column("competitor_names", postgresql.JSONB(), nullable=True),
sa.Column("is_active", sa.Boolean(), server_default="true", nullable=False),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("next_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
)
op.create_index("idx_detection_tasks_brand_id", "detection_tasks", ["brand_id"])
op.create_index("idx_detection_tasks_user_id", "detection_tasks", ["user_id"])
op.create_index("idx_detection_tasks_is_active", "detection_tasks", ["is_active"])
def downgrade() -> None:
op.drop_index("idx_detection_tasks_is_active", table_name="detection_tasks")
op.drop_index("idx_detection_tasks_user_id", table_name="detection_tasks")
op.drop_index("idx_detection_tasks_brand_id", table_name="detection_tasks")
op.drop_table("detection_tasks")
op.drop_index("idx_rule_versions_platform", table_name="platform_rule_versions")
op.drop_index("idx_rule_versions_rule_id", table_name="platform_rule_versions")
op.drop_table("platform_rule_versions")
op.drop_index("idx_usage_records_engine_timestamp", table_name="usage_records")
op.drop_index("idx_usage_records_user_timestamp", table_name="usage_records")
op.drop_index("idx_usage_records_user_engine", table_name="usage_records")
op.drop_index("idx_usage_records_timestamp", table_name="usage_records")
op.drop_index("idx_usage_records_user_id", table_name="usage_records")
op.drop_table("usage_records")
op.drop_index("idx_api_keys_engine_status", table_name="api_keys")
op.drop_index("idx_api_keys_user_engine", table_name="api_keys")
op.drop_index("idx_api_keys_user_id", table_name="api_keys")
op.drop_table("api_keys")
op.drop_index("idx_competitors_brand_id", table_name="competitors")
op.drop_table("competitors")
op.drop_index("idx_brands_user_id", table_name="brands")
op.drop_table("brands")

View File

@ -3,7 +3,6 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@ -44,7 +43,7 @@ class AgentConfigManager:
self,
agent_name: str,
key: str,
value: Any,
value: dict | list | str | int | float | bool | None,
updated_by: str | None = None,
):
"""设置单个配置项"""
@ -183,7 +182,7 @@ class AgentConfigManager:
result = await db.execute(stmt)
return result.scalar_one_or_none()
def _wrap_value(self, value: Any) -> dict:
def _wrap_value(self, value: dict | list | str | int | float | bool | None) -> dict:
"""将任意值包装为 JSONB 兼容的 dict"""
if isinstance(value, dict):
return value

View File

@ -3,7 +3,6 @@
import logging
import re
from pathlib import Path
from typing import Any
import yaml
@ -196,7 +195,7 @@ class PipelineLoader:
return visited_count == len(stage_names)
@staticmethod
def resolve_variables(template: Any, context: dict) -> Any:
def resolve_variables(template: str | dict | list, context: dict) -> str | dict | list | int | float | bool | None:
"""
解析${var.path}格式的变量引用
@ -229,7 +228,7 @@ class PipelineLoader:
return template
@staticmethod
def _resolve_string(template: str, context: dict) -> Any:
def _resolve_string(template: str, context: dict) -> str | dict | list | int | float | bool | None:
"""
解析字符串中的变量引用
@ -257,7 +256,7 @@ class PipelineLoader:
return VARIABLE_PATTERN.sub(replacer, template)
@staticmethod
def _resolve_path(path: str, context: dict) -> Any:
def _resolve_path(path: str, context: dict) -> str | dict | list | int | float | bool | None:
"""
解析点分路径从context中获取值
@ -265,7 +264,7 @@ class PipelineLoader:
context["stages"]["step1"]["outputs"]["result"]
"""
parts = path.split(".")
current: Any = context
current: str | dict | list | int | float | bool | None = context
for part in parts:
if isinstance(current, dict):

View File

@ -28,6 +28,21 @@ class Settings(BaseSettings):
TONGYI_API_KEY: str = ""
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"
@field_validator("CORS_ORIGINS")
@classmethod
def validate_cors_origins(cls, v: str) -> str:
import os
if os.getenv("ENVIRONMENT", "development") == "production":
origins = [o.strip() for o in v.split(",") if o.strip()]
localhost_origins = [o for o in origins if "localhost" in o or "127.0.0.1" in o]
if localhost_origins:
print(
f"[WARNING] CORS_ORIGINS contains localhost in production: {localhost_origins}. "
"This is a security risk. Please configure proper production origins.",
file=sys.stderr,
)
return v
# ---- LLM Provider 配置 ----
DEFAULT_LLM_PROVIDER: str = "openai"
DEFAULT_LLM_MODEL: str = "qwen3-coder-plus"

View File

@ -1,7 +1,7 @@
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
from datetime import datetime
from typing import Any, Optional
from typing import Optional
from pydantic import BaseModel, Field
@ -275,8 +275,8 @@ class ContentValidateRequest(BaseModel):
class RuleDiff(BaseModel):
"""规则差异"""
field: str
old_value: Optional[Any] = None
new_value: Optional[Any] = None
old_value: dict | list | str | int | float | bool | None = None
new_value: dict | list | str | int | float | bool | None = None
class RuleDiffResponse(BaseModel):

View File

@ -7,7 +7,6 @@
"""
import json
import logging
from typing import Any
import redis.asyncio as aioredis
@ -45,7 +44,7 @@ class CacheService:
logger.warning("Cache GET failed for key=%s: %s", key, exc)
return None
async def get_json(self, key: str) -> Any | None:
async def get_json(self, key: str) -> dict | list | str | int | float | bool | None:
"""从缓存读取并反序列化 JSON 值。"""
raw = await self.get(key)
if raw is None:
@ -62,7 +61,7 @@ class CacheService:
except Exception as exc:
logger.warning("Cache SET failed for key=%s: %s", key, exc)
async def set_json(self, key: str, value: Any, expire: int = 300) -> None:
async def set_json(self, key: str, value: dict | list | str | int | float | bool, expire: int = 300) -> None:
"""序列化为 JSON 后写入缓存。"""
try:
await self.set(key, json.dumps(value, default=str), expire=expire)

View File

@ -1,10 +1,10 @@
import time
from dataclasses import dataclass, field
from typing import Optional, Any, List
from typing import Optional, List
from app.services.content.rule_validator import RuleValidator
from app.services.content.sensitive_filter import SensitiveFilter
from app.services.content.seo_optimizer import SEOOptimizer
from app.services.content.rule_validator import RuleValidator, ValidationResult
from app.services.content.sensitive_filter import SensitiveFilter, FilterResult
from app.services.content.seo_optimizer import SEOOptimizer, OptimizationResult
from app.services.content.html_generator import HTMLGenerator
@ -12,7 +12,7 @@ from app.services.content.html_generator import HTMLGenerator
class PipelineStage:
name: str
passed: bool
result: Any = None
result: ValidationResult | FilterResult | OptimizationResult | PipelineOutput | None = None
duration: float = 0.0
error: Optional[str] = None

View File

@ -0,0 +1,217 @@
import uuid
from datetime import datetime
import pytest
from sqlalchemy import select
from app.models.lifecycle import LifecycleProject, ProjectStage
from app.models.organization import Organization
from app.models.user import User
class TestLifecycleProjectModel:
def test_lifecycle_project_table_name(self):
assert LifecycleProject.__tablename__ == "lifecycle_projects"
def test_project_stage_table_name(self):
assert ProjectStage.__tablename__ == "project_stages"
def test_lifecycle_project_has_required_fields(self):
fields = LifecycleProject.__table__.columns.keys()
assert "id" in fields
assert "organization_id" in fields
assert "brand_name" in fields
assert "brand_aliases" in fields
assert "current_stage" in fields
assert "status" in fields
assert "created_by" in fields
assert "created_at" in fields
assert "updated_at" in fields
def test_project_stage_has_required_fields(self):
fields = ProjectStage.__table__.columns.keys()
assert "id" in fields
assert "project_id" in fields
assert "stage_number" in fields
assert "status" in fields
assert "started_at" in fields
assert "completed_at" in fields
assert "notes" in fields
assert "metrics" in fields
def test_lifecycle_project_field_types(self):
columns = LifecycleProject.__table__.columns
id_type = str(columns["id"].type).upper()
assert "UUID" in id_type or "CHAR" in id_type
org_id_type = str(columns["organization_id"].type).upper()
assert "UUID" in org_id_type or "CHAR" in org_id_type
brand_name_type = str(columns["brand_name"].type).upper()
assert "VARCHAR" in brand_name_type or "STRING" in brand_name_type
assert "INTEGER" in str(columns["current_stage"].type).upper()
status_type = str(columns["status"].type).upper()
assert "VARCHAR" in status_type or "STRING" in status_type
def test_project_stage_field_types(self):
columns = ProjectStage.__table__.columns
id_type = str(columns["id"].type).upper()
assert "UUID" in id_type or "CHAR" in id_type
project_id_type = str(columns["project_id"].type).upper()
assert "UUID" in project_id_type or "CHAR" in project_id_type
assert "INTEGER" in str(columns["stage_number"].type).upper()
status_type = str(columns["status"].type).upper()
assert "VARCHAR" in status_type or "STRING" in status_type
def test_lifecycle_project_relationships_defined(self):
relationships = LifecycleProject.__mapper__.relationships
rel_keys = relationships.keys()
assert "stages" in rel_keys
assert "organization" in rel_keys
assert "creator" in rel_keys
def test_project_stage_relationships_defined(self):
relationships = ProjectStage.__mapper__.relationships
rel_keys = relationships.keys()
assert "project" in rel_keys
def test_lifecycle_project_default_stage(self):
columns = LifecycleProject.__table__.columns
stage_col = columns["current_stage"]
assert stage_col.server_default is not None
def test_lifecycle_project_default_status(self):
columns = LifecycleProject.__table__.columns
status_col = columns["status"]
assert status_col.server_default is not None
def test_project_stage_default_status(self):
columns = ProjectStage.__table__.columns
status_col = columns["status"]
assert status_col.server_default is not None
@pytest.mark.asyncio
async def test_lifecycle_project_create(self, async_session, test_user):
org = Organization(
name="Lifecycle Test Org",
slug="lifecycle-test-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
project = LifecycleProject(
id=uuid.uuid4(),
organization_id=org.id,
brand_name="Test Brand",
brand_aliases=["TB", "TestBrand"],
current_stage=1,
status="active",
created_by=test_user.id,
)
async_session.add(project)
await async_session.commit()
await async_session.refresh(project)
assert project.id is not None
assert project.organization_id == org.id
assert project.brand_name == "Test Brand"
assert project.brand_aliases == ["TB", "TestBrand"]
assert project.current_stage == 1
assert project.status == "active"
assert project.created_by == test_user.id
assert project.created_at is not None
assert project.updated_at is not None
@pytest.mark.asyncio
async def test_lifecycle_project_default_values(self, async_session, test_user):
org = Organization(
name="Default Lifecycle Org",
slug="default-lifecycle-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
project = LifecycleProject(
organization_id=org.id,
brand_name="Default Brand",
created_by=test_user.id,
)
async_session.add(project)
await async_session.commit()
await async_session.refresh(project)
assert project.brand_aliases == []
assert project.current_stage == 1
assert project.status == "active"
@pytest.mark.asyncio
async def test_project_stage_create(self, async_session, test_user):
org = Organization(
name="Stage Test Org",
slug="stage-test-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
project = LifecycleProject(
organization_id=org.id,
brand_name="Stage Test Brand",
created_by=test_user.id,
)
async_session.add(project)
await async_session.commit()
await async_session.refresh(project)
stage = ProjectStage(
id=uuid.uuid4(),
project_id=project.id,
stage_number=1,
status="in_progress",
notes="Starting brand awareness phase",
metrics={"awareness_score": 45},
)
async_session.add(stage)
await async_session.commit()
await async_session.refresh(stage)
assert stage.id is not None
assert stage.project_id == project.id
assert stage.stage_number == 1
assert stage.status == "in_progress"
assert stage.notes == "Starting brand awareness phase"
assert stage.metrics == {"awareness_score": 45}
assert stage.started_at is None
assert stage.completed_at is None
@pytest.mark.asyncio
async def test_project_stage_default_values(self, async_session, test_user):
org = Organization(
name="Default Stage Org",
slug="default-stage-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
project = LifecycleProject(
organization_id=org.id,
brand_name="Default Stage Brand",
created_by=test_user.id,
)
async_session.add(project)
await async_session.commit()
await async_session.refresh(project)
stage = ProjectStage(
project_id=project.id,
stage_number=2,
)
async_session.add(stage)
await async_session.commit()
await async_session.refresh(stage)
assert stage.status == "pending"
assert stage.notes is None
assert stage.metrics is None

View File

@ -0,0 +1,158 @@
import uuid
from datetime import datetime
import pytest
from sqlalchemy import select
from app.models.organization import Organization, OrgMember
from app.models.user import User
class TestOrganizationModel:
def test_organization_table_name(self):
assert Organization.__tablename__ == "organizations"
def test_org_member_table_name(self):
assert OrgMember.__tablename__ == "org_members"
def test_organization_has_required_fields(self):
fields = Organization.__table__.columns.keys()
assert "id" in fields
assert "name" in fields
assert "slug" in fields
assert "description" in fields
assert "logo_url" in fields
assert "plan" in fields
assert "max_members" in fields
assert "created_at" in fields
assert "updated_at" in fields
def test_org_member_has_required_fields(self):
fields = OrgMember.__table__.columns.keys()
assert "id" in fields
assert "organization_id" in fields
assert "user_id" in fields
assert "role" in fields
assert "joined_at" in fields
assert "invited_by" in fields
def test_organization_field_types(self):
columns = Organization.__table__.columns
id_type = str(columns["id"].type).upper()
assert "UUID" in id_type or "CHAR" in id_type
name_type = str(columns["name"].type).upper()
assert "VARCHAR" in name_type or "STRING" in name_type
slug_type = str(columns["slug"].type).upper()
assert "VARCHAR" in slug_type or "STRING" in slug_type
assert "INTEGER" in str(columns["max_members"].type).upper()
def test_org_member_field_types(self):
columns = OrgMember.__table__.columns
id_type = str(columns["id"].type).upper()
assert "UUID" in id_type or "CHAR" in id_type
org_id_type = str(columns["organization_id"].type).upper()
assert "UUID" in org_id_type or "CHAR" in org_id_type
user_id_type = str(columns["user_id"].type).upper()
assert "UUID" in user_id_type or "CHAR" in user_id_type
role_type = str(columns["role"].type).upper()
assert "VARCHAR" in role_type or "STRING" in role_type
def test_organization_relationships_defined(self):
relationships = Organization.__mapper__.relationships
rel_keys = relationships.keys()
assert "members" in rel_keys
assert "users" in rel_keys
def test_org_member_relationships_defined(self):
relationships = OrgMember.__mapper__.relationships
rel_keys = relationships.keys()
assert "organization" in rel_keys
assert "user" in rel_keys
@pytest.mark.asyncio
async def test_organization_create(self, async_session, test_user):
org = Organization(
id=uuid.uuid4(),
name="Test Org",
slug="test-org",
description="A test organization",
logo_url="https://example.com/logo.png",
plan="free",
max_members=5,
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
assert org.id is not None
assert org.name == "Test Org"
assert org.slug == "test-org"
assert org.description == "A test organization"
assert org.logo_url == "https://example.com/logo.png"
assert org.plan == "free"
assert org.max_members == 5
assert org.created_at is not None
assert org.updated_at is not None
@pytest.mark.asyncio
async def test_organization_default_values(self, async_session):
org = Organization(
name="Default Org",
slug="default-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
assert org.plan == "free"
assert org.max_members == 5
assert org.description is None
assert org.logo_url is None
@pytest.mark.asyncio
async def test_org_member_create(self, async_session, test_user):
org = Organization(
name="Member Test Org",
slug="member-test-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
member = OrgMember(
id=uuid.uuid4(),
organization_id=org.id,
user_id=test_user.id,
role="admin",
)
async_session.add(member)
await async_session.commit()
await async_session.refresh(member)
assert member.id is not None
assert member.organization_id == org.id
assert member.user_id == test_user.id
assert member.role == "admin"
assert member.joined_at is not None
@pytest.mark.asyncio
async def test_org_member_default_role(self, async_session, test_user):
org = Organization(
name="Role Test Org",
slug="role-test-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
member = OrgMember(
organization_id=org.id,
user_id=test_user.id,
)
async_session.add(member)
await async_session.commit()
await async_session.refresh(member)
assert member.role == "viewer"
assert member.invited_by is None

View File

@ -0,0 +1,123 @@
import uuid
from datetime import date, datetime
import pytest
from sqlalchemy import select
from app.models.subscription import Subscription
from app.models.user import User
class TestSubscriptionModel:
def test_subscription_table_name(self):
assert Subscription.__tablename__ == "subscriptions"
def test_subscription_has_required_fields(self):
fields = Subscription.__table__.columns.keys()
assert "id" in fields
assert "user_id" in fields
assert "plan" in fields
assert "status" in fields
assert "start_date" in fields
assert "end_date" in fields
assert "amount" in fields
assert "payment_method" in fields
assert "payment_id" in fields
assert "created_at" in fields
def test_subscription_field_types(self):
columns = Subscription.__table__.columns
assert "UUID" in str(columns["id"].type).upper()
assert "UUID" in str(columns["user_id"].type).upper()
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
assert "DATE" in str(columns["start_date"].type).upper()
assert "DATE" in str(columns["end_date"].type).upper()
assert "NUMERIC" in str(columns["amount"].type).upper()
def test_subscription_relationships_defined(self):
relationships = Subscription.__mapper__.relationships
rel_keys = relationships.keys()
assert "user" in rel_keys
def test_subscription_plan_field_allows_values(self):
columns = Subscription.__table__.columns
plan_col = columns["plan"]
assert plan_col.nullable is False
def test_subscription_status_default(self):
columns = Subscription.__table__.columns
status_col = columns["status"]
assert status_col.default is not None
assert status_col.default.arg == "active"
@pytest.mark.asyncio
async def test_subscription_create(self, async_session, test_user):
subscription = Subscription(
id=uuid.uuid4(),
user_id=test_user.id,
plan="pro",
status="active",
start_date=date(2025, 1, 1),
end_date=date(2025, 12, 31),
amount=99.99,
payment_method="credit_card",
payment_id="pay_abc123",
)
async_session.add(subscription)
await async_session.commit()
await async_session.refresh(subscription)
assert subscription.id is not None
assert subscription.user_id == test_user.id
assert subscription.plan == "pro"
assert subscription.status == "active"
assert subscription.start_date == date(2025, 1, 1)
assert subscription.end_date == date(2025, 12, 31)
assert subscription.amount is not None
assert subscription.payment_method == "credit_card"
assert subscription.payment_id == "pay_abc123"
assert subscription.created_at is not None
@pytest.mark.asyncio
async def test_subscription_default_status(self, async_session, test_user):
subscription = Subscription(
user_id=test_user.id,
plan="free",
start_date=date(2025, 1, 1),
end_date=date(2025, 12, 31),
)
async_session.add(subscription)
await async_session.commit()
await async_session.refresh(subscription)
assert subscription.status == "active"
assert subscription.amount is None
assert subscription.payment_method is None
assert subscription.payment_id is None
@pytest.mark.asyncio
async def test_subscription_query_by_user(self, async_session, test_user):
sub1 = Subscription(
user_id=test_user.id,
plan="free",
start_date=date(2025, 1, 1),
end_date=date(2025, 6, 30),
)
sub2 = Subscription(
user_id=test_user.id,
plan="pro",
start_date=date(2025, 7, 1),
end_date=date(2025, 12, 31),
)
async_session.add(sub1)
async_session.add(sub2)
await async_session.commit()
result = await async_session.execute(
select(Subscription).where(Subscription.user_id == test_user.id)
)
subscriptions = result.scalars().all()
assert len(subscriptions) == 2

View File

@ -0,0 +1,184 @@
import uuid
from datetime import datetime
import pytest
from sqlalchemy import select
from app.models.brand import Brand
from app.models.suggestion import Suggestion
from app.models.user import User
class TestSuggestionModel:
def test_suggestion_table_name(self):
assert Suggestion.__tablename__ == "suggestions"
def test_suggestion_has_required_fields(self):
fields = Suggestion.__table__.columns.keys()
assert "id" in fields
assert "brand_id" in fields
assert "type" in fields
assert "priority" in fields
assert "title" in fields
assert "description" in fields
assert "action" in fields
assert "expected_impact" in fields
assert "difficulty" in fields
assert "status" in fields
assert "generated_at" in fields
assert "updated_at" in fields
assert "batch_id" in fields
assert "source" in fields
def test_suggestion_field_types(self):
columns = Suggestion.__table__.columns
id_type = str(columns["id"].type).upper()
assert "UUID" in id_type or "CHAR" in id_type
brand_id_type = str(columns["brand_id"].type).upper()
assert "UUID" in brand_id_type or "CHAR" in brand_id_type
type_type = str(columns["type"].type).upper()
assert "VARCHAR" in type_type or "STRING" in type_type
priority_type = str(columns["priority"].type).upper()
assert "VARCHAR" in priority_type or "STRING" in priority_type
title_type = str(columns["title"].type).upper()
assert "VARCHAR" in title_type or "STRING" in title_type
assert "TEXT" in str(columns["description"].type).upper()
difficulty_type = str(columns["difficulty"].type).upper()
assert "VARCHAR" in difficulty_type or "STRING" in difficulty_type
status_type = str(columns["status"].type).upper()
assert "VARCHAR" in status_type or "STRING" in status_type
def test_suggestion_relationships_defined(self):
relationships = Suggestion.__mapper__.relationships
rel_keys = relationships.keys()
assert "brand" in rel_keys
def test_suggestion_priority_default(self):
columns = Suggestion.__table__.columns
priority_col = columns["priority"]
assert priority_col.default is not None
assert priority_col.default.arg == "medium"
def test_suggestion_status_default(self):
columns = Suggestion.__table__.columns
status_col = columns["status"]
assert status_col.default is not None
assert status_col.default.arg == "pending"
def test_suggestion_difficulty_default(self):
columns = Suggestion.__table__.columns
difficulty_col = columns["difficulty"]
assert difficulty_col.default is not None
assert difficulty_col.default.arg == "medium"
def test_suggestion_source_default(self):
columns = Suggestion.__table__.columns
source_col = columns["source"]
assert source_col.default is not None
assert source_col.default.arg == "rule"
@pytest.mark.asyncio
async def test_suggestion_create(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
name="Suggestion Test Brand",
platforms=["wenxin"],
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
suggestion = Suggestion(
id=uuid.uuid4(),
brand_id=brand.id,
type="content_optimization",
priority="high",
title="Optimize content structure",
description="Improve heading hierarchy for better AI citation",
action="Restructure headings using H2/H3 hierarchy",
expected_impact="20% increase in citation rate",
difficulty="easy",
status="pending",
source="rule",
)
async_session.add(suggestion)
await async_session.commit()
await async_session.refresh(suggestion)
assert suggestion.id is not None
assert suggestion.brand_id == brand.id
assert suggestion.type == "content_optimization"
assert suggestion.priority == "high"
assert suggestion.title == "Optimize content structure"
assert suggestion.description == "Improve heading hierarchy for better AI citation"
assert suggestion.action == "Restructure headings using H2/H3 hierarchy"
assert suggestion.expected_impact == "20% increase in citation rate"
assert suggestion.difficulty == "easy"
assert suggestion.status == "pending"
assert suggestion.source == "rule"
assert suggestion.generated_at is not None
assert suggestion.updated_at is not None
assert suggestion.batch_id is not None
@pytest.mark.asyncio
async def test_suggestion_default_values(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
name="Default Suggestion Brand",
platforms=["kimi"],
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
suggestion = Suggestion(
brand_id=brand.id,
type="competitor_gap",
title="Fill competitor gap",
description="Address missing topics that competitors cover",
)
async_session.add(suggestion)
await async_session.commit()
await async_session.refresh(suggestion)
assert suggestion.priority == "medium"
assert suggestion.difficulty == "medium"
assert suggestion.status == "pending"
assert suggestion.source == "rule"
assert suggestion.action is None
assert suggestion.expected_impact is None
@pytest.mark.asyncio
async def test_suggestion_query_by_brand(self, async_session, test_user):
brand = Brand(
user_id=test_user.id,
name="Query Suggestion Brand",
platforms=["wenxin"],
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
s1 = Suggestion(
brand_id=brand.id,
type="content_optimization",
title="Suggestion 1",
description="Desc 1",
)
s2 = Suggestion(
brand_id=brand.id,
type="platform_targeting",
title="Suggestion 2",
description="Desc 2",
)
async_session.add(s1)
async_session.add(s2)
await async_session.commit()
result = await async_session.execute(
select(Suggestion).where(Suggestion.brand_id == brand.id)
)
suggestions = result.scalars().all()
assert len(suggestions) == 2