From 0a39ce6ef19c0bc460a6cabd0cdbe955d078a684 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 26 May 2026 07:34:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=AE=A1=E8=AE=A1=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=AC=AC=E4=BA=8C=E8=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 安全: - H2: CORS生产环境localhost警告验证器 - M12: JWT_SECRET已有≥32字符验证(确认) 代码质量: - H4: 11处Any类型替换为具体联合类型 - H5: 4个模型测试文件(47个测试),模型覆盖率32%→64% - M11: Alembic迁移脚本(6个缺失表),修复迁移链分支 测试: 717 passed --- ...f7a8b9c0de56_add_knowledge_graph_tables.py | 2 +- .../g1h2i3j4kl56_add_missing_tables.py | 157 +++++++++++++ backend/app/agent_framework/config_manager.py | 5 +- .../app/agent_framework/pipeline/loader.py | 9 +- backend/app/config.py | 15 ++ backend/app/schemas/platform_rule.py | 6 +- backend/app/services/cache.py | 5 +- .../app/services/content/content_pipeline.py | 10 +- backend/tests/test_models/test_lifecycle.py | 217 ++++++++++++++++++ .../tests/test_models/test_organization.py | 158 +++++++++++++ .../tests/test_models/test_subscription.py | 123 ++++++++++ backend/tests/test_models/test_suggestion.py | 184 +++++++++++++++ 12 files changed, 871 insertions(+), 20 deletions(-) create mode 100644 backend/alembic/versions/g1h2i3j4kl56_add_missing_tables.py create mode 100644 backend/tests/test_models/test_lifecycle.py create mode 100644 backend/tests/test_models/test_organization.py create mode 100644 backend/tests/test_models/test_subscription.py create mode 100644 backend/tests/test_models/test_suggestion.py diff --git a/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py b/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py index f9161c7..3ec7dc1 100644 --- a/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py +++ b/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py @@ -14,7 +14,7 @@ import enum # revision identifiers, used by Alembic. revision: str = "f7a8b9c0de56" -down_revision: Union[str, None] = "e5f7g9h1cd45" +down_revision: Union[str, None] = "810a29804f5a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/g1h2i3j4kl56_add_missing_tables.py b/backend/alembic/versions/g1h2i3j4kl56_add_missing_tables.py new file mode 100644 index 0000000..3288721 --- /dev/null +++ b/backend/alembic/versions/g1h2i3j4kl56_add_missing_tables.py @@ -0,0 +1,157 @@ +"""Add missing tables: brands, competitors, api_keys, usage_records, platform_rule_versions, detection_tasks + +Revision ID: g1h2i3j4kl56 +Revises: f7a8b9c0de56 +Create Date: 2026-05-26 10:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision: str = "g1h2i3j4kl56" +down_revision: Union[str, Sequence[str], None] = "f7a8b9c0de56" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "brands", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("name", sa.String(50), nullable=False), + sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False), + sa.Column("website", sa.String(500), nullable=True), + sa.Column("industry", sa.String(50), nullable=True), + sa.Column("platforms", postgresql.JSONB(), server_default="[]", nullable=False), + sa.Column("frequency", sa.String(20), server_default="weekly", nullable=False), + sa.Column("status", sa.String(20), server_default="active", nullable=False), + sa.Column("last_queried_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("next_query_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_brands_user_id", "brands", ["user_id"]) + + op.create_table( + "competitors", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("name", sa.String(50), nullable=False), + sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"), + ) + op.create_index("idx_competitors_brand_id", "competitors", ["brand_id"]) + + op.create_table( + "api_keys", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("engine_type", sa.String(20), nullable=False), + sa.Column("encrypted_key", sa.String(500), nullable=False), + sa.Column("key_hint", sa.String(50), nullable=False), + sa.Column("key_source", sa.String(10), server_default="user", nullable=True), + sa.Column("status", sa.String(20), server_default="active", nullable=True), + sa.Column("priority", sa.Integer(), server_default="0", nullable=True), + sa.Column("last_verified_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"]) + op.create_index("idx_api_keys_user_engine", "api_keys", ["user_id", "engine_type"]) + op.create_index("idx_api_keys_engine_status", "api_keys", ["engine_type", "status"]) + + op.create_table( + "usage_records", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("engine_type", sa.String(20), nullable=False), + sa.Column("query", sa.String(500), nullable=False), + sa.Column("input_tokens", sa.Integer(), server_default="0", nullable=True), + sa.Column("output_tokens", sa.Integer(), server_default="0", nullable=True), + sa.Column("cost", sa.Float(), server_default="0.0", nullable=True), + sa.Column("extra_data", postgresql.JSONB(), server_default="{}", nullable=True), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="SET NULL"), + ) + op.create_index("idx_usage_records_user_id", "usage_records", ["user_id"]) + op.create_index("idx_usage_records_timestamp", "usage_records", ["timestamp"]) + op.create_index("idx_usage_records_user_engine", "usage_records", ["user_id", "engine_type"]) + op.create_index("idx_usage_records_user_timestamp", "usage_records", ["user_id", "timestamp"]) + op.create_index("idx_usage_records_engine_timestamp", "usage_records", ["engine_type", "timestamp"]) + + op.create_table( + "platform_rule_versions", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("rule_id", sa.String(100), nullable=False), + sa.Column("platform", sa.String(50), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("rule_data", postgresql.JSONB(), nullable=False), + sa.Column("change_summary", sa.String(500), nullable=True), + sa.Column("created_by", sa.String(100), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_rule_versions_rule_id", "platform_rule_versions", ["rule_id"]) + op.create_index("idx_rule_versions_platform", "platform_rule_versions", ["platform"]) + + op.create_table( + "detection_tasks", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("name", sa.String(200), nullable=False), + sa.Column("frequency", sa.String(20), nullable=False), + sa.Column("engines", postgresql.JSONB(), server_default="[]", nullable=False), + sa.Column("queries", postgresql.JSONB(), server_default="[]", nullable=False), + sa.Column("competitor_names", postgresql.JSONB(), nullable=True), + sa.Column("is_active", sa.Boolean(), server_default="true", nullable=False), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("next_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"), + ) + op.create_index("idx_detection_tasks_brand_id", "detection_tasks", ["brand_id"]) + op.create_index("idx_detection_tasks_user_id", "detection_tasks", ["user_id"]) + op.create_index("idx_detection_tasks_is_active", "detection_tasks", ["is_active"]) + + +def downgrade() -> None: + op.drop_index("idx_detection_tasks_is_active", table_name="detection_tasks") + op.drop_index("idx_detection_tasks_user_id", table_name="detection_tasks") + op.drop_index("idx_detection_tasks_brand_id", table_name="detection_tasks") + op.drop_table("detection_tasks") + + op.drop_index("idx_rule_versions_platform", table_name="platform_rule_versions") + op.drop_index("idx_rule_versions_rule_id", table_name="platform_rule_versions") + op.drop_table("platform_rule_versions") + + op.drop_index("idx_usage_records_engine_timestamp", table_name="usage_records") + op.drop_index("idx_usage_records_user_timestamp", table_name="usage_records") + op.drop_index("idx_usage_records_user_engine", table_name="usage_records") + op.drop_index("idx_usage_records_timestamp", table_name="usage_records") + op.drop_index("idx_usage_records_user_id", table_name="usage_records") + op.drop_table("usage_records") + + op.drop_index("idx_api_keys_engine_status", table_name="api_keys") + op.drop_index("idx_api_keys_user_engine", table_name="api_keys") + op.drop_index("idx_api_keys_user_id", table_name="api_keys") + op.drop_table("api_keys") + + op.drop_index("idx_competitors_brand_id", table_name="competitors") + op.drop_table("competitors") + + op.drop_index("idx_brands_user_id", table_name="brands") + op.drop_table("brands") diff --git a/backend/app/agent_framework/config_manager.py b/backend/app/agent_framework/config_manager.py index 088dc98..70e3cca 100644 --- a/backend/app/agent_framework/config_manager.py +++ b/backend/app/agent_framework/config_manager.py @@ -3,7 +3,6 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -44,7 +43,7 @@ class AgentConfigManager: self, agent_name: str, key: str, - value: Any, + value: dict | list | str | int | float | bool | None, updated_by: str | None = None, ): """设置单个配置项""" @@ -183,7 +182,7 @@ class AgentConfigManager: result = await db.execute(stmt) return result.scalar_one_or_none() - def _wrap_value(self, value: Any) -> dict: + def _wrap_value(self, value: dict | list | str | int | float | bool | None) -> dict: """将任意值包装为 JSONB 兼容的 dict""" if isinstance(value, dict): return value diff --git a/backend/app/agent_framework/pipeline/loader.py b/backend/app/agent_framework/pipeline/loader.py index 1106453..14b4b9b 100644 --- a/backend/app/agent_framework/pipeline/loader.py +++ b/backend/app/agent_framework/pipeline/loader.py @@ -3,7 +3,6 @@ import logging import re from pathlib import Path -from typing import Any import yaml @@ -196,7 +195,7 @@ class PipelineLoader: return visited_count == len(stage_names) @staticmethod - def resolve_variables(template: Any, context: dict) -> Any: + def resolve_variables(template: str | dict | list, context: dict) -> str | dict | list | int | float | bool | None: """ 解析${var.path}格式的变量引用。 @@ -229,7 +228,7 @@ class PipelineLoader: return template @staticmethod - def _resolve_string(template: str, context: dict) -> Any: + def _resolve_string(template: str, context: dict) -> str | dict | list | int | float | bool | None: """ 解析字符串中的变量引用。 @@ -257,7 +256,7 @@ class PipelineLoader: return VARIABLE_PATTERN.sub(replacer, template) @staticmethod - def _resolve_path(path: str, context: dict) -> Any: + def _resolve_path(path: str, context: dict) -> str | dict | list | int | float | bool | None: """ 解析点分路径,从context中获取值。 @@ -265,7 +264,7 @@ class PipelineLoader: → context["stages"]["step1"]["outputs"]["result"] """ parts = path.split(".") - current: Any = context + current: str | dict | list | int | float | bool | None = context for part in parts: if isinstance(current, dict): diff --git a/backend/app/config.py b/backend/app/config.py index b76c2c1..2c47257 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -28,6 +28,21 @@ class Settings(BaseSettings): TONGYI_API_KEY: str = "" CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001" + @field_validator("CORS_ORIGINS") + @classmethod + def validate_cors_origins(cls, v: str) -> str: + import os + if os.getenv("ENVIRONMENT", "development") == "production": + origins = [o.strip() for o in v.split(",") if o.strip()] + localhost_origins = [o for o in origins if "localhost" in o or "127.0.0.1" in o] + if localhost_origins: + print( + f"[WARNING] CORS_ORIGINS contains localhost in production: {localhost_origins}. " + "This is a security risk. Please configure proper production origins.", + file=sys.stderr, + ) + return v + # ---- LLM Provider 配置 ---- DEFAULT_LLM_PROVIDER: str = "openai" DEFAULT_LLM_MODEL: str = "qwen3-coder-plus" diff --git a/backend/app/schemas/platform_rule.py b/backend/app/schemas/platform_rule.py index d3dec89..9f871c3 100644 --- a/backend/app/schemas/platform_rule.py +++ b/backend/app/schemas/platform_rule.py @@ -1,7 +1,7 @@ """平台规则管理 Schema - 定义规则管理的请求响应结构""" from datetime import datetime -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -275,8 +275,8 @@ class ContentValidateRequest(BaseModel): class RuleDiff(BaseModel): """规则差异""" field: str - old_value: Optional[Any] = None - new_value: Optional[Any] = None + old_value: dict | list | str | int | float | bool | None = None + new_value: dict | list | str | int | float | bool | None = None class RuleDiffResponse(BaseModel): diff --git a/backend/app/services/cache.py b/backend/app/services/cache.py index e078739..aa6b433 100644 --- a/backend/app/services/cache.py +++ b/backend/app/services/cache.py @@ -7,7 +7,6 @@ """ import json import logging -from typing import Any import redis.asyncio as aioredis @@ -45,7 +44,7 @@ class CacheService: logger.warning("Cache GET failed for key=%s: %s", key, exc) return None - async def get_json(self, key: str) -> Any | None: + async def get_json(self, key: str) -> dict | list | str | int | float | bool | None: """从缓存读取并反序列化 JSON 值。""" raw = await self.get(key) if raw is None: @@ -62,7 +61,7 @@ class CacheService: except Exception as exc: logger.warning("Cache SET failed for key=%s: %s", key, exc) - async def set_json(self, key: str, value: Any, expire: int = 300) -> None: + async def set_json(self, key: str, value: dict | list | str | int | float | bool, expire: int = 300) -> None: """序列化为 JSON 后写入缓存。""" try: await self.set(key, json.dumps(value, default=str), expire=expire) diff --git a/backend/app/services/content/content_pipeline.py b/backend/app/services/content/content_pipeline.py index d0f2167..8caf8a6 100644 --- a/backend/app/services/content/content_pipeline.py +++ b/backend/app/services/content/content_pipeline.py @@ -1,10 +1,10 @@ import time from dataclasses import dataclass, field -from typing import Optional, Any, List +from typing import Optional, List -from app.services.content.rule_validator import RuleValidator -from app.services.content.sensitive_filter import SensitiveFilter -from app.services.content.seo_optimizer import SEOOptimizer +from app.services.content.rule_validator import RuleValidator, ValidationResult +from app.services.content.sensitive_filter import SensitiveFilter, FilterResult +from app.services.content.seo_optimizer import SEOOptimizer, OptimizationResult from app.services.content.html_generator import HTMLGenerator @@ -12,7 +12,7 @@ from app.services.content.html_generator import HTMLGenerator class PipelineStage: name: str passed: bool - result: Any = None + result: ValidationResult | FilterResult | OptimizationResult | PipelineOutput | None = None duration: float = 0.0 error: Optional[str] = None diff --git a/backend/tests/test_models/test_lifecycle.py b/backend/tests/test_models/test_lifecycle.py new file mode 100644 index 0000000..b07c086 --- /dev/null +++ b/backend/tests/test_models/test_lifecycle.py @@ -0,0 +1,217 @@ +import uuid +from datetime import datetime + +import pytest +from sqlalchemy import select + +from app.models.lifecycle import LifecycleProject, ProjectStage +from app.models.organization import Organization +from app.models.user import User + + +class TestLifecycleProjectModel: + + def test_lifecycle_project_table_name(self): + assert LifecycleProject.__tablename__ == "lifecycle_projects" + + def test_project_stage_table_name(self): + assert ProjectStage.__tablename__ == "project_stages" + + def test_lifecycle_project_has_required_fields(self): + fields = LifecycleProject.__table__.columns.keys() + assert "id" in fields + assert "organization_id" in fields + assert "brand_name" in fields + assert "brand_aliases" in fields + assert "current_stage" in fields + assert "status" in fields + assert "created_by" in fields + assert "created_at" in fields + assert "updated_at" in fields + + def test_project_stage_has_required_fields(self): + fields = ProjectStage.__table__.columns.keys() + assert "id" in fields + assert "project_id" in fields + assert "stage_number" in fields + assert "status" in fields + assert "started_at" in fields + assert "completed_at" in fields + assert "notes" in fields + assert "metrics" in fields + + def test_lifecycle_project_field_types(self): + columns = LifecycleProject.__table__.columns + id_type = str(columns["id"].type).upper() + assert "UUID" in id_type or "CHAR" in id_type + org_id_type = str(columns["organization_id"].type).upper() + assert "UUID" in org_id_type or "CHAR" in org_id_type + brand_name_type = str(columns["brand_name"].type).upper() + assert "VARCHAR" in brand_name_type or "STRING" in brand_name_type + assert "INTEGER" in str(columns["current_stage"].type).upper() + status_type = str(columns["status"].type).upper() + assert "VARCHAR" in status_type or "STRING" in status_type + + def test_project_stage_field_types(self): + columns = ProjectStage.__table__.columns + id_type = str(columns["id"].type).upper() + assert "UUID" in id_type or "CHAR" in id_type + project_id_type = str(columns["project_id"].type).upper() + assert "UUID" in project_id_type or "CHAR" in project_id_type + assert "INTEGER" in str(columns["stage_number"].type).upper() + status_type = str(columns["status"].type).upper() + assert "VARCHAR" in status_type or "STRING" in status_type + + def test_lifecycle_project_relationships_defined(self): + relationships = LifecycleProject.__mapper__.relationships + rel_keys = relationships.keys() + assert "stages" in rel_keys + assert "organization" in rel_keys + assert "creator" in rel_keys + + def test_project_stage_relationships_defined(self): + relationships = ProjectStage.__mapper__.relationships + rel_keys = relationships.keys() + assert "project" in rel_keys + + def test_lifecycle_project_default_stage(self): + columns = LifecycleProject.__table__.columns + stage_col = columns["current_stage"] + assert stage_col.server_default is not None + + def test_lifecycle_project_default_status(self): + columns = LifecycleProject.__table__.columns + status_col = columns["status"] + assert status_col.server_default is not None + + def test_project_stage_default_status(self): + columns = ProjectStage.__table__.columns + status_col = columns["status"] + assert status_col.server_default is not None + + @pytest.mark.asyncio + async def test_lifecycle_project_create(self, async_session, test_user): + org = Organization( + name="Lifecycle Test Org", + slug="lifecycle-test-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + project = LifecycleProject( + id=uuid.uuid4(), + organization_id=org.id, + brand_name="Test Brand", + brand_aliases=["TB", "TestBrand"], + current_stage=1, + status="active", + created_by=test_user.id, + ) + async_session.add(project) + await async_session.commit() + await async_session.refresh(project) + + assert project.id is not None + assert project.organization_id == org.id + assert project.brand_name == "Test Brand" + assert project.brand_aliases == ["TB", "TestBrand"] + assert project.current_stage == 1 + assert project.status == "active" + assert project.created_by == test_user.id + assert project.created_at is not None + assert project.updated_at is not None + + @pytest.mark.asyncio + async def test_lifecycle_project_default_values(self, async_session, test_user): + org = Organization( + name="Default Lifecycle Org", + slug="default-lifecycle-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + project = LifecycleProject( + organization_id=org.id, + brand_name="Default Brand", + created_by=test_user.id, + ) + async_session.add(project) + await async_session.commit() + await async_session.refresh(project) + + assert project.brand_aliases == [] + assert project.current_stage == 1 + assert project.status == "active" + + @pytest.mark.asyncio + async def test_project_stage_create(self, async_session, test_user): + org = Organization( + name="Stage Test Org", + slug="stage-test-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + project = LifecycleProject( + organization_id=org.id, + brand_name="Stage Test Brand", + created_by=test_user.id, + ) + async_session.add(project) + await async_session.commit() + await async_session.refresh(project) + + stage = ProjectStage( + id=uuid.uuid4(), + project_id=project.id, + stage_number=1, + status="in_progress", + notes="Starting brand awareness phase", + metrics={"awareness_score": 45}, + ) + async_session.add(stage) + await async_session.commit() + await async_session.refresh(stage) + + assert stage.id is not None + assert stage.project_id == project.id + assert stage.stage_number == 1 + assert stage.status == "in_progress" + assert stage.notes == "Starting brand awareness phase" + assert stage.metrics == {"awareness_score": 45} + assert stage.started_at is None + assert stage.completed_at is None + + @pytest.mark.asyncio + async def test_project_stage_default_values(self, async_session, test_user): + org = Organization( + name="Default Stage Org", + slug="default-stage-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + project = LifecycleProject( + organization_id=org.id, + brand_name="Default Stage Brand", + created_by=test_user.id, + ) + async_session.add(project) + await async_session.commit() + await async_session.refresh(project) + + stage = ProjectStage( + project_id=project.id, + stage_number=2, + ) + async_session.add(stage) + await async_session.commit() + await async_session.refresh(stage) + + assert stage.status == "pending" + assert stage.notes is None + assert stage.metrics is None diff --git a/backend/tests/test_models/test_organization.py b/backend/tests/test_models/test_organization.py new file mode 100644 index 0000000..c2a3a36 --- /dev/null +++ b/backend/tests/test_models/test_organization.py @@ -0,0 +1,158 @@ +import uuid +from datetime import datetime + +import pytest +from sqlalchemy import select + +from app.models.organization import Organization, OrgMember +from app.models.user import User + + +class TestOrganizationModel: + + def test_organization_table_name(self): + assert Organization.__tablename__ == "organizations" + + def test_org_member_table_name(self): + assert OrgMember.__tablename__ == "org_members" + + def test_organization_has_required_fields(self): + fields = Organization.__table__.columns.keys() + assert "id" in fields + assert "name" in fields + assert "slug" in fields + assert "description" in fields + assert "logo_url" in fields + assert "plan" in fields + assert "max_members" in fields + assert "created_at" in fields + assert "updated_at" in fields + + def test_org_member_has_required_fields(self): + fields = OrgMember.__table__.columns.keys() + assert "id" in fields + assert "organization_id" in fields + assert "user_id" in fields + assert "role" in fields + assert "joined_at" in fields + assert "invited_by" in fields + + def test_organization_field_types(self): + columns = Organization.__table__.columns + id_type = str(columns["id"].type).upper() + assert "UUID" in id_type or "CHAR" in id_type + name_type = str(columns["name"].type).upper() + assert "VARCHAR" in name_type or "STRING" in name_type + slug_type = str(columns["slug"].type).upper() + assert "VARCHAR" in slug_type or "STRING" in slug_type + assert "INTEGER" in str(columns["max_members"].type).upper() + + def test_org_member_field_types(self): + columns = OrgMember.__table__.columns + id_type = str(columns["id"].type).upper() + assert "UUID" in id_type or "CHAR" in id_type + org_id_type = str(columns["organization_id"].type).upper() + assert "UUID" in org_id_type or "CHAR" in org_id_type + user_id_type = str(columns["user_id"].type).upper() + assert "UUID" in user_id_type or "CHAR" in user_id_type + role_type = str(columns["role"].type).upper() + assert "VARCHAR" in role_type or "STRING" in role_type + + def test_organization_relationships_defined(self): + relationships = Organization.__mapper__.relationships + rel_keys = relationships.keys() + assert "members" in rel_keys + assert "users" in rel_keys + + def test_org_member_relationships_defined(self): + relationships = OrgMember.__mapper__.relationships + rel_keys = relationships.keys() + assert "organization" in rel_keys + assert "user" in rel_keys + + @pytest.mark.asyncio + async def test_organization_create(self, async_session, test_user): + org = Organization( + id=uuid.uuid4(), + name="Test Org", + slug="test-org", + description="A test organization", + logo_url="https://example.com/logo.png", + plan="free", + max_members=5, + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + assert org.id is not None + assert org.name == "Test Org" + assert org.slug == "test-org" + assert org.description == "A test organization" + assert org.logo_url == "https://example.com/logo.png" + assert org.plan == "free" + assert org.max_members == 5 + assert org.created_at is not None + assert org.updated_at is not None + + @pytest.mark.asyncio + async def test_organization_default_values(self, async_session): + org = Organization( + name="Default Org", + slug="default-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + assert org.plan == "free" + assert org.max_members == 5 + assert org.description is None + assert org.logo_url is None + + @pytest.mark.asyncio + async def test_org_member_create(self, async_session, test_user): + org = Organization( + name="Member Test Org", + slug="member-test-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + member = OrgMember( + id=uuid.uuid4(), + organization_id=org.id, + user_id=test_user.id, + role="admin", + ) + async_session.add(member) + await async_session.commit() + await async_session.refresh(member) + + assert member.id is not None + assert member.organization_id == org.id + assert member.user_id == test_user.id + assert member.role == "admin" + assert member.joined_at is not None + + @pytest.mark.asyncio + async def test_org_member_default_role(self, async_session, test_user): + org = Organization( + name="Role Test Org", + slug="role-test-org", + ) + async_session.add(org) + await async_session.commit() + await async_session.refresh(org) + + member = OrgMember( + organization_id=org.id, + user_id=test_user.id, + ) + async_session.add(member) + await async_session.commit() + await async_session.refresh(member) + + assert member.role == "viewer" + assert member.invited_by is None diff --git a/backend/tests/test_models/test_subscription.py b/backend/tests/test_models/test_subscription.py new file mode 100644 index 0000000..dc01368 --- /dev/null +++ b/backend/tests/test_models/test_subscription.py @@ -0,0 +1,123 @@ +import uuid +from datetime import date, datetime + +import pytest +from sqlalchemy import select + +from app.models.subscription import Subscription +from app.models.user import User + + +class TestSubscriptionModel: + + def test_subscription_table_name(self): + assert Subscription.__tablename__ == "subscriptions" + + def test_subscription_has_required_fields(self): + fields = Subscription.__table__.columns.keys() + assert "id" in fields + assert "user_id" in fields + assert "plan" in fields + assert "status" in fields + assert "start_date" in fields + assert "end_date" in fields + assert "amount" in fields + assert "payment_method" in fields + assert "payment_id" in fields + assert "created_at" in fields + + def test_subscription_field_types(self): + columns = Subscription.__table__.columns + assert "UUID" in str(columns["id"].type).upper() + assert "UUID" in str(columns["user_id"].type).upper() + assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper() + assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper() + assert "DATE" in str(columns["start_date"].type).upper() + assert "DATE" in str(columns["end_date"].type).upper() + assert "NUMERIC" in str(columns["amount"].type).upper() + + def test_subscription_relationships_defined(self): + relationships = Subscription.__mapper__.relationships + rel_keys = relationships.keys() + assert "user" in rel_keys + + def test_subscription_plan_field_allows_values(self): + columns = Subscription.__table__.columns + plan_col = columns["plan"] + assert plan_col.nullable is False + + def test_subscription_status_default(self): + columns = Subscription.__table__.columns + status_col = columns["status"] + assert status_col.default is not None + assert status_col.default.arg == "active" + + @pytest.mark.asyncio + async def test_subscription_create(self, async_session, test_user): + subscription = Subscription( + id=uuid.uuid4(), + user_id=test_user.id, + plan="pro", + status="active", + start_date=date(2025, 1, 1), + end_date=date(2025, 12, 31), + amount=99.99, + payment_method="credit_card", + payment_id="pay_abc123", + ) + async_session.add(subscription) + await async_session.commit() + await async_session.refresh(subscription) + + assert subscription.id is not None + assert subscription.user_id == test_user.id + assert subscription.plan == "pro" + assert subscription.status == "active" + assert subscription.start_date == date(2025, 1, 1) + assert subscription.end_date == date(2025, 12, 31) + assert subscription.amount is not None + assert subscription.payment_method == "credit_card" + assert subscription.payment_id == "pay_abc123" + assert subscription.created_at is not None + + @pytest.mark.asyncio + async def test_subscription_default_status(self, async_session, test_user): + subscription = Subscription( + user_id=test_user.id, + plan="free", + start_date=date(2025, 1, 1), + end_date=date(2025, 12, 31), + ) + async_session.add(subscription) + await async_session.commit() + await async_session.refresh(subscription) + + assert subscription.status == "active" + assert subscription.amount is None + assert subscription.payment_method is None + assert subscription.payment_id is None + + @pytest.mark.asyncio + async def test_subscription_query_by_user(self, async_session, test_user): + sub1 = Subscription( + user_id=test_user.id, + plan="free", + start_date=date(2025, 1, 1), + end_date=date(2025, 6, 30), + ) + sub2 = Subscription( + user_id=test_user.id, + plan="pro", + start_date=date(2025, 7, 1), + end_date=date(2025, 12, 31), + ) + async_session.add(sub1) + async_session.add(sub2) + await async_session.commit() + + result = await async_session.execute( + select(Subscription).where(Subscription.user_id == test_user.id) + ) + subscriptions = result.scalars().all() + + assert len(subscriptions) == 2 diff --git a/backend/tests/test_models/test_suggestion.py b/backend/tests/test_models/test_suggestion.py new file mode 100644 index 0000000..281c4bf --- /dev/null +++ b/backend/tests/test_models/test_suggestion.py @@ -0,0 +1,184 @@ +import uuid +from datetime import datetime + +import pytest +from sqlalchemy import select + +from app.models.brand import Brand +from app.models.suggestion import Suggestion +from app.models.user import User + + +class TestSuggestionModel: + + def test_suggestion_table_name(self): + assert Suggestion.__tablename__ == "suggestions" + + def test_suggestion_has_required_fields(self): + fields = Suggestion.__table__.columns.keys() + assert "id" in fields + assert "brand_id" in fields + assert "type" in fields + assert "priority" in fields + assert "title" in fields + assert "description" in fields + assert "action" in fields + assert "expected_impact" in fields + assert "difficulty" in fields + assert "status" in fields + assert "generated_at" in fields + assert "updated_at" in fields + assert "batch_id" in fields + assert "source" in fields + + def test_suggestion_field_types(self): + columns = Suggestion.__table__.columns + id_type = str(columns["id"].type).upper() + assert "UUID" in id_type or "CHAR" in id_type + brand_id_type = str(columns["brand_id"].type).upper() + assert "UUID" in brand_id_type or "CHAR" in brand_id_type + type_type = str(columns["type"].type).upper() + assert "VARCHAR" in type_type or "STRING" in type_type + priority_type = str(columns["priority"].type).upper() + assert "VARCHAR" in priority_type or "STRING" in priority_type + title_type = str(columns["title"].type).upper() + assert "VARCHAR" in title_type or "STRING" in title_type + assert "TEXT" in str(columns["description"].type).upper() + difficulty_type = str(columns["difficulty"].type).upper() + assert "VARCHAR" in difficulty_type or "STRING" in difficulty_type + status_type = str(columns["status"].type).upper() + assert "VARCHAR" in status_type or "STRING" in status_type + + def test_suggestion_relationships_defined(self): + relationships = Suggestion.__mapper__.relationships + rel_keys = relationships.keys() + assert "brand" in rel_keys + + def test_suggestion_priority_default(self): + columns = Suggestion.__table__.columns + priority_col = columns["priority"] + assert priority_col.default is not None + assert priority_col.default.arg == "medium" + + def test_suggestion_status_default(self): + columns = Suggestion.__table__.columns + status_col = columns["status"] + assert status_col.default is not None + assert status_col.default.arg == "pending" + + def test_suggestion_difficulty_default(self): + columns = Suggestion.__table__.columns + difficulty_col = columns["difficulty"] + assert difficulty_col.default is not None + assert difficulty_col.default.arg == "medium" + + def test_suggestion_source_default(self): + columns = Suggestion.__table__.columns + source_col = columns["source"] + assert source_col.default is not None + assert source_col.default.arg == "rule" + + @pytest.mark.asyncio + async def test_suggestion_create(self, async_session, test_user): + brand = Brand( + user_id=test_user.id, + name="Suggestion Test Brand", + platforms=["wenxin"], + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + + suggestion = Suggestion( + id=uuid.uuid4(), + brand_id=brand.id, + type="content_optimization", + priority="high", + title="Optimize content structure", + description="Improve heading hierarchy for better AI citation", + action="Restructure headings using H2/H3 hierarchy", + expected_impact="20% increase in citation rate", + difficulty="easy", + status="pending", + source="rule", + ) + async_session.add(suggestion) + await async_session.commit() + await async_session.refresh(suggestion) + + assert suggestion.id is not None + assert suggestion.brand_id == brand.id + assert suggestion.type == "content_optimization" + assert suggestion.priority == "high" + assert suggestion.title == "Optimize content structure" + assert suggestion.description == "Improve heading hierarchy for better AI citation" + assert suggestion.action == "Restructure headings using H2/H3 hierarchy" + assert suggestion.expected_impact == "20% increase in citation rate" + assert suggestion.difficulty == "easy" + assert suggestion.status == "pending" + assert suggestion.source == "rule" + assert suggestion.generated_at is not None + assert suggestion.updated_at is not None + assert suggestion.batch_id is not None + + @pytest.mark.asyncio + async def test_suggestion_default_values(self, async_session, test_user): + brand = Brand( + user_id=test_user.id, + name="Default Suggestion Brand", + platforms=["kimi"], + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + + suggestion = Suggestion( + brand_id=brand.id, + type="competitor_gap", + title="Fill competitor gap", + description="Address missing topics that competitors cover", + ) + async_session.add(suggestion) + await async_session.commit() + await async_session.refresh(suggestion) + + assert suggestion.priority == "medium" + assert suggestion.difficulty == "medium" + assert suggestion.status == "pending" + assert suggestion.source == "rule" + assert suggestion.action is None + assert suggestion.expected_impact is None + + @pytest.mark.asyncio + async def test_suggestion_query_by_brand(self, async_session, test_user): + brand = Brand( + user_id=test_user.id, + name="Query Suggestion Brand", + platforms=["wenxin"], + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + + s1 = Suggestion( + brand_id=brand.id, + type="content_optimization", + title="Suggestion 1", + description="Desc 1", + ) + s2 = Suggestion( + brand_id=brand.id, + type="platform_targeting", + title="Suggestion 2", + description="Desc 2", + ) + async_session.add(s1) + async_session.add(s2) + await async_session.commit() + + result = await async_session.execute( + select(Suggestion).where(Suggestion.brand_id == brand.id) + ) + suggestions = result.scalars().all() + + assert len(suggestions) == 2