geo/backend/app/services/citation/citation.py

666 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import csv
import io
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.scoring.scoring_service import ScoringResultV2
from sqlalchemy import func, select, and_, cast, Integer
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.services.ai_engine.platform_bridge import execute_single_platform as _execute_single_platform_bridge
logger = logging.getLogger(__name__)
async def _verify_query_ownership(
db: AsyncSession,
query_id: uuid.UUID,
user_id: uuid.UUID,
) -> Query | None:
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_citations(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
platform: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[CitationRecord], int]:
# Build base filter: citations belonging to the user's queries
conditions = [Query.user_id == user_id]
if query_id is not None:
conditions.append(CitationRecord.query_id == query_id)
# Also verify query ownership explicitly when query_id is provided
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
return [], 0
if platform is not None:
conditions.append(CitationRecord.platform == platform)
if start_date is not None:
conditions.append(CitationRecord.queried_at >= start_date)
if end_date is not None:
conditions.append(CitationRecord.queried_at <= end_date)
stmt = (
select(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(and_(*conditions))
.order_by(CitationRecord.queried_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
items = result.scalars().all()
count_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(and_(*conditions))
)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
return list(items), total
async def get_citation_stats(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
brand_id: uuid.UUID | None = None,
) -> dict:
"""
Get citation statistics.
Args:
db: Database session
user_id: User ID
query_id: Optional query ID to filter by specific query
brand_id: Optional brand ID to filter by specific brand
Returns:
dict with citation statistics
"""
# Build base conditions
base_conditions = [Query.user_id == user_id]
# If brand_id is provided, filter by brand name matching
if brand_id is not None:
from app.models.brand import Brand
brand_stmt = select(Brand.name).where(
Brand.id == brand_id,
Brand.user_id == user_id,
)
brand_result = await db.execute(brand_stmt)
brand_name = brand_result.scalar_one_or_none()
if brand_name is None:
return {
"total_queries": 0,
"total_citations": 0,
"citation_rate": 0.0,
"avg_position": None,
"by_platform": {},
"trend": [],
}
base_conditions.append(Query.target_brand == brand_name)
# If query_id is provided, verify ownership and filter
if query_id is not None:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
return {
"total_queries": 0,
"total_citations": 0,
"citation_rate": 0.0,
"avg_position": None,
"by_platform": {},
"trend": [],
}
base_conditions.append(CitationRecord.query_id == query_id)
base_where = and_(*base_conditions)
# Total queries and citations
total_queries_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where)
)
total_queries_result = await db.execute(total_queries_stmt)
total_queries = total_queries_result.scalar_one()
total_citations_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where, CitationRecord.cited.is_(True))
)
total_citations_result = await db.execute(total_citations_stmt)
total_citations = total_citations_result.scalar_one()
citation_rate = total_citations / total_queries if total_queries > 0 else 0.0
# Average position (only for cited records with a position)
avg_pos_stmt = (
select(func.avg(CitationRecord.citation_position))
.join(Query, CitationRecord.query_id == Query.id)
.where(
base_where,
CitationRecord.cited.is_(True),
CitationRecord.citation_position.isnot(None),
)
)
avg_pos_result = await db.execute(avg_pos_stmt)
avg_position = avg_pos_result.scalar_one()
avg_position = round(avg_position, 1) if avg_position is not None else None
# By platform stats
platform_stmt = (
select(
CitationRecord.platform,
func.count().label("queries"),
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
func.avg(CitationRecord.citation_position).label("avg_position"),
)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where)
.group_by(CitationRecord.platform)
)
platform_result = await db.execute(platform_stmt)
by_platform = {}
for row in platform_result.all():
platform_name = row.platform
queries = row.queries
citations = row.citations or 0
rate = citations / queries if queries > 0 else 0.0
plat_avg_pos = row.avg_position
plat_avg_pos = round(plat_avg_pos, 1) if plat_avg_pos is not None else None
by_platform[platform_name] = {
"queries": queries,
"citations": citations,
"rate": round(rate, 2),
"avg_position": plat_avg_pos,
}
# Trend: past 30 days grouped by week
# Use naive datetime to avoid mixing with naive datetimes from database
now = datetime.utcnow()
thirty_days_ago = now - timedelta(days=30)
# Cross-database week grouping expression
dialect = db.bind.dialect.name if db.bind else "postgresql"
if dialect == "postgresql":
week_expr = func.date_trunc("week", CitationRecord.queried_at)
else:
# SQLite compatible week grouping (YYYY-WW format)
week_expr = func.strftime("%Y-%W", CitationRecord.queried_at)
trend_stmt = (
select(
week_expr.label("week_start"),
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
)
.join(Query, CitationRecord.query_id == Query.id)
.where(
base_where,
CitationRecord.queried_at >= thirty_days_ago,
)
.group_by(week_expr)
.order_by(week_expr)
)
trend_result = await db.execute(trend_stmt)
trend = []
for row in trend_result.all():
week_start = row.week_start
if isinstance(week_start, datetime):
date_str = week_start.date().isoformat()
else:
date_str = str(week_start)
trend.append({
"date": date_str,
"citations": int(row.citations or 0),
})
return {
"total_queries": total_queries,
"total_citations": total_citations,
"citation_rate": round(citation_rate, 2),
"avg_position": avg_position,
"by_platform": by_platform,
"trend": trend,
}
async def trigger_query_now(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID,
) -> QueryTask:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
raise ValueError("Query not found")
if query.status != "active":
raise ValueError("Query is not active")
platforms = query.platforms or []
if not platforms:
raise ValueError("No platforms configured for this query")
first_task = None
for platform in platforms:
task = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task)
if first_task is None:
first_task = task
await db.commit()
if first_task is not None:
await db.refresh(first_task)
# 新增:立即在后台执行查询任务
asyncio.create_task(
_execute_query_tasks(
query_id=query_id,
platforms=platforms,
keyword=query.keyword,
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
user_id=user_id,
)
)
return first_task
async def _execute_query_tasks(
query_id: uuid.UUID,
platforms: list,
keyword: str,
target_brand: str,
brand_aliases: list,
user_id: uuid.UUID | None = None,
):
"""后台执行查询任务 — 通过 Agent 框架执行,失败时回退到直接引擎"""
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
agent = CitationDetectorAgent()
try:
async with AsyncSessionLocal() as db:
# 验证 query 归属该用户
if user_id is not None:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
logger.error(f"查询 {query_id} 不属于用户 {user_id},跳过执行")
return
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.status == "pending",
QueryTask.platform.in_(platforms),
)
result = await db.execute(stmt)
tasks = result.scalars().all()
for task in tasks:
try:
task.status = "running"
task.started_at = datetime.utcnow()
task.error_message = None
await db.commit()
citation_result = await _execute_single_platform_via_agent(
agent=agent,
keyword=keyword,
platform=task.platform,
target_brand=target_brand,
brand_aliases=brand_aliases or [],
)
if citation_result:
record = CitationRecord.from_citation_result(
query_id=query_id,
platform=task.platform,
result=citation_result,
)
db.add(record)
task.status = "success"
task.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
await db.rollback()
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
await db.commit()
logger.error(f"查询任务执行失败: {task.id}, 错误: {e}")
except Exception as e:
logger.error(f"查询引擎执行失败: {e}")
finally:
await agent.close()
async def _execute_single_platform_via_agent(
agent,
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
try:
return await agent.execute_single_platform_compat(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
except Exception as agent_err:
logger.warning(
f"Agent 框架执行单平台检测失败 ({platform}): {agent_err}"
"回退到直接引擎"
)
return await _execute_single_platform_bridge(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
PLATFORM_NAMES = {
"wenxin": "文心一言",
"kimi": "Kimi",
"tongyi": "通义千问",
"doubao": "豆包",
"qingyan": "智谱清言",
"tiangong": "天工AI",
"xinghuo": "讯飞星火",
"baidu_ai": "百度AI搜索",
"yuanbao": "腾讯元宝",
}
async def export_citations_pdf(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
v2_result: ScoringResultV2 | None = None,
) -> bytes:
"""生成PDF格式报告"""
import os
from fpdf import FPDF
# 验证查询所有权(如果提供了 query_id
if query_id is not None:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
raise ValueError("Query not found")
# 构建查询条件
conditions = [Query.user_id == user_id]
if query_id is not None:
conditions.append(CitationRecord.query_id == query_id)
# 查询数据,使用 selectinload 加载 query 关系
stmt = (
select(CitationRecord)
.options(selectinload(CitationRecord.query))
.join(Query, CitationRecord.query_id == Query.id)
.where(and_(*conditions))
.order_by(CitationRecord.queried_at.desc())
)
result = await db.execute(stmt)
records = result.scalars().all()
pdf = FPDF()
pdf.add_page()
# 加载中文字体
font_paths = [
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/STHeiti Light.ttc",
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
]
font_loaded = False
for fp in font_paths:
if os.path.exists(fp):
pdf.add_font("Chinese", "", fp, uni=True)
pdf.set_font("Chinese", size=12)
font_loaded = True
break
if not font_loaded:
pdf.set_font("Helvetica", size=12)
# 封面
pdf.set_font_size(24)
pdf.cell(0, 40, "GEO 品牌曝光度分析报告", new_x="LMARGIN", new_y="NEXT", align="C")
pdf.set_font_size(12)
pdf.cell(0, 10, f"生成日期: {datetime.now().strftime('%Y-%m-%d %H:%M')}", new_x="LMARGIN", new_y="NEXT", align="C")
pdf.ln(20)
# 汇总统计
total = len(records)
cited_count = sum(1 for r in records if r.cited)
rate = f"{cited_count / total * 100:.1f}%" if total > 0 else "0%"
total_position = 0
position_count = 0
for r in records:
if r.citation_position is not None:
total_position += r.citation_position
position_count += 1
avg_pos = f"{total_position / position_count:.1f}" if position_count > 0 else "-"
pdf.set_font_size(16)
pdf.cell(0, 12, "一、汇总统计", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
pdf.cell(0, 8, f"总查询次数: {total}", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"引用次数: {cited_count}", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"引用率: {rate}", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"平均引用位置: {avg_pos}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# 平台分布
pdf.set_font_size(16)
pdf.cell(0, 12, "二、平台分布", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
platform_stats = {}
for r in records:
if r.platform not in platform_stats:
platform_stats[r.platform] = {"total": 0, "cited": 0}
platform_stats[r.platform]["total"] += 1
if r.cited:
platform_stats[r.platform]["cited"] += 1
for platform, stats in platform_stats.items():
name = PLATFORM_NAMES.get(platform, platform)
p_rate = f"{stats['cited'] / stats['total'] * 100:.1f}%" if stats['total'] > 0 else "0%"
pdf.cell(0, 8, f" {name}: 查询{stats['total']}次, 引用{stats['cited']}次, 引用率{p_rate}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(10)
# 详细数据表格
pdf.set_font_size(16)
pdf.cell(0, 12, "三、详细数据", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(9)
col_widths = [30, 25, 20, 20, 15, 80]
headers = ["查询关键词", "平台", "是否引用", "置信度", "位置", "引用文本"]
for i, h in enumerate(headers):
pdf.cell(col_widths[i], 8, h, border=1, align="C")
pdf.ln()
for r in records:
keyword = r.query.keyword if r.query else ""
platform_name = PLATFORM_NAMES.get(r.platform, r.platform)
cited_str = "" if r.cited else ""
conf = f"{r.confidence:.2f}" if r.confidence is not None else "-"
pos = str(r.citation_position) if r.citation_position is not None else "-"
text = (r.citation_text or "")[:40]
row_data = [keyword[:15], platform_name, cited_str, conf, pos, text]
for i, d in enumerate(row_data):
pdf.cell(col_widths[i], 7, d, border=1)
pdf.ln()
if v2_result is not None:
pdf.add_page()
pdf.set_font_size(16)
pdf.cell(0, 12, "四、V2 品牌可见性评分", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
pdf.cell(0, 8, f"综合评分: {v2_result.overall_score:.2f}/100", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"健康等级: {v2_result.health_level}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(5)
pdf.cell(0, 8, f"提及率: {v2_result.mention_rate.score:.2f}/{v2_result.mention_rate.max_score:.0f} ({v2_result.mention_rate.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"推荐排名: {v2_result.recommendation_rank.score:.2f}/{v2_result.recommendation_rank.max_score:.0f} ({v2_result.recommendation_rank.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"情感倾向: {v2_result.sentiment_score.score:.2f}/{v2_result.sentiment_score.max_score:.0f} ({v2_result.sentiment_score.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"引用质量: {v2_result.citation_quality.score:.2f}/{v2_result.citation_quality.max_score:.0f} ({v2_result.citation_quality.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"竞品对比: {v2_result.competitive_position.score:.2f}/{v2_result.competitive_position.max_score:.0f} ({v2_result.competitive_position.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
return pdf.output()
async def export_citations_csv(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID,
v2_result: ScoringResultV2 | None = None,
) -> str:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
raise ValueError("Query not found")
stmt = (
select(CitationRecord)
.where(CitationRecord.query_id == query_id)
.order_by(CitationRecord.queried_at.desc())
)
result = await db.execute(stmt)
records = result.scalars().all()
output = io.StringIO()
writer = csv.writer(output)
headers = [
"查询关键词",
"目标品牌",
"查询日期",
"查询平台",
"是否引用",
"引用位置",
"引用文本",
"匹配置信度",
"匹配类型",
"竞争品牌",
]
if v2_result is not None:
headers.extend([
"overall_score",
"health_level",
"mention_rate",
"recommendation_rank",
"sentiment_score",
"citation_quality",
"competitive_position",
])
writer.writerow(headers)
total_queries = len(records)
total_citations = 0
total_position = 0
position_count = 0
for record in records:
if record.cited:
total_citations += 1
if record.citation_position is not None:
total_position += record.citation_position
position_count += 1
date_str = ""
if record.queried_at:
date_str = record.queried_at.strftime("%Y-%m-%d %H:%M:%S")
platform_name = PLATFORM_NAMES.get(record.platform, record.platform)
match_type_display = ""
if record.match_type == "exact":
match_type_display = "精确匹配"
elif record.match_type == "alias":
match_type_display = "别名匹配"
elif record.match_type == "fuzzy":
match_type_display = "模糊匹配"
confidence_str = ""
if record.confidence is not None:
confidence_str = f"{record.confidence:.2f}"
row = [
query.keyword,
query.target_brand,
date_str,
platform_name,
"" if record.cited else "",
record.citation_position if record.citation_position is not None else "",
record.citation_text or "",
confidence_str,
match_type_display,
", ".join(record.competitor_brands) if record.competitor_brands else "",
]
if v2_result is not None:
row.extend([
round(v2_result.overall_score, 2),
v2_result.health_level,
round(v2_result.mention_rate.score, 2),
round(v2_result.recommendation_rank.score, 2),
round(v2_result.sentiment_score.score, 2),
round(v2_result.citation_quality.score, 2),
round(v2_result.competitive_position.score, 2),
])
writer.writerow(row)
# 汇总统计
writer.writerow([])
writer.writerow(["汇总统计"])
writer.writerow(["总查询次数", total_queries])
writer.writerow(["引用次数", total_citations])
citation_rate = (total_citations / total_queries * 100) if total_queries > 0 else 0.0
writer.writerow(["引用率", f"{citation_rate:.1f}%"])
avg_position = (total_position / position_count) if position_count > 0 else 0.0
writer.writerow(["平均引用位置", f"{avg_position:.1f}"])
writer.writerow(["报告生成时间", datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
return output.getvalue()