fix: 审计问题修复第二轮
安全: - H2: CORS生产环境localhost警告验证器 - M12: JWT_SECRET已有≥32字符验证(确认) 代码质量: - H4: 11处Any类型替换为具体联合类型 - H5: 4个模型测试文件(47个测试),模型覆盖率32%→64% - M11: Alembic迁移脚本(6个缺失表),修复迁移链分支 测试: 717 passed
This commit is contained in:
parent
aeaa50e89e
commit
0a39ce6ef1
|
|
@ -14,7 +14,7 @@ import enum
|
|||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f7a8b9c0de56"
|
||||
down_revision: Union[str, None] = "e5f7g9h1cd45"
|
||||
down_revision: Union[str, None] = "810a29804f5a"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,157 @@
|
|||
"""Add missing tables: brands, competitors, api_keys, usage_records, platform_rule_versions, detection_tasks
|
||||
|
||||
Revision ID: g1h2i3j4kl56
|
||||
Revises: f7a8b9c0de56
|
||||
Create Date: 2026-05-26 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "g1h2i3j4kl56"
|
||||
down_revision: Union[str, Sequence[str], None] = "f7a8b9c0de56"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"brands",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(50), nullable=False),
|
||||
sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False),
|
||||
sa.Column("website", sa.String(500), nullable=True),
|
||||
sa.Column("industry", sa.String(50), nullable=True),
|
||||
sa.Column("platforms", postgresql.JSONB(), server_default="[]", nullable=False),
|
||||
sa.Column("frequency", sa.String(20), server_default="weekly", nullable=False),
|
||||
sa.Column("status", sa.String(20), server_default="active", nullable=False),
|
||||
sa.Column("last_queried_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("next_query_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_brands_user_id", "brands", ["user_id"])
|
||||
|
||||
op.create_table(
|
||||
"competitors",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(50), nullable=False),
|
||||
sa.Column("aliases", postgresql.JSONB(), server_default="[]", nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("idx_competitors_brand_id", "competitors", ["brand_id"])
|
||||
|
||||
op.create_table(
|
||||
"api_keys",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("engine_type", sa.String(20), nullable=False),
|
||||
sa.Column("encrypted_key", sa.String(500), nullable=False),
|
||||
sa.Column("key_hint", sa.String(50), nullable=False),
|
||||
sa.Column("key_source", sa.String(10), server_default="user", nullable=True),
|
||||
sa.Column("status", sa.String(20), server_default="active", nullable=True),
|
||||
sa.Column("priority", sa.Integer(), server_default="0", nullable=True),
|
||||
sa.Column("last_verified_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"])
|
||||
op.create_index("idx_api_keys_user_engine", "api_keys", ["user_id", "engine_type"])
|
||||
op.create_index("idx_api_keys_engine_status", "api_keys", ["engine_type", "status"])
|
||||
|
||||
op.create_table(
|
||||
"usage_records",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("engine_type", sa.String(20), nullable=False),
|
||||
sa.Column("query", sa.String(500), nullable=False),
|
||||
sa.Column("input_tokens", sa.Integer(), server_default="0", nullable=True),
|
||||
sa.Column("output_tokens", sa.Integer(), server_default="0", nullable=True),
|
||||
sa.Column("cost", sa.Float(), server_default="0.0", nullable=True),
|
||||
sa.Column("extra_data", postgresql.JSONB(), server_default="{}", nullable=True),
|
||||
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="SET NULL"),
|
||||
)
|
||||
op.create_index("idx_usage_records_user_id", "usage_records", ["user_id"])
|
||||
op.create_index("idx_usage_records_timestamp", "usage_records", ["timestamp"])
|
||||
op.create_index("idx_usage_records_user_engine", "usage_records", ["user_id", "engine_type"])
|
||||
op.create_index("idx_usage_records_user_timestamp", "usage_records", ["user_id", "timestamp"])
|
||||
op.create_index("idx_usage_records_engine_timestamp", "usage_records", ["engine_type", "timestamp"])
|
||||
|
||||
op.create_table(
|
||||
"platform_rule_versions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("rule_id", sa.String(100), nullable=False),
|
||||
sa.Column("platform", sa.String(50), nullable=False),
|
||||
sa.Column("version", sa.Integer(), nullable=False),
|
||||
sa.Column("rule_data", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("change_summary", sa.String(500), nullable=True),
|
||||
sa.Column("created_by", sa.String(100), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_rule_versions_rule_id", "platform_rule_versions", ["rule_id"])
|
||||
op.create_index("idx_rule_versions_platform", "platform_rule_versions", ["platform"])
|
||||
|
||||
op.create_table(
|
||||
"detection_tasks",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("brand_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(200), nullable=False),
|
||||
sa.Column("frequency", sa.String(20), nullable=False),
|
||||
sa.Column("engines", postgresql.JSONB(), server_default="[]", nullable=False),
|
||||
sa.Column("queries", postgresql.JSONB(), server_default="[]", nullable=False),
|
||||
sa.Column("competitor_names", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), server_default="true", nullable=False),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("next_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("idx_detection_tasks_brand_id", "detection_tasks", ["brand_id"])
|
||||
op.create_index("idx_detection_tasks_user_id", "detection_tasks", ["user_id"])
|
||||
op.create_index("idx_detection_tasks_is_active", "detection_tasks", ["is_active"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_detection_tasks_is_active", table_name="detection_tasks")
|
||||
op.drop_index("idx_detection_tasks_user_id", table_name="detection_tasks")
|
||||
op.drop_index("idx_detection_tasks_brand_id", table_name="detection_tasks")
|
||||
op.drop_table("detection_tasks")
|
||||
|
||||
op.drop_index("idx_rule_versions_platform", table_name="platform_rule_versions")
|
||||
op.drop_index("idx_rule_versions_rule_id", table_name="platform_rule_versions")
|
||||
op.drop_table("platform_rule_versions")
|
||||
|
||||
op.drop_index("idx_usage_records_engine_timestamp", table_name="usage_records")
|
||||
op.drop_index("idx_usage_records_user_timestamp", table_name="usage_records")
|
||||
op.drop_index("idx_usage_records_user_engine", table_name="usage_records")
|
||||
op.drop_index("idx_usage_records_timestamp", table_name="usage_records")
|
||||
op.drop_index("idx_usage_records_user_id", table_name="usage_records")
|
||||
op.drop_table("usage_records")
|
||||
|
||||
op.drop_index("idx_api_keys_engine_status", table_name="api_keys")
|
||||
op.drop_index("idx_api_keys_user_engine", table_name="api_keys")
|
||||
op.drop_index("idx_api_keys_user_id", table_name="api_keys")
|
||||
op.drop_table("api_keys")
|
||||
|
||||
op.drop_index("idx_competitors_brand_id", table_name="competitors")
|
||||
op.drop_table("competitors")
|
||||
|
||||
op.drop_index("idx_brands_user_id", table_name="brands")
|
||||
op.drop_table("brands")
|
||||
|
|
@ -3,7 +3,6 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -44,7 +43,7 @@ class AgentConfigManager:
|
|||
self,
|
||||
agent_name: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
value: dict | list | str | int | float | bool | None,
|
||||
updated_by: str | None = None,
|
||||
):
|
||||
"""设置单个配置项"""
|
||||
|
|
@ -183,7 +182,7 @@ class AgentConfigManager:
|
|||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _wrap_value(self, value: Any) -> dict:
|
||||
def _wrap_value(self, value: dict | list | str | int | float | bool | None) -> dict:
|
||||
"""将任意值包装为 JSONB 兼容的 dict"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -196,7 +195,7 @@ class PipelineLoader:
|
|||
return visited_count == len(stage_names)
|
||||
|
||||
@staticmethod
|
||||
def resolve_variables(template: Any, context: dict) -> Any:
|
||||
def resolve_variables(template: str | dict | list, context: dict) -> str | dict | list | int | float | bool | None:
|
||||
"""
|
||||
解析${var.path}格式的变量引用。
|
||||
|
||||
|
|
@ -229,7 +228,7 @@ class PipelineLoader:
|
|||
return template
|
||||
|
||||
@staticmethod
|
||||
def _resolve_string(template: str, context: dict) -> Any:
|
||||
def _resolve_string(template: str, context: dict) -> str | dict | list | int | float | bool | None:
|
||||
"""
|
||||
解析字符串中的变量引用。
|
||||
|
||||
|
|
@ -257,7 +256,7 @@ class PipelineLoader:
|
|||
return VARIABLE_PATTERN.sub(replacer, template)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(path: str, context: dict) -> Any:
|
||||
def _resolve_path(path: str, context: dict) -> str | dict | list | int | float | bool | None:
|
||||
"""
|
||||
解析点分路径,从context中获取值。
|
||||
|
||||
|
|
@ -265,7 +264,7 @@ class PipelineLoader:
|
|||
→ context["stages"]["step1"]["outputs"]["result"]
|
||||
"""
|
||||
parts = path.split(".")
|
||||
current: Any = context
|
||||
current: str | dict | list | int | float | bool | None = context
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict):
|
||||
|
|
|
|||
|
|
@ -28,6 +28,21 @@ class Settings(BaseSettings):
|
|||
TONGYI_API_KEY: str = ""
|
||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"
|
||||
|
||||
@field_validator("CORS_ORIGINS")
|
||||
@classmethod
|
||||
def validate_cors_origins(cls, v: str) -> str:
|
||||
import os
|
||||
if os.getenv("ENVIRONMENT", "development") == "production":
|
||||
origins = [o.strip() for o in v.split(",") if o.strip()]
|
||||
localhost_origins = [o for o in origins if "localhost" in o or "127.0.0.1" in o]
|
||||
if localhost_origins:
|
||||
print(
|
||||
f"[WARNING] CORS_ORIGINS contains localhost in production: {localhost_origins}. "
|
||||
"This is a security risk. Please configure proper production origins.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return v
|
||||
|
||||
# ---- LLM Provider 配置 ----
|
||||
DEFAULT_LLM_PROVIDER: str = "openai"
|
||||
DEFAULT_LLM_MODEL: str = "qwen3-coder-plus"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
@ -275,8 +275,8 @@ class ContentValidateRequest(BaseModel):
|
|||
class RuleDiff(BaseModel):
|
||||
"""规则差异"""
|
||||
field: str
|
||||
old_value: Optional[Any] = None
|
||||
new_value: Optional[Any] = None
|
||||
old_value: dict | list | str | int | float | bool | None = None
|
||||
new_value: dict | list | str | int | float | bool | None = None
|
||||
|
||||
|
||||
class RuleDiffResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
|
|
@ -45,7 +44,7 @@ class CacheService:
|
|||
logger.warning("Cache GET failed for key=%s: %s", key, exc)
|
||||
return None
|
||||
|
||||
async def get_json(self, key: str) -> Any | None:
|
||||
async def get_json(self, key: str) -> dict | list | str | int | float | bool | None:
|
||||
"""从缓存读取并反序列化 JSON 值。"""
|
||||
raw = await self.get(key)
|
||||
if raw is None:
|
||||
|
|
@ -62,7 +61,7 @@ class CacheService:
|
|||
except Exception as exc:
|
||||
logger.warning("Cache SET failed for key=%s: %s", key, exc)
|
||||
|
||||
async def set_json(self, key: str, value: Any, expire: int = 300) -> None:
|
||||
async def set_json(self, key: str, value: dict | list | str | int | float | bool, expire: int = 300) -> None:
|
||||
"""序列化为 JSON 后写入缓存。"""
|
||||
try:
|
||||
await self.set(key, json.dumps(value, default=str), expire=expire)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from typing import Optional, List
|
||||
|
||||
from app.services.content.rule_validator import RuleValidator
|
||||
from app.services.content.sensitive_filter import SensitiveFilter
|
||||
from app.services.content.seo_optimizer import SEOOptimizer
|
||||
from app.services.content.rule_validator import RuleValidator, ValidationResult
|
||||
from app.services.content.sensitive_filter import SensitiveFilter, FilterResult
|
||||
from app.services.content.seo_optimizer import SEOOptimizer, OptimizationResult
|
||||
from app.services.content.html_generator import HTMLGenerator
|
||||
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from app.services.content.html_generator import HTMLGenerator
|
|||
class PipelineStage:
|
||||
name: str
|
||||
passed: bool
|
||||
result: Any = None
|
||||
result: ValidationResult | FilterResult | OptimizationResult | PipelineOutput | None = None
|
||||
duration: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,217 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.lifecycle import LifecycleProject, ProjectStage
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestLifecycleProjectModel:
|
||||
|
||||
def test_lifecycle_project_table_name(self):
|
||||
assert LifecycleProject.__tablename__ == "lifecycle_projects"
|
||||
|
||||
def test_project_stage_table_name(self):
|
||||
assert ProjectStage.__tablename__ == "project_stages"
|
||||
|
||||
def test_lifecycle_project_has_required_fields(self):
|
||||
fields = LifecycleProject.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "organization_id" in fields
|
||||
assert "brand_name" in fields
|
||||
assert "brand_aliases" in fields
|
||||
assert "current_stage" in fields
|
||||
assert "status" in fields
|
||||
assert "created_by" in fields
|
||||
assert "created_at" in fields
|
||||
assert "updated_at" in fields
|
||||
|
||||
def test_project_stage_has_required_fields(self):
|
||||
fields = ProjectStage.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "project_id" in fields
|
||||
assert "stage_number" in fields
|
||||
assert "status" in fields
|
||||
assert "started_at" in fields
|
||||
assert "completed_at" in fields
|
||||
assert "notes" in fields
|
||||
assert "metrics" in fields
|
||||
|
||||
def test_lifecycle_project_field_types(self):
|
||||
columns = LifecycleProject.__table__.columns
|
||||
id_type = str(columns["id"].type).upper()
|
||||
assert "UUID" in id_type or "CHAR" in id_type
|
||||
org_id_type = str(columns["organization_id"].type).upper()
|
||||
assert "UUID" in org_id_type or "CHAR" in org_id_type
|
||||
brand_name_type = str(columns["brand_name"].type).upper()
|
||||
assert "VARCHAR" in brand_name_type or "STRING" in brand_name_type
|
||||
assert "INTEGER" in str(columns["current_stage"].type).upper()
|
||||
status_type = str(columns["status"].type).upper()
|
||||
assert "VARCHAR" in status_type or "STRING" in status_type
|
||||
|
||||
def test_project_stage_field_types(self):
|
||||
columns = ProjectStage.__table__.columns
|
||||
id_type = str(columns["id"].type).upper()
|
||||
assert "UUID" in id_type or "CHAR" in id_type
|
||||
project_id_type = str(columns["project_id"].type).upper()
|
||||
assert "UUID" in project_id_type or "CHAR" in project_id_type
|
||||
assert "INTEGER" in str(columns["stage_number"].type).upper()
|
||||
status_type = str(columns["status"].type).upper()
|
||||
assert "VARCHAR" in status_type or "STRING" in status_type
|
||||
|
||||
def test_lifecycle_project_relationships_defined(self):
|
||||
relationships = LifecycleProject.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "stages" in rel_keys
|
||||
assert "organization" in rel_keys
|
||||
assert "creator" in rel_keys
|
||||
|
||||
def test_project_stage_relationships_defined(self):
|
||||
relationships = ProjectStage.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "project" in rel_keys
|
||||
|
||||
def test_lifecycle_project_default_stage(self):
|
||||
columns = LifecycleProject.__table__.columns
|
||||
stage_col = columns["current_stage"]
|
||||
assert stage_col.server_default is not None
|
||||
|
||||
def test_lifecycle_project_default_status(self):
|
||||
columns = LifecycleProject.__table__.columns
|
||||
status_col = columns["status"]
|
||||
assert status_col.server_default is not None
|
||||
|
||||
def test_project_stage_default_status(self):
|
||||
columns = ProjectStage.__table__.columns
|
||||
status_col = columns["status"]
|
||||
assert status_col.server_default is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_project_create(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Lifecycle Test Org",
|
||||
slug="lifecycle-test-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
project = LifecycleProject(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=org.id,
|
||||
brand_name="Test Brand",
|
||||
brand_aliases=["TB", "TestBrand"],
|
||||
current_stage=1,
|
||||
status="active",
|
||||
created_by=test_user.id,
|
||||
)
|
||||
async_session.add(project)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(project)
|
||||
|
||||
assert project.id is not None
|
||||
assert project.organization_id == org.id
|
||||
assert project.brand_name == "Test Brand"
|
||||
assert project.brand_aliases == ["TB", "TestBrand"]
|
||||
assert project.current_stage == 1
|
||||
assert project.status == "active"
|
||||
assert project.created_by == test_user.id
|
||||
assert project.created_at is not None
|
||||
assert project.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_project_default_values(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Default Lifecycle Org",
|
||||
slug="default-lifecycle-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
project = LifecycleProject(
|
||||
organization_id=org.id,
|
||||
brand_name="Default Brand",
|
||||
created_by=test_user.id,
|
||||
)
|
||||
async_session.add(project)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(project)
|
||||
|
||||
assert project.brand_aliases == []
|
||||
assert project.current_stage == 1
|
||||
assert project.status == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_stage_create(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Stage Test Org",
|
||||
slug="stage-test-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
project = LifecycleProject(
|
||||
organization_id=org.id,
|
||||
brand_name="Stage Test Brand",
|
||||
created_by=test_user.id,
|
||||
)
|
||||
async_session.add(project)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(project)
|
||||
|
||||
stage = ProjectStage(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
stage_number=1,
|
||||
status="in_progress",
|
||||
notes="Starting brand awareness phase",
|
||||
metrics={"awareness_score": 45},
|
||||
)
|
||||
async_session.add(stage)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(stage)
|
||||
|
||||
assert stage.id is not None
|
||||
assert stage.project_id == project.id
|
||||
assert stage.stage_number == 1
|
||||
assert stage.status == "in_progress"
|
||||
assert stage.notes == "Starting brand awareness phase"
|
||||
assert stage.metrics == {"awareness_score": 45}
|
||||
assert stage.started_at is None
|
||||
assert stage.completed_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_stage_default_values(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Default Stage Org",
|
||||
slug="default-stage-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
project = LifecycleProject(
|
||||
organization_id=org.id,
|
||||
brand_name="Default Stage Brand",
|
||||
created_by=test_user.id,
|
||||
)
|
||||
async_session.add(project)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(project)
|
||||
|
||||
stage = ProjectStage(
|
||||
project_id=project.id,
|
||||
stage_number=2,
|
||||
)
|
||||
async_session.add(stage)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(stage)
|
||||
|
||||
assert stage.status == "pending"
|
||||
assert stage.notes is None
|
||||
assert stage.metrics is None
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestOrganizationModel:
|
||||
|
||||
def test_organization_table_name(self):
|
||||
assert Organization.__tablename__ == "organizations"
|
||||
|
||||
def test_org_member_table_name(self):
|
||||
assert OrgMember.__tablename__ == "org_members"
|
||||
|
||||
def test_organization_has_required_fields(self):
|
||||
fields = Organization.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "name" in fields
|
||||
assert "slug" in fields
|
||||
assert "description" in fields
|
||||
assert "logo_url" in fields
|
||||
assert "plan" in fields
|
||||
assert "max_members" in fields
|
||||
assert "created_at" in fields
|
||||
assert "updated_at" in fields
|
||||
|
||||
def test_org_member_has_required_fields(self):
|
||||
fields = OrgMember.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "organization_id" in fields
|
||||
assert "user_id" in fields
|
||||
assert "role" in fields
|
||||
assert "joined_at" in fields
|
||||
assert "invited_by" in fields
|
||||
|
||||
def test_organization_field_types(self):
|
||||
columns = Organization.__table__.columns
|
||||
id_type = str(columns["id"].type).upper()
|
||||
assert "UUID" in id_type or "CHAR" in id_type
|
||||
name_type = str(columns["name"].type).upper()
|
||||
assert "VARCHAR" in name_type or "STRING" in name_type
|
||||
slug_type = str(columns["slug"].type).upper()
|
||||
assert "VARCHAR" in slug_type or "STRING" in slug_type
|
||||
assert "INTEGER" in str(columns["max_members"].type).upper()
|
||||
|
||||
def test_org_member_field_types(self):
|
||||
columns = OrgMember.__table__.columns
|
||||
id_type = str(columns["id"].type).upper()
|
||||
assert "UUID" in id_type or "CHAR" in id_type
|
||||
org_id_type = str(columns["organization_id"].type).upper()
|
||||
assert "UUID" in org_id_type or "CHAR" in org_id_type
|
||||
user_id_type = str(columns["user_id"].type).upper()
|
||||
assert "UUID" in user_id_type or "CHAR" in user_id_type
|
||||
role_type = str(columns["role"].type).upper()
|
||||
assert "VARCHAR" in role_type or "STRING" in role_type
|
||||
|
||||
def test_organization_relationships_defined(self):
|
||||
relationships = Organization.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "members" in rel_keys
|
||||
assert "users" in rel_keys
|
||||
|
||||
def test_org_member_relationships_defined(self):
|
||||
relationships = OrgMember.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "organization" in rel_keys
|
||||
assert "user" in rel_keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_create(self, async_session, test_user):
|
||||
org = Organization(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Org",
|
||||
slug="test-org",
|
||||
description="A test organization",
|
||||
logo_url="https://example.com/logo.png",
|
||||
plan="free",
|
||||
max_members=5,
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
assert org.id is not None
|
||||
assert org.name == "Test Org"
|
||||
assert org.slug == "test-org"
|
||||
assert org.description == "A test organization"
|
||||
assert org.logo_url == "https://example.com/logo.png"
|
||||
assert org.plan == "free"
|
||||
assert org.max_members == 5
|
||||
assert org.created_at is not None
|
||||
assert org.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_organization_default_values(self, async_session):
|
||||
org = Organization(
|
||||
name="Default Org",
|
||||
slug="default-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
assert org.plan == "free"
|
||||
assert org.max_members == 5
|
||||
assert org.description is None
|
||||
assert org.logo_url is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_member_create(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Member Test Org",
|
||||
slug="member-test-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
member = OrgMember(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=org.id,
|
||||
user_id=test_user.id,
|
||||
role="admin",
|
||||
)
|
||||
async_session.add(member)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(member)
|
||||
|
||||
assert member.id is not None
|
||||
assert member.organization_id == org.id
|
||||
assert member.user_id == test_user.id
|
||||
assert member.role == "admin"
|
||||
assert member.joined_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_member_default_role(self, async_session, test_user):
|
||||
org = Organization(
|
||||
name="Role Test Org",
|
||||
slug="role-test-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
|
||||
member = OrgMember(
|
||||
organization_id=org.id,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
async_session.add(member)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(member)
|
||||
|
||||
assert member.role == "viewer"
|
||||
assert member.invited_by is None
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestSubscriptionModel:
|
||||
|
||||
def test_subscription_table_name(self):
|
||||
assert Subscription.__tablename__ == "subscriptions"
|
||||
|
||||
def test_subscription_has_required_fields(self):
|
||||
fields = Subscription.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "user_id" in fields
|
||||
assert "plan" in fields
|
||||
assert "status" in fields
|
||||
assert "start_date" in fields
|
||||
assert "end_date" in fields
|
||||
assert "amount" in fields
|
||||
assert "payment_method" in fields
|
||||
assert "payment_id" in fields
|
||||
assert "created_at" in fields
|
||||
|
||||
def test_subscription_field_types(self):
|
||||
columns = Subscription.__table__.columns
|
||||
assert "UUID" in str(columns["id"].type).upper()
|
||||
assert "UUID" in str(columns["user_id"].type).upper()
|
||||
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
|
||||
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
|
||||
assert "DATE" in str(columns["start_date"].type).upper()
|
||||
assert "DATE" in str(columns["end_date"].type).upper()
|
||||
assert "NUMERIC" in str(columns["amount"].type).upper()
|
||||
|
||||
def test_subscription_relationships_defined(self):
|
||||
relationships = Subscription.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "user" in rel_keys
|
||||
|
||||
def test_subscription_plan_field_allows_values(self):
|
||||
columns = Subscription.__table__.columns
|
||||
plan_col = columns["plan"]
|
||||
assert plan_col.nullable is False
|
||||
|
||||
def test_subscription_status_default(self):
|
||||
columns = Subscription.__table__.columns
|
||||
status_col = columns["status"]
|
||||
assert status_col.default is not None
|
||||
assert status_col.default.arg == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_create(self, async_session, test_user):
|
||||
subscription = Subscription(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
plan="pro",
|
||||
status="active",
|
||||
start_date=date(2025, 1, 1),
|
||||
end_date=date(2025, 12, 31),
|
||||
amount=99.99,
|
||||
payment_method="credit_card",
|
||||
payment_id="pay_abc123",
|
||||
)
|
||||
async_session.add(subscription)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(subscription)
|
||||
|
||||
assert subscription.id is not None
|
||||
assert subscription.user_id == test_user.id
|
||||
assert subscription.plan == "pro"
|
||||
assert subscription.status == "active"
|
||||
assert subscription.start_date == date(2025, 1, 1)
|
||||
assert subscription.end_date == date(2025, 12, 31)
|
||||
assert subscription.amount is not None
|
||||
assert subscription.payment_method == "credit_card"
|
||||
assert subscription.payment_id == "pay_abc123"
|
||||
assert subscription.created_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_default_status(self, async_session, test_user):
|
||||
subscription = Subscription(
|
||||
user_id=test_user.id,
|
||||
plan="free",
|
||||
start_date=date(2025, 1, 1),
|
||||
end_date=date(2025, 12, 31),
|
||||
)
|
||||
async_session.add(subscription)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(subscription)
|
||||
|
||||
assert subscription.status == "active"
|
||||
assert subscription.amount is None
|
||||
assert subscription.payment_method is None
|
||||
assert subscription.payment_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_query_by_user(self, async_session, test_user):
|
||||
sub1 = Subscription(
|
||||
user_id=test_user.id,
|
||||
plan="free",
|
||||
start_date=date(2025, 1, 1),
|
||||
end_date=date(2025, 6, 30),
|
||||
)
|
||||
sub2 = Subscription(
|
||||
user_id=test_user.id,
|
||||
plan="pro",
|
||||
start_date=date(2025, 7, 1),
|
||||
end_date=date(2025, 12, 31),
|
||||
)
|
||||
async_session.add(sub1)
|
||||
async_session.add(sub2)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(Subscription).where(Subscription.user_id == test_user.id)
|
||||
)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
assert len(subscriptions) == 2
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.brand import Brand
|
||||
from app.models.suggestion import Suggestion
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestSuggestionModel:
|
||||
|
||||
def test_suggestion_table_name(self):
|
||||
assert Suggestion.__tablename__ == "suggestions"
|
||||
|
||||
def test_suggestion_has_required_fields(self):
|
||||
fields = Suggestion.__table__.columns.keys()
|
||||
assert "id" in fields
|
||||
assert "brand_id" in fields
|
||||
assert "type" in fields
|
||||
assert "priority" in fields
|
||||
assert "title" in fields
|
||||
assert "description" in fields
|
||||
assert "action" in fields
|
||||
assert "expected_impact" in fields
|
||||
assert "difficulty" in fields
|
||||
assert "status" in fields
|
||||
assert "generated_at" in fields
|
||||
assert "updated_at" in fields
|
||||
assert "batch_id" in fields
|
||||
assert "source" in fields
|
||||
|
||||
def test_suggestion_field_types(self):
|
||||
columns = Suggestion.__table__.columns
|
||||
id_type = str(columns["id"].type).upper()
|
||||
assert "UUID" in id_type or "CHAR" in id_type
|
||||
brand_id_type = str(columns["brand_id"].type).upper()
|
||||
assert "UUID" in brand_id_type or "CHAR" in brand_id_type
|
||||
type_type = str(columns["type"].type).upper()
|
||||
assert "VARCHAR" in type_type or "STRING" in type_type
|
||||
priority_type = str(columns["priority"].type).upper()
|
||||
assert "VARCHAR" in priority_type or "STRING" in priority_type
|
||||
title_type = str(columns["title"].type).upper()
|
||||
assert "VARCHAR" in title_type or "STRING" in title_type
|
||||
assert "TEXT" in str(columns["description"].type).upper()
|
||||
difficulty_type = str(columns["difficulty"].type).upper()
|
||||
assert "VARCHAR" in difficulty_type or "STRING" in difficulty_type
|
||||
status_type = str(columns["status"].type).upper()
|
||||
assert "VARCHAR" in status_type or "STRING" in status_type
|
||||
|
||||
def test_suggestion_relationships_defined(self):
|
||||
relationships = Suggestion.__mapper__.relationships
|
||||
rel_keys = relationships.keys()
|
||||
assert "brand" in rel_keys
|
||||
|
||||
def test_suggestion_priority_default(self):
|
||||
columns = Suggestion.__table__.columns
|
||||
priority_col = columns["priority"]
|
||||
assert priority_col.default is not None
|
||||
assert priority_col.default.arg == "medium"
|
||||
|
||||
def test_suggestion_status_default(self):
|
||||
columns = Suggestion.__table__.columns
|
||||
status_col = columns["status"]
|
||||
assert status_col.default is not None
|
||||
assert status_col.default.arg == "pending"
|
||||
|
||||
def test_suggestion_difficulty_default(self):
|
||||
columns = Suggestion.__table__.columns
|
||||
difficulty_col = columns["difficulty"]
|
||||
assert difficulty_col.default is not None
|
||||
assert difficulty_col.default.arg == "medium"
|
||||
|
||||
def test_suggestion_source_default(self):
|
||||
columns = Suggestion.__table__.columns
|
||||
source_col = columns["source"]
|
||||
assert source_col.default is not None
|
||||
assert source_col.default.arg == "rule"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestion_create(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
name="Suggestion Test Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
|
||||
suggestion = Suggestion(
|
||||
id=uuid.uuid4(),
|
||||
brand_id=brand.id,
|
||||
type="content_optimization",
|
||||
priority="high",
|
||||
title="Optimize content structure",
|
||||
description="Improve heading hierarchy for better AI citation",
|
||||
action="Restructure headings using H2/H3 hierarchy",
|
||||
expected_impact="20% increase in citation rate",
|
||||
difficulty="easy",
|
||||
status="pending",
|
||||
source="rule",
|
||||
)
|
||||
async_session.add(suggestion)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(suggestion)
|
||||
|
||||
assert suggestion.id is not None
|
||||
assert suggestion.brand_id == brand.id
|
||||
assert suggestion.type == "content_optimization"
|
||||
assert suggestion.priority == "high"
|
||||
assert suggestion.title == "Optimize content structure"
|
||||
assert suggestion.description == "Improve heading hierarchy for better AI citation"
|
||||
assert suggestion.action == "Restructure headings using H2/H3 hierarchy"
|
||||
assert suggestion.expected_impact == "20% increase in citation rate"
|
||||
assert suggestion.difficulty == "easy"
|
||||
assert suggestion.status == "pending"
|
||||
assert suggestion.source == "rule"
|
||||
assert suggestion.generated_at is not None
|
||||
assert suggestion.updated_at is not None
|
||||
assert suggestion.batch_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestion_default_values(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
name="Default Suggestion Brand",
|
||||
platforms=["kimi"],
|
||||
)
|
||||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
|
||||
suggestion = Suggestion(
|
||||
brand_id=brand.id,
|
||||
type="competitor_gap",
|
||||
title="Fill competitor gap",
|
||||
description="Address missing topics that competitors cover",
|
||||
)
|
||||
async_session.add(suggestion)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(suggestion)
|
||||
|
||||
assert suggestion.priority == "medium"
|
||||
assert suggestion.difficulty == "medium"
|
||||
assert suggestion.status == "pending"
|
||||
assert suggestion.source == "rule"
|
||||
assert suggestion.action is None
|
||||
assert suggestion.expected_impact is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestion_query_by_brand(self, async_session, test_user):
|
||||
brand = Brand(
|
||||
user_id=test_user.id,
|
||||
name="Query Suggestion Brand",
|
||||
platforms=["wenxin"],
|
||||
)
|
||||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
|
||||
s1 = Suggestion(
|
||||
brand_id=brand.id,
|
||||
type="content_optimization",
|
||||
title="Suggestion 1",
|
||||
description="Desc 1",
|
||||
)
|
||||
s2 = Suggestion(
|
||||
brand_id=brand.id,
|
||||
type="platform_targeting",
|
||||
title="Suggestion 2",
|
||||
description="Desc 2",
|
||||
)
|
||||
async_session.add(s1)
|
||||
async_session.add(s2)
|
||||
await async_session.commit()
|
||||
|
||||
result = await async_session.execute(
|
||||
select(Suggestion).where(Suggestion.brand_id == brand.id)
|
||||
)
|
||||
suggestions = result.scalars().all()
|
||||
|
||||
assert len(suggestions) == 2
|
||||
Loading…
Reference in New Issue