geo/backend/app/services/citation.py

429 lines
14 KiB
Python

import asyncio
import csv
import io
import logging
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import func, select, and_, cast, Integer
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.workers.citation_engine import CitationEngine
logger = logging.getLogger(__name__)
async def _verify_query_ownership(
db: AsyncSession,
query_id: uuid.UUID,
user_id: uuid.UUID,
) -> Query | None:
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_citations(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
platform: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[CitationRecord], int]:
# Build base filter: citations belonging to the user's queries
conditions = [Query.user_id == user_id]
if query_id is not None:
conditions.append(CitationRecord.query_id == query_id)
# Also verify query ownership explicitly when query_id is provided
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
return [], 0
if platform is not None:
conditions.append(CitationRecord.platform == platform)
if start_date is not None:
conditions.append(CitationRecord.queried_at >= start_date)
if end_date is not None:
conditions.append(CitationRecord.queried_at <= end_date)
stmt = (
select(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(and_(*conditions))
.order_by(CitationRecord.queried_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
items = result.scalars().all()
count_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(and_(*conditions))
)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
return list(items), total
async def get_citation_stats(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
) -> dict:
# Verify ownership if query_id provided
if query_id is not None:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
return {
"total_queries": 0,
"total_citations": 0,
"citation_rate": 0.0,
"avg_position": None,
"by_platform": {},
"trend": [],
}
# Base filter
base_conditions = [Query.user_id == user_id]
if query_id is not None:
base_conditions.append(CitationRecord.query_id == query_id)
base_where = and_(*base_conditions)
# Total queries and citations
total_queries_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where)
)
total_queries_result = await db.execute(total_queries_stmt)
total_queries = total_queries_result.scalar_one()
total_citations_stmt = (
select(func.count())
.select_from(CitationRecord)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where, CitationRecord.cited.is_(True))
)
total_citations_result = await db.execute(total_citations_stmt)
total_citations = total_citations_result.scalar_one()
citation_rate = total_citations / total_queries if total_queries > 0 else 0.0
# Average position (only for cited records with a position)
avg_pos_stmt = (
select(func.avg(CitationRecord.citation_position))
.join(Query, CitationRecord.query_id == Query.id)
.where(
base_where,
CitationRecord.cited.is_(True),
CitationRecord.citation_position.isnot(None),
)
)
avg_pos_result = await db.execute(avg_pos_stmt)
avg_position = avg_pos_result.scalar_one()
avg_position = round(avg_position, 1) if avg_position is not None else None
# By platform stats
platform_stmt = (
select(
CitationRecord.platform,
func.count().label("queries"),
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
func.avg(CitationRecord.citation_position).label("avg_position"),
)
.join(Query, CitationRecord.query_id == Query.id)
.where(base_where)
.group_by(CitationRecord.platform)
)
platform_result = await db.execute(platform_stmt)
by_platform = {}
for row in platform_result.all():
platform_name = row.platform
queries = row.queries
citations = row.citations or 0
rate = citations / queries if queries > 0 else 0.0
plat_avg_pos = row.avg_position
plat_avg_pos = round(plat_avg_pos, 1) if plat_avg_pos is not None else None
by_platform[platform_name] = {
"queries": queries,
"citations": citations,
"rate": round(rate, 2),
"avg_position": plat_avg_pos,
}
# Trend: past 30 days grouped by week
# Use naive datetime to avoid mixing with naive datetimes from database
now = datetime.utcnow()
thirty_days_ago = now - timedelta(days=30)
# Cross-database week grouping expression
dialect = db.bind.dialect.name if db.bind else "postgresql"
if dialect == "postgresql":
week_expr = func.date_trunc("week", CitationRecord.queried_at)
else:
# SQLite compatible week grouping (YYYY-WW format)
week_expr = func.strftime("%Y-%W", CitationRecord.queried_at)
trend_stmt = (
select(
week_expr.label("week_start"),
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
)
.join(Query, CitationRecord.query_id == Query.id)
.where(
base_where,
CitationRecord.queried_at >= thirty_days_ago,
)
.group_by(week_expr)
.order_by(week_expr)
)
trend_result = await db.execute(trend_stmt)
trend = []
for row in trend_result.all():
week_start = row.week_start
if isinstance(week_start, datetime):
date_str = week_start.date().isoformat()
else:
date_str = str(week_start)
trend.append({
"date": date_str,
"citations": int(row.citations or 0),
})
return {
"total_queries": total_queries,
"total_citations": total_citations,
"citation_rate": round(citation_rate, 2),
"avg_position": avg_position,
"by_platform": by_platform,
"trend": trend,
}
async def trigger_query_now(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID,
) -> QueryTask:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
raise ValueError("Query not found")
if query.status != "active":
raise ValueError("Query is not active")
platforms = query.platforms or []
if not platforms:
raise ValueError("No platforms configured for this query")
first_task = None
for platform in platforms:
task = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task)
if first_task is None:
first_task = task
await db.commit()
if first_task is not None:
await db.refresh(first_task)
# 新增:立即在后台执行查询任务
asyncio.create_task(
_execute_query_tasks(
query_id=query_id,
platforms=platforms,
keyword=query.keyword,
target_brand=query.target_brand,
brand_aliases=query.brand_aliases or [],
)
)
return first_task
async def _execute_query_tasks(
query_id: uuid.UUID,
platforms: list,
keyword: str,
target_brand: str,
brand_aliases: list,
):
"""后台执行查询任务"""
engine = CitationEngine()
try:
async with AsyncSessionLocal() as db:
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.status == "pending",
QueryTask.platform.in_(platforms),
)
result = await db.execute(stmt)
tasks = result.scalars().all()
for task in tasks:
try:
task.status = "running"
task.started_at = datetime.utcnow()
task.error_message = None
await db.commit()
citation_result = await engine.execute_single_platform(
keyword=keyword,
platform=task.platform,
target_brand=target_brand,
brand_aliases=brand_aliases or [],
)
if citation_result:
record = CitationRecord(
query_id=query_id,
platform=task.platform,
cited=citation_result.get("cited", False),
citation_position=citation_result.get("position"),
citation_text=citation_result.get("citation_text"),
competitor_brands=citation_result.get("competitor_brands", []),
raw_response=citation_result.get("raw_response", ""),
confidence=citation_result.get("confidence"),
match_type=citation_result.get("match_type"),
)
db.add(record)
task.status = "success"
task.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
await db.rollback()
task.status = "failed"
task.error_message = str(e)
task.completed_at = datetime.utcnow()
await db.commit()
logger.error(f"查询任务执行失败: {task.id}, 错误: {e}")
except Exception as e:
logger.error(f"查询引擎执行失败: {e}")
finally:
await engine.close()
PLATFORM_NAMES = {
"wenxin": "文心一言",
"kimi": "Kimi",
"tongyi": "通义千问",
"doubao": "豆包",
"qingyan": "智谱清言",
"tiangong": "天工AI",
"xinghuo": "讯飞星火",
"baidu_ai": "百度AI搜索",
"yuanbao": "腾讯元宝",
}
async def export_citations_csv(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID,
) -> str:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
raise ValueError("Query not found")
stmt = (
select(CitationRecord)
.where(CitationRecord.query_id == query_id)
.order_by(CitationRecord.queried_at.desc())
)
result = await db.execute(stmt)
records = result.scalars().all()
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"查询关键词",
"目标品牌",
"查询日期",
"查询平台",
"是否引用",
"引用位置",
"引用文本",
"匹配置信度",
"匹配类型",
"竞争品牌",
])
total_queries = len(records)
total_citations = 0
total_position = 0
position_count = 0
for record in records:
if record.cited:
total_citations += 1
if record.citation_position is not None:
total_position += record.citation_position
position_count += 1
date_str = ""
if record.queried_at:
date_str = record.queried_at.strftime("%Y-%m-%d %H:%M:%S")
platform_name = PLATFORM_NAMES.get(record.platform, record.platform)
match_type_display = ""
if record.match_type == "exact":
match_type_display = "精确匹配"
elif record.match_type == "alias":
match_type_display = "别名匹配"
elif record.match_type == "fuzzy":
match_type_display = "模糊匹配"
confidence_str = ""
if record.confidence is not None:
confidence_str = f"{record.confidence:.2f}"
writer.writerow([
query.keyword,
query.target_brand,
date_str,
platform_name,
"" if record.cited else "",
record.citation_position if record.citation_position is not None else "",
record.citation_text or "",
confidence_str,
match_type_display,
", ".join(record.competitor_brands) if record.competitor_brands else "",
])
# 汇总统计
writer.writerow([])
writer.writerow(["汇总统计"])
writer.writerow(["总查询次数", total_queries])
writer.writerow(["引用次数", total_citations])
citation_rate = (total_citations / total_queries * 100) if total_queries > 0 else 0.0
writer.writerow(["引用率", f"{citation_rate:.1f}%"])
avg_position = (total_position / position_count) if position_count > 0 else 0.0
writer.writerow(["平均引用位置", f"{avg_position:.1f}"])
writer.writerow(["报告生成时间", datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
return output.getvalue()