geo/backend/app/api/schema_advisor.py

249 lines
8.2 KiB
Python

import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.schema_suggestion import SchemaSuggestion
from app.schemas.schema_suggestion import (
SchemaAdviseRequest,
SchemaSuggestionResponse,
SchemaSuggestionList,
SchemaValidationResult,
SchemaStatusUpdateRequest,
)
from app.services.schema.schema_advisor_service import SchemaAdvisorService
from app.services.scoring.scoring_service import ScoringService
router = APIRouter()
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
async def _get_brand_diagnosis_data(
db: AsyncSession,
user_id: uuid.UUID,
brand: Brand,
) -> dict:
from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.models.competitor import Competitor
from app.schemas.scoring import CitationResult
from app.services.analysis.sentiment_service import get_sentiment_service
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
scoring_service = ScoringService()
empty_result = scoring_service.calculate_v2(
mentioned_count=0,
total_queries=0,
positions=[],
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
citations=[],
brand_mentions=0,
competitor_mentions={},
)
return empty_result.to_dict()
query_ids = [q.id for q in queries]
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
all_citations = list(citations_result.scalars().all())
total_queries = len(all_citations)
brand_citations = [c for c in all_citations if c.cited]
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in all_citations
if c.cited and c.competitor_brands
and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
sentiment_service = get_sentiment_service()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for citation in brand_citations:
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
sentiment_counts[citation.sentiment] += 1
else:
content = citation.raw_response or citation.citation_text or ""
if content.strip():
try:
result = await sentiment_service.analyze(
brand_name=brand.name,
content=content,
)
sentiment_counts[result.sentiment] += 1
except Exception:
sentiment_counts["neutral"] += 1
else:
sentiment_counts["neutral"] += 1
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment="neutral",
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
positions = [c.citation_position for c in brand_citations if c.cited]
scoring_service = ScoringService()
v2_result = scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
return v2_result.to_dict()
@router.post("/advise", response_model=SchemaSuggestionList)
async def generate_schema_advise(
request: SchemaAdviseRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_with_access(request.brand_id, db, current_user)
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
brand_info = {
"name": brand.name,
"website": brand.website or "",
"industry": brand.industry or "",
}
service = SchemaAdvisorService()
suggestions = await service.generate_suggestions(
db=db,
brand_id=brand.id,
diagnosis_data=diagnosis_data,
brand_info=brand_info,
target_url=request.target_url,
focus_dimensions=request.focus_dimensions,
)
return SchemaSuggestionList(
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
total=len(suggestions),
)
@router.get("/brand/{brand_id}", response_model=SchemaSuggestionList)
async def get_brand_schema_suggestions(
brand_id: uuid.UUID,
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
schema_type: str | None = Query(None, description="按Schema类型筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
service = SchemaAdvisorService()
suggestions, total = await service.get_suggestions(
db=db,
brand_id=brand_id,
status_filter=status_filter,
schema_type=schema_type,
skip=skip,
limit=limit,
)
return SchemaSuggestionList(
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
total=total,
)
@router.get("/{suggestion_id}", response_model=SchemaSuggestionResponse)
async def get_schema_suggestion_detail(
suggestion_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = SchemaAdvisorService()
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
if not suggestion:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="建议不存在",
)
brand = await _get_brand_with_access(suggestion.brand_id, db, current_user)
return SchemaSuggestionResponse.model_validate(suggestion)
@router.put("/{suggestion_id}/status", response_model=SchemaSuggestionResponse)
async def update_schema_suggestion_status(
suggestion_id: uuid.UUID,
status_update: SchemaStatusUpdateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
valid_statuses = {"pending", "applied", "dismissed"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
service = SchemaAdvisorService()
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
if not suggestion:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="建议不存在",
)
await _get_brand_with_access(suggestion.brand_id, db, current_user)
updated = await service.update_status(db, suggestion_id, status_update.status)
return SchemaSuggestionResponse.model_validate(updated)