geo/backend/app/services/admin.py

197 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from datetime import datetime, timedelta
from sqlalchemy import func, select, case
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.subscription import Subscription
from app.models.user import User
from app.services.subscription import PLANS
async def get_system_stats(db: AsyncSession) -> dict:
total_users_result = await db.execute(select(func.count()).select_from(User))
total_users = total_users_result.scalar_one()
total_queries_result = await db.execute(select(func.count()).select_from(Query))
total_queries = total_queries_result.scalar_one()
total_citations_result = await db.execute(
select(func.count()).select_from(CitationRecord)
)
total_citations = total_citations_result.scalar_one()
cited_result = await db.execute(
select(func.count()).select_from(CitationRecord).where(CitationRecord.cited.is_(True))
)
cited_count = cited_result.scalar_one()
citation_rate = round(cited_count / total_citations * 100, 2) if total_citations > 0 else 0.0
today = datetime.utcnow().date()
today_start = datetime(today.year, today.month, today.day)
today_active_result = await db.execute(
select(func.count(func.distinct(Query.user_id))).where(Query.last_queried_at >= today_start)
)
today_active_users = today_active_result.scalar_one()
return {
"total_users": total_users,
"total_queries": total_queries,
"total_citations": total_citations,
"citation_rate": citation_rate,
"today_active_users": today_active_users,
}
async def get_users(
db: AsyncSession, skip: int = 0, limit: int = 20, search: str | None = None
) -> dict:
base_stmt = select(User)
count_stmt = select(func.count()).select_from(User)
if search:
like_pattern = f"%{search}%"
base_stmt = base_stmt.where(
(User.email.ilike(like_pattern)) | (User.name.ilike(like_pattern))
)
count_stmt = count_stmt.where(
(User.email.ilike(like_pattern)) | (User.name.ilike(like_pattern))
)
base_stmt = base_stmt.order_by(User.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(base_stmt)
users = list(result.scalars().all())
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
if not users:
return {"items": [], "total": total}
# 修复 N+1一次性批量获取所有用户的 query 计数
user_ids = [u.id for u in users]
query_count_stmt = (
select(Query.user_id, func.count().label("cnt"))
.where(Query.user_id.in_(user_ids))
.group_by(Query.user_id)
)
qc_result = await db.execute(query_count_stmt)
query_counts: dict = {row.user_id: row.cnt for row in qc_result.all()}
items = []
for user in users:
items.append(
{
"id": user.id,
"email": user.email,
"name": user.name,
"plan": user.plan,
"is_active": user.is_active,
"is_admin": user.is_admin,
"email_verified": user.email_verified,
"query_count": query_counts.get(user.id, 0),
"created_at": user.created_at,
}
)
return {"items": items, "total": total}
async def get_user_detail(db: AsyncSession, user_id: uuid.UUID) -> dict | None:
user_result = await db.execute(select(User).where(User.id == user_id))
user = user_result.scalar_one_or_none()
if user is None:
return None
queries_result = await db.execute(
select(Query).where(Query.user_id == user_id).order_by(Query.created_at.desc())
)
queries = queries_result.scalars().all()
citations_result = await db.execute(
select(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(Query.user_id == user_id)
.order_by(CitationRecord.queried_at.desc())
.limit(10)
)
citations = citations_result.scalars().all()
return {
"user": {
"id": user.id,
"email": user.email,
"name": user.name,
"plan": user.plan,
"is_active": user.is_active,
"is_admin": user.is_admin,
"email_verified": user.email_verified,
"max_queries": user.max_queries,
"created_at": user.created_at,
"updated_at": user.updated_at,
},
"queries": [
{
"id": q.id,
"keyword": q.keyword,
"target_brand": q.target_brand,
"status": q.status,
"frequency": q.frequency,
"created_at": q.created_at,
}
for q in queries
],
"recent_citations": [
{
"id": c.id,
"platform": c.platform,
"cited": c.cited,
"citation_position": c.citation_position,
"queried_at": c.queried_at,
}
for c in citations
],
}
async def toggle_user_active(db: AsyncSession, user_id: uuid.UUID) -> dict | None:
user_result = await db.execute(select(User).where(User.id == user_id))
user = user_result.scalar_one_or_none()
if user is None:
return None
user.is_active = not user.is_active
await db.commit()
return {
"id": user.id,
"is_active": user.is_active,
"message": "用户已启用" if user.is_active else "用户已禁用",
}
async def update_user_plan(db: AsyncSession, user_id: uuid.UUID, plan: str) -> dict | None:
plan_data = PLANS.get(plan)
if plan_data is None:
raise ValueError(f"Invalid plan: {plan}")
user_result = await db.execute(select(User).where(User.id == user_id))
user = user_result.scalar_one_or_none()
if user is None:
return None
user.plan = plan
user.max_queries = plan_data["max_queries"]
await db.commit()
return {
"id": user.id,
"plan": user.plan,
"max_queries": user.max_queries,
"message": f"用户套餐已更新为{plan_data['name']}",
}