import uuid from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select, func, case from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.api.deps import get_current_user from app.database import get_db from app.models.lifecycle import LifecycleProject, ProjectStage from app.models.organization import Organization, OrgMember from app.models.user import User from app.schemas.lifecycle import ( ProjectCreateRequest, ProjectResponse, ProjectStatsResponse, QuickStartResponse, StageDetailResponse, StageUpdateRequest, TimelineEvent, ) router = APIRouter() STAGE_NAMES = { 1: "品牌基建", 2: "内容生产", 3: "AI适配优化", 4: "权威信号构建", 5: "持续运维", } STAGE_INT_TO_STR = { 1: "diagnosis", 2: "strategy", 3: "content", 4: "publishing", 5: "monitoring", } # ---------- helpers ---------- async def _get_or_create_org( db: AsyncSession, user: User ) -> uuid.UUID: """Return the user's organization_id; create a default org if missing.""" if user.organization_id: return user.organization_id slug = f"default-{user.id.hex[:8]}" org = Organization( name=f"{user.name or user.email.split('@')[0]}'s Organization", slug=slug, plan="free", ) db.add(org) await db.flush() # link user user.organization_id = org.id db.add(user) # add owner membership membership = OrgMember( organization_id=org.id, user_id=user.id, role="owner", ) db.add(membership) await db.flush() return org.id async def _load_project_with_stages( db: AsyncSession, project_id: uuid.UUID, org_id: uuid.UUID ) -> LifecycleProject | None: stmt = ( select(LifecycleProject) .where( LifecycleProject.id == project_id, LifecycleProject.organization_id == org_id, ) .options(selectinload(LifecycleProject.stages)) ) result = await db.execute(stmt) return result.scalar_one_or_none() # ---------- endpoints ---------- @router.get("/projects/", response_model=list[ProjectResponse]) async def list_projects( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: return [] stmt = ( select(LifecycleProject) .where(LifecycleProject.organization_id == org_id) .options(selectinload(LifecycleProject.stages)) .order_by(LifecycleProject.created_at.desc()) ) result = await db.execute(stmt) projects = result.scalars().all() return projects @router.get("/projects/stats", response_model=ProjectStatsResponse) async def project_stats( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: return ProjectStatsResponse( total_projects=0, active_projects=0, completed_projects=0, contents_produced=0, avg_ai_citation_rate=None, current_stage_distribution={}, stage_distribution={}, completion_rate=0.0, ) # total & active counts count_stmt = select( func.count().label("total"), func.count().filter(LifecycleProject.status == "active").label("active"), ).where(LifecycleProject.organization_id == org_id) result = await db.execute(count_stmt) row = result.one() total = row.total active = row.active # stage distribution dist_stmt = ( select( ProjectStage.status, func.count().label("cnt"), ) .join(LifecycleProject, ProjectStage.project_id == LifecycleProject.id) .where(LifecycleProject.organization_id == org_id) .group_by(ProjectStage.status) ) dist_result = await db.execute(dist_stmt) stage_distribution = {r.status: r.cnt for r in dist_result.all()} # completion rate: projects where current_stage >= 5 comp_stmt = select( func.count().filter(LifecycleProject.current_stage >= 5).label("done"), ).where(LifecycleProject.organization_id == org_id) comp_result = await db.execute(comp_stmt) done = comp_result.scalar() or 0 completion_rate = round(done / total, 4) if total > 0 else 0.0 # completed projects completed_stmt = select( func.count().filter(LifecycleProject.status == "completed").label("completed"), ).where(LifecycleProject.organization_id == org_id) completed_result = await db.execute(completed_stmt) completed = completed_result.scalar() or 0 # contents produced (count from content table if available) try: from app.models.content import Content contents_stmt = select(func.count()).where(Content.organization_id == org_id) contents_result = await db.execute(contents_stmt) contents_produced = contents_result.scalar() or 0 except Exception: contents_produced = 0 # avg AI citation rate try: from app.models.citation_record import CitationRecord from app.models.query import Query as QueryModel # Query uses user_id, so join through users table to get org members from app.models.organization import OrgMember org_user_ids_stmt = select(OrgMember.user_id).where(OrgMember.organization_id == org_id) org_user_ids_result = await db.execute(org_user_ids_stmt) org_user_ids = [r.user_id for r in org_user_ids_result.all()] if org_user_ids: citation_stmt = select( func.count().label("total_citations"), func.count().filter(CitationRecord.cited == True).label("cited_count"), ).join(QueryModel, CitationRecord.query_id == QueryModel.id).where( QueryModel.user_id.in_(org_user_ids), ) citation_result = await db.execute(citation_stmt) citation_row = citation_result.one() total_citations = citation_row.total_citations or 0 cited_count = citation_row.cited_count or 0 avg_ai_citation_rate = round(cited_count / total_citations, 4) if total_citations > 0 else None else: avg_ai_citation_rate = None except Exception: avg_ai_citation_rate = None # current stage distribution (map int stage to string) current_stage_dist_stmt = ( select( LifecycleProject.current_stage, func.count().label("cnt"), ) .where(LifecycleProject.organization_id == org_id) .group_by(LifecycleProject.current_stage) ) current_stage_dist_result = await db.execute(current_stage_dist_stmt) current_stage_distribution = {} for r in current_stage_dist_result.all(): stage_key = STAGE_INT_TO_STR.get(r.current_stage, str(r.current_stage)) current_stage_distribution[stage_key] = current_stage_distribution.get(stage_key, 0) + r.cnt return ProjectStatsResponse( total_projects=total, active_projects=active, completed_projects=completed, contents_produced=contents_produced, avg_ai_citation_rate=avg_ai_citation_rate, current_stage_distribution=current_stage_distribution, stage_distribution=stage_distribution, completion_rate=completion_rate, ) @router.get("/projects/{project_id}/timeline", response_model=list[TimelineEvent]) async def project_timeline( project_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: return [] project = await _load_project_with_stages(db, project_id, org_id) if not project: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") events: list[TimelineEvent] = [] # project creation event events.append( TimelineEvent( event_type="project_created", description=f"项目「{project.brand_name}」创建", timestamp=project.created_at, stage_number=None, ) ) # stage events for stage in sorted(project.stages, key=lambda s: s.stage_number): stage_label = STAGE_NAMES.get(stage.stage_number, f"阶段 {stage.stage_number}") if stage.started_at: events.append( TimelineEvent( event_type="stage_started", description=f"{stage_label} 开始", timestamp=stage.started_at, stage_number=stage.stage_number, ) ) if stage.completed_at: events.append( TimelineEvent( event_type="stage_completed", description=f"{stage_label} 完成", timestamp=stage.completed_at, stage_number=stage.stage_number, ) ) events.sort(key=lambda e: e.timestamp) return events @router.post("/projects/quick-start", response_model=QuickStartResponse, status_code=status.HTTP_201_CREATED) async def quick_start( body: ProjectCreateRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = await _get_or_create_org(db, current_user) # create project project = LifecycleProject( organization_id=org_id, brand_name=body.brand_name, brand_aliases=[], current_stage=1, status="active", created_by=current_user.id, ) db.add(project) await db.flush() # create 5 stages now = datetime.utcnow() for i in range(1, 6): stage = ProjectStage( project_id=project.id, stage_number=i, status="active" if i == 1 else "pending", started_at=now if i == 1 else None, ) db.add(stage) await db.flush() await db.commit() # reload with stages project = await _load_project_with_stages(db, project.id, org_id) return QuickStartResponse( project=ProjectResponse.model_validate(project), message=f"项目「{body.brand_name}」已创建,从品牌基建阶段开始", ) @router.get("/projects/{project_id}/stages", response_model=list[StageDetailResponse]) async def list_stages( project_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: return [] project = await _load_project_with_stages(db, project_id, org_id) if not project: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") return sorted(project.stages, key=lambda s: s.stage_number) @router.put("/projects/{project_id}/stages/{stage_number}", response_model=StageDetailResponse) async def update_stage( project_id: uuid.UUID, stage_number: int, body: StageUpdateRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户未关联组织,无法修改项目阶段") # verify project ownership project = await _load_project_with_stages(db, project_id, org_id) if not project: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") # find the stage stage: ProjectStage | None = None for s in project.stages: if s.stage_number == stage_number: stage = s break if not stage: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Stage not found") now = datetime.utcnow() # update fields if body.status is not None: stage.status = body.status if body.status == "active" and stage.started_at is None: stage.started_at = now if body.status == "completed": stage.completed_at = now # advance project current_stage if stage_number >= project.current_stage: project.current_stage = stage_number + 1 if body.notes is not None: stage.notes = body.notes if body.metrics is not None: stage.metrics = body.metrics await db.commit() await db.refresh(stage) return stage