598 lines
19 KiB
Python
598 lines
19 KiB
Python
import asyncio
|
||
import csv
|
||
import io
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from sqlalchemy import func, select, and_, cast, Integer
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.models.query_task import QueryTask
|
||
from app.workers.citation_engine import CitationEngine
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def _verify_query_ownership(
|
||
db: AsyncSession,
|
||
query_id: uuid.UUID,
|
||
user_id: uuid.UUID,
|
||
) -> Query | None:
|
||
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
|
||
result = await db.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
|
||
async def get_citations(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
query_id: uuid.UUID | None = None,
|
||
platform: str | None = None,
|
||
start_date: datetime | None = None,
|
||
end_date: datetime | None = None,
|
||
skip: int = 0,
|
||
limit: int = 20,
|
||
) -> tuple[list[CitationRecord], int]:
|
||
# Build base filter: citations belonging to the user's queries
|
||
conditions = [Query.user_id == user_id]
|
||
|
||
if query_id is not None:
|
||
conditions.append(CitationRecord.query_id == query_id)
|
||
# Also verify query ownership explicitly when query_id is provided
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
return [], 0
|
||
|
||
if platform is not None:
|
||
conditions.append(CitationRecord.platform == platform)
|
||
|
||
if start_date is not None:
|
||
conditions.append(CitationRecord.queried_at >= start_date)
|
||
|
||
if end_date is not None:
|
||
conditions.append(CitationRecord.queried_at <= end_date)
|
||
|
||
stmt = (
|
||
select(CitationRecord)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(and_(*conditions))
|
||
.order_by(CitationRecord.queried_at.desc())
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
items = result.scalars().all()
|
||
|
||
count_stmt = (
|
||
select(func.count())
|
||
.select_from(CitationRecord)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(and_(*conditions))
|
||
)
|
||
count_result = await db.execute(count_stmt)
|
||
total = count_result.scalar_one()
|
||
|
||
return list(items), total
|
||
|
||
|
||
async def get_citation_stats(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
query_id: uuid.UUID | None = None,
|
||
brand_id: uuid.UUID | None = None,
|
||
) -> dict:
|
||
"""
|
||
Get citation statistics.
|
||
|
||
Args:
|
||
db: Database session
|
||
user_id: User ID
|
||
query_id: Optional query ID to filter by specific query
|
||
brand_id: Optional brand ID to filter by specific brand
|
||
|
||
Returns:
|
||
dict with citation statistics
|
||
"""
|
||
# Build base conditions
|
||
base_conditions = [Query.user_id == user_id]
|
||
|
||
# If brand_id is provided, filter by brand name matching
|
||
if brand_id is not None:
|
||
from app.models.brand import Brand
|
||
brand_stmt = select(Brand.name).where(
|
||
Brand.id == brand_id,
|
||
Brand.user_id == user_id,
|
||
)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand_name = brand_result.scalar_one_or_none()
|
||
|
||
if brand_name is None:
|
||
return {
|
||
"total_queries": 0,
|
||
"total_citations": 0,
|
||
"citation_rate": 0.0,
|
||
"avg_position": None,
|
||
"by_platform": {},
|
||
"trend": [],
|
||
}
|
||
base_conditions.append(Query.target_brand == brand_name)
|
||
|
||
# If query_id is provided, verify ownership and filter
|
||
if query_id is not None:
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
return {
|
||
"total_queries": 0,
|
||
"total_citations": 0,
|
||
"citation_rate": 0.0,
|
||
"avg_position": None,
|
||
"by_platform": {},
|
||
"trend": [],
|
||
}
|
||
base_conditions.append(CitationRecord.query_id == query_id)
|
||
|
||
base_where = and_(*base_conditions)
|
||
|
||
# Total queries and citations
|
||
total_queries_stmt = (
|
||
select(func.count())
|
||
.select_from(CitationRecord)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(base_where)
|
||
)
|
||
total_queries_result = await db.execute(total_queries_stmt)
|
||
total_queries = total_queries_result.scalar_one()
|
||
|
||
total_citations_stmt = (
|
||
select(func.count())
|
||
.select_from(CitationRecord)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(base_where, CitationRecord.cited.is_(True))
|
||
)
|
||
total_citations_result = await db.execute(total_citations_stmt)
|
||
total_citations = total_citations_result.scalar_one()
|
||
|
||
citation_rate = total_citations / total_queries if total_queries > 0 else 0.0
|
||
|
||
# Average position (only for cited records with a position)
|
||
avg_pos_stmt = (
|
||
select(func.avg(CitationRecord.citation_position))
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(
|
||
base_where,
|
||
CitationRecord.cited.is_(True),
|
||
CitationRecord.citation_position.isnot(None),
|
||
)
|
||
)
|
||
avg_pos_result = await db.execute(avg_pos_stmt)
|
||
avg_position = avg_pos_result.scalar_one()
|
||
avg_position = round(avg_position, 1) if avg_position is not None else None
|
||
|
||
# By platform stats
|
||
platform_stmt = (
|
||
select(
|
||
CitationRecord.platform,
|
||
func.count().label("queries"),
|
||
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
|
||
func.avg(CitationRecord.citation_position).label("avg_position"),
|
||
)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(base_where)
|
||
.group_by(CitationRecord.platform)
|
||
)
|
||
platform_result = await db.execute(platform_stmt)
|
||
by_platform = {}
|
||
for row in platform_result.all():
|
||
platform_name = row.platform
|
||
queries = row.queries
|
||
citations = row.citations or 0
|
||
rate = citations / queries if queries > 0 else 0.0
|
||
plat_avg_pos = row.avg_position
|
||
plat_avg_pos = round(plat_avg_pos, 1) if plat_avg_pos is not None else None
|
||
by_platform[platform_name] = {
|
||
"queries": queries,
|
||
"citations": citations,
|
||
"rate": round(rate, 2),
|
||
"avg_position": plat_avg_pos,
|
||
}
|
||
|
||
# Trend: past 30 days grouped by week
|
||
# Use naive datetime to avoid mixing with naive datetimes from database
|
||
now = datetime.utcnow()
|
||
thirty_days_ago = now - timedelta(days=30)
|
||
|
||
# Cross-database week grouping expression
|
||
dialect = db.bind.dialect.name if db.bind else "postgresql"
|
||
if dialect == "postgresql":
|
||
week_expr = func.date_trunc("week", CitationRecord.queried_at)
|
||
else:
|
||
# SQLite compatible week grouping (YYYY-WW format)
|
||
week_expr = func.strftime("%Y-%W", CitationRecord.queried_at)
|
||
|
||
trend_stmt = (
|
||
select(
|
||
week_expr.label("week_start"),
|
||
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
|
||
)
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(
|
||
base_where,
|
||
CitationRecord.queried_at >= thirty_days_ago,
|
||
)
|
||
.group_by(week_expr)
|
||
.order_by(week_expr)
|
||
)
|
||
trend_result = await db.execute(trend_stmt)
|
||
trend = []
|
||
for row in trend_result.all():
|
||
week_start = row.week_start
|
||
if isinstance(week_start, datetime):
|
||
date_str = week_start.date().isoformat()
|
||
else:
|
||
date_str = str(week_start)
|
||
trend.append({
|
||
"date": date_str,
|
||
"citations": int(row.citations or 0),
|
||
})
|
||
|
||
return {
|
||
"total_queries": total_queries,
|
||
"total_citations": total_citations,
|
||
"citation_rate": round(citation_rate, 2),
|
||
"avg_position": avg_position,
|
||
"by_platform": by_platform,
|
||
"trend": trend,
|
||
}
|
||
|
||
|
||
async def trigger_query_now(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
query_id: uuid.UUID,
|
||
) -> QueryTask:
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
raise ValueError("Query not found")
|
||
|
||
if query.status != "active":
|
||
raise ValueError("Query is not active")
|
||
|
||
platforms = query.platforms or []
|
||
if not platforms:
|
||
raise ValueError("No platforms configured for this query")
|
||
|
||
first_task = None
|
||
for platform in platforms:
|
||
task = QueryTask(
|
||
query_id=query_id,
|
||
platform=platform,
|
||
status="pending",
|
||
)
|
||
db.add(task)
|
||
if first_task is None:
|
||
first_task = task
|
||
|
||
await db.commit()
|
||
if first_task is not None:
|
||
await db.refresh(first_task)
|
||
|
||
# 新增:立即在后台执行查询任务
|
||
asyncio.create_task(
|
||
_execute_query_tasks(
|
||
query_id=query_id,
|
||
platforms=platforms,
|
||
keyword=query.keyword,
|
||
target_brand=query.target_brand,
|
||
brand_aliases=query.brand_aliases or [],
|
||
user_id=user_id,
|
||
)
|
||
)
|
||
|
||
return first_task
|
||
|
||
|
||
async def _execute_query_tasks(
|
||
query_id: uuid.UUID,
|
||
platforms: list,
|
||
keyword: str,
|
||
target_brand: str,
|
||
brand_aliases: list,
|
||
user_id: uuid.UUID | None = None,
|
||
):
|
||
"""后台执行查询任务"""
|
||
engine = CitationEngine()
|
||
try:
|
||
async with AsyncSessionLocal() as db:
|
||
# 验证 query 归属该用户
|
||
if user_id is not None:
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
logger.error(f"查询 {query_id} 不属于用户 {user_id},跳过执行")
|
||
return
|
||
|
||
stmt = select(QueryTask).where(
|
||
QueryTask.query_id == query_id,
|
||
QueryTask.status == "pending",
|
||
QueryTask.platform.in_(platforms),
|
||
)
|
||
result = await db.execute(stmt)
|
||
tasks = result.scalars().all()
|
||
|
||
for task in tasks:
|
||
try:
|
||
task.status = "running"
|
||
task.started_at = datetime.utcnow()
|
||
task.error_message = None
|
||
await db.commit()
|
||
|
||
citation_result = await engine.execute_single_platform(
|
||
keyword=keyword,
|
||
platform=task.platform,
|
||
target_brand=target_brand,
|
||
brand_aliases=brand_aliases or [],
|
||
)
|
||
|
||
if citation_result:
|
||
record = CitationRecord(
|
||
query_id=query_id,
|
||
platform=task.platform,
|
||
cited=citation_result.get("cited", False),
|
||
citation_position=citation_result.get("position"),
|
||
citation_text=citation_result.get("citation_text"),
|
||
competitor_brands=citation_result.get("competitor_brands", []),
|
||
raw_response=citation_result.get("raw_response", ""),
|
||
confidence=citation_result.get("confidence"),
|
||
match_type=citation_result.get("match_type"),
|
||
)
|
||
db.add(record)
|
||
|
||
task.status = "success"
|
||
task.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
task.status = "failed"
|
||
task.error_message = str(e)
|
||
task.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
logger.error(f"查询任务执行失败: {task.id}, 错误: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询引擎执行失败: {e}")
|
||
finally:
|
||
await engine.close()
|
||
|
||
|
||
PLATFORM_NAMES = {
|
||
"wenxin": "文心一言",
|
||
"kimi": "Kimi",
|
||
"tongyi": "通义千问",
|
||
"doubao": "豆包",
|
||
"qingyan": "智谱清言",
|
||
"tiangong": "天工AI",
|
||
"xinghuo": "讯飞星火",
|
||
"baidu_ai": "百度AI搜索",
|
||
"yuanbao": "腾讯元宝",
|
||
}
|
||
|
||
|
||
async def export_citations_pdf(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
query_id: uuid.UUID | None = None,
|
||
) -> bytes:
|
||
"""生成PDF格式报告"""
|
||
import os
|
||
from fpdf import FPDF
|
||
|
||
# 验证查询所有权(如果提供了 query_id)
|
||
if query_id is not None:
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
raise ValueError("Query not found")
|
||
|
||
# 构建查询条件
|
||
conditions = [Query.user_id == user_id]
|
||
if query_id is not None:
|
||
conditions.append(CitationRecord.query_id == query_id)
|
||
|
||
# 查询数据,使用 selectinload 加载 query 关系
|
||
stmt = (
|
||
select(CitationRecord)
|
||
.options(selectinload(CitationRecord.query))
|
||
.join(Query, CitationRecord.query_id == Query.id)
|
||
.where(and_(*conditions))
|
||
.order_by(CitationRecord.queried_at.desc())
|
||
)
|
||
result = await db.execute(stmt)
|
||
records = result.scalars().all()
|
||
|
||
pdf = FPDF()
|
||
pdf.add_page()
|
||
|
||
# 加载中文字体
|
||
font_paths = [
|
||
"/System/Library/Fonts/PingFang.ttc",
|
||
"/System/Library/Fonts/STHeiti Light.ttc",
|
||
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
|
||
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
|
||
]
|
||
font_loaded = False
|
||
for fp in font_paths:
|
||
if os.path.exists(fp):
|
||
pdf.add_font("Chinese", "", fp, uni=True)
|
||
pdf.set_font("Chinese", size=12)
|
||
font_loaded = True
|
||
break
|
||
if not font_loaded:
|
||
pdf.set_font("Helvetica", size=12)
|
||
|
||
# 封面
|
||
pdf.set_font_size(24)
|
||
pdf.cell(0, 40, "GEO 品牌曝光度分析报告", new_x="LMARGIN", new_y="NEXT", align="C")
|
||
pdf.set_font_size(12)
|
||
pdf.cell(0, 10, f"生成日期: {datetime.now().strftime('%Y-%m-%d %H:%M')}", new_x="LMARGIN", new_y="NEXT", align="C")
|
||
pdf.ln(20)
|
||
|
||
# 汇总统计
|
||
total = len(records)
|
||
cited_count = sum(1 for r in records if r.cited)
|
||
rate = f"{cited_count / total * 100:.1f}%" if total > 0 else "0%"
|
||
|
||
total_position = 0
|
||
position_count = 0
|
||
for r in records:
|
||
if r.citation_position is not None:
|
||
total_position += r.citation_position
|
||
position_count += 1
|
||
avg_pos = f"{total_position / position_count:.1f}" if position_count > 0 else "-"
|
||
|
||
pdf.set_font_size(16)
|
||
pdf.cell(0, 12, "一、汇总统计", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.set_font_size(11)
|
||
pdf.cell(0, 8, f"总查询次数: {total}", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.cell(0, 8, f"引用次数: {cited_count}", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.cell(0, 8, f"引用率: {rate}", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.cell(0, 8, f"平均引用位置: {avg_pos}", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.ln(10)
|
||
|
||
# 平台分布
|
||
pdf.set_font_size(16)
|
||
pdf.cell(0, 12, "二、平台分布", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.set_font_size(11)
|
||
|
||
platform_stats = {}
|
||
for r in records:
|
||
if r.platform not in platform_stats:
|
||
platform_stats[r.platform] = {"total": 0, "cited": 0}
|
||
platform_stats[r.platform]["total"] += 1
|
||
if r.cited:
|
||
platform_stats[r.platform]["cited"] += 1
|
||
|
||
for platform, stats in platform_stats.items():
|
||
name = PLATFORM_NAMES.get(platform, platform)
|
||
p_rate = f"{stats['cited'] / stats['total'] * 100:.1f}%" if stats['total'] > 0 else "0%"
|
||
pdf.cell(0, 8, f" {name}: 查询{stats['total']}次, 引用{stats['cited']}次, 引用率{p_rate}", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.ln(10)
|
||
|
||
# 详细数据表格
|
||
pdf.set_font_size(16)
|
||
pdf.cell(0, 12, "三、详细数据", new_x="LMARGIN", new_y="NEXT")
|
||
pdf.set_font_size(9)
|
||
|
||
col_widths = [30, 25, 20, 20, 15, 80]
|
||
headers = ["查询关键词", "平台", "是否引用", "置信度", "位置", "引用文本"]
|
||
for i, h in enumerate(headers):
|
||
pdf.cell(col_widths[i], 8, h, border=1, align="C")
|
||
pdf.ln()
|
||
|
||
for r in records:
|
||
keyword = r.query.keyword if r.query else ""
|
||
platform_name = PLATFORM_NAMES.get(r.platform, r.platform)
|
||
cited_str = "是" if r.cited else "否"
|
||
conf = f"{r.confidence:.2f}" if r.confidence is not None else "-"
|
||
pos = str(r.citation_position) if r.citation_position is not None else "-"
|
||
text = (r.citation_text or "")[:40]
|
||
|
||
row_data = [keyword[:15], platform_name, cited_str, conf, pos, text]
|
||
for i, d in enumerate(row_data):
|
||
pdf.cell(col_widths[i], 7, d, border=1)
|
||
pdf.ln()
|
||
|
||
return pdf.output()
|
||
|
||
|
||
async def export_citations_csv(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
query_id: uuid.UUID,
|
||
) -> str:
|
||
query = await _verify_query_ownership(db, query_id, user_id)
|
||
if query is None:
|
||
raise ValueError("Query not found")
|
||
|
||
stmt = (
|
||
select(CitationRecord)
|
||
.where(CitationRecord.query_id == query_id)
|
||
.order_by(CitationRecord.queried_at.desc())
|
||
)
|
||
result = await db.execute(stmt)
|
||
records = result.scalars().all()
|
||
|
||
output = io.StringIO()
|
||
writer = csv.writer(output)
|
||
writer.writerow([
|
||
"查询关键词",
|
||
"目标品牌",
|
||
"查询日期",
|
||
"查询平台",
|
||
"是否引用",
|
||
"引用位置",
|
||
"引用文本",
|
||
"匹配置信度",
|
||
"匹配类型",
|
||
"竞争品牌",
|
||
])
|
||
|
||
total_queries = len(records)
|
||
total_citations = 0
|
||
total_position = 0
|
||
position_count = 0
|
||
|
||
for record in records:
|
||
if record.cited:
|
||
total_citations += 1
|
||
if record.citation_position is not None:
|
||
total_position += record.citation_position
|
||
position_count += 1
|
||
|
||
date_str = ""
|
||
if record.queried_at:
|
||
date_str = record.queried_at.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
platform_name = PLATFORM_NAMES.get(record.platform, record.platform)
|
||
|
||
match_type_display = ""
|
||
if record.match_type == "exact":
|
||
match_type_display = "精确匹配"
|
||
elif record.match_type == "alias":
|
||
match_type_display = "别名匹配"
|
||
elif record.match_type == "fuzzy":
|
||
match_type_display = "模糊匹配"
|
||
|
||
confidence_str = ""
|
||
if record.confidence is not None:
|
||
confidence_str = f"{record.confidence:.2f}"
|
||
|
||
writer.writerow([
|
||
query.keyword,
|
||
query.target_brand,
|
||
date_str,
|
||
platform_name,
|
||
"是" if record.cited else "否",
|
||
record.citation_position if record.citation_position is not None else "",
|
||
record.citation_text or "",
|
||
confidence_str,
|
||
match_type_display,
|
||
", ".join(record.competitor_brands) if record.competitor_brands else "",
|
||
])
|
||
|
||
# 汇总统计
|
||
writer.writerow([])
|
||
writer.writerow(["汇总统计"])
|
||
writer.writerow(["总查询次数", total_queries])
|
||
writer.writerow(["引用次数", total_citations])
|
||
citation_rate = (total_citations / total_queries * 100) if total_queries > 0 else 0.0
|
||
writer.writerow(["引用率", f"{citation_rate:.1f}%"])
|
||
avg_position = (total_position / position_count) if position_count > 0 else 0.0
|
||
writer.writerow(["平均引用位置", f"{avg_position:.1f}"])
|
||
writer.writerow(["报告生成时间", datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
|
||
|
||
return output.getvalue()
|