249 lines
8.2 KiB
Python
249 lines
8.2 KiB
Python
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.models.schema_suggestion import SchemaSuggestion
|
|
from app.schemas.schema_suggestion import (
|
|
SchemaAdviseRequest,
|
|
SchemaSuggestionResponse,
|
|
SchemaSuggestionList,
|
|
SchemaValidationResult,
|
|
SchemaStatusUpdateRequest,
|
|
)
|
|
from app.services.schema.schema_advisor_service import SchemaAdvisorService
|
|
from app.services.scoring.scoring_service import ScoringService
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
async def _get_brand_with_access(
|
|
brand_id: uuid.UUID,
|
|
db: AsyncSession,
|
|
current_user: User,
|
|
) -> Brand:
|
|
stmt = select(Brand).where(
|
|
Brand.id == brand_id,
|
|
Brand.user_id == current_user.id,
|
|
)
|
|
result = await db.execute(stmt)
|
|
brand = result.scalar_one_or_none()
|
|
if not brand:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="品牌不存在",
|
|
)
|
|
return brand
|
|
|
|
|
|
async def _get_brand_diagnosis_data(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
brand: Brand,
|
|
) -> dict:
|
|
from app.models.query import Query as QueryModel
|
|
from app.models.citation_record import CitationRecord
|
|
from app.models.competitor import Competitor
|
|
from app.schemas.scoring import CitationResult
|
|
from app.services.analysis.sentiment_service import get_sentiment_service
|
|
|
|
queries_stmt = select(QueryModel).where(
|
|
QueryModel.user_id == user_id,
|
|
QueryModel.target_brand == brand.name,
|
|
)
|
|
queries_result = await db.execute(queries_stmt)
|
|
queries = list(queries_result.scalars().all())
|
|
|
|
if not queries:
|
|
scoring_service = ScoringService()
|
|
empty_result = scoring_service.calculate_v2(
|
|
mentioned_count=0,
|
|
total_queries=0,
|
|
positions=[],
|
|
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
|
|
citations=[],
|
|
brand_mentions=0,
|
|
competitor_mentions={},
|
|
)
|
|
return empty_result.to_dict()
|
|
|
|
query_ids = [q.id for q in queries]
|
|
|
|
citations_stmt = select(CitationRecord).where(
|
|
CitationRecord.query_id.in_(query_ids),
|
|
)
|
|
citations_result = await db.execute(citations_stmt)
|
|
all_citations = list(citations_result.scalars().all())
|
|
|
|
total_queries = len(all_citations)
|
|
brand_citations = [c for c in all_citations if c.cited]
|
|
|
|
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
|
|
competitor_result = await db.execute(competitor_stmt)
|
|
competitors = list(competitor_result.scalars().all())
|
|
competitor_names = [c.name for c in competitors]
|
|
|
|
competitor_mentions: dict[str, int] = {}
|
|
for comp_name in competitor_names:
|
|
count = sum(
|
|
1 for c in all_citations
|
|
if c.cited and c.competitor_brands
|
|
and comp_name in c.competitor_brands
|
|
)
|
|
if count > 0:
|
|
competitor_mentions[comp_name] = count
|
|
|
|
sentiment_service = get_sentiment_service()
|
|
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
|
for citation in brand_citations:
|
|
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
|
|
sentiment_counts[citation.sentiment] += 1
|
|
else:
|
|
content = citation.raw_response or citation.citation_text or ""
|
|
if content.strip():
|
|
try:
|
|
result = await sentiment_service.analyze(
|
|
brand_name=brand.name,
|
|
content=content,
|
|
)
|
|
sentiment_counts[result.sentiment] += 1
|
|
except Exception:
|
|
sentiment_counts["neutral"] += 1
|
|
else:
|
|
sentiment_counts["neutral"] += 1
|
|
|
|
citation_results = [
|
|
CitationResult(
|
|
cited=c.cited,
|
|
position=c.citation_position,
|
|
citation_text=c.citation_text,
|
|
sentiment="neutral",
|
|
confidence=c.confidence or 0.0,
|
|
)
|
|
for c in brand_citations
|
|
]
|
|
|
|
positions = [c.citation_position for c in brand_citations if c.cited]
|
|
|
|
scoring_service = ScoringService()
|
|
v2_result = scoring_service.calculate_v2(
|
|
mentioned_count=len(brand_citations),
|
|
total_queries=total_queries,
|
|
positions=positions,
|
|
sentiment_counts=sentiment_counts,
|
|
citations=citation_results,
|
|
brand_mentions=len(brand_citations),
|
|
competitor_mentions=competitor_mentions,
|
|
)
|
|
|
|
return v2_result.to_dict()
|
|
|
|
|
|
@router.post("/advise", response_model=SchemaSuggestionList)
|
|
async def generate_schema_advise(
|
|
request: SchemaAdviseRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
brand = await _get_brand_with_access(request.brand_id, db, current_user)
|
|
|
|
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
|
|
|
|
brand_info = {
|
|
"name": brand.name,
|
|
"website": brand.website or "",
|
|
"industry": brand.industry or "",
|
|
}
|
|
|
|
service = SchemaAdvisorService()
|
|
suggestions = await service.generate_suggestions(
|
|
db=db,
|
|
brand_id=brand.id,
|
|
diagnosis_data=diagnosis_data,
|
|
brand_info=brand_info,
|
|
target_url=request.target_url,
|
|
focus_dimensions=request.focus_dimensions,
|
|
)
|
|
|
|
return SchemaSuggestionList(
|
|
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
|
|
total=len(suggestions),
|
|
)
|
|
|
|
|
|
@router.get("/brand/{brand_id}", response_model=SchemaSuggestionList)
|
|
async def get_brand_schema_suggestions(
|
|
brand_id: uuid.UUID,
|
|
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
|
|
schema_type: str | None = Query(None, description="按Schema类型筛选"),
|
|
skip: int = Query(0, ge=0),
|
|
limit: int = Query(20, ge=1, le=100),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await _get_brand_with_access(brand_id, db, current_user)
|
|
|
|
service = SchemaAdvisorService()
|
|
suggestions, total = await service.get_suggestions(
|
|
db=db,
|
|
brand_id=brand_id,
|
|
status_filter=status_filter,
|
|
schema_type=schema_type,
|
|
skip=skip,
|
|
limit=limit,
|
|
)
|
|
|
|
return SchemaSuggestionList(
|
|
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
|
|
total=total,
|
|
)
|
|
|
|
|
|
@router.get("/{suggestion_id}", response_model=SchemaSuggestionResponse)
|
|
async def get_schema_suggestion_detail(
|
|
suggestion_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
service = SchemaAdvisorService()
|
|
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
|
|
if not suggestion:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="建议不存在",
|
|
)
|
|
brand = await _get_brand_with_access(suggestion.brand_id, db, current_user)
|
|
return SchemaSuggestionResponse.model_validate(suggestion)
|
|
|
|
|
|
@router.put("/{suggestion_id}/status", response_model=SchemaSuggestionResponse)
|
|
async def update_schema_suggestion_status(
|
|
suggestion_id: uuid.UUID,
|
|
status_update: SchemaStatusUpdateRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
valid_statuses = {"pending", "applied", "dismissed"}
|
|
if status_update.status not in valid_statuses:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
|
)
|
|
|
|
service = SchemaAdvisorService()
|
|
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
|
|
if not suggestion:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="建议不存在",
|
|
)
|
|
await _get_brand_with_access(suggestion.brand_id, db, current_user)
|
|
|
|
updated = await service.update_status(db, suggestion_id, status_update.status)
|
|
return SchemaSuggestionResponse.model_validate(updated)
|