429 lines
14 KiB
Python
429 lines
14 KiB
Python
import asyncio
|
|
import csv
|
|
import io
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import func, select, and_, cast, Integer
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import AsyncSessionLocal
|
|
from app.models.citation_record import CitationRecord
|
|
from app.models.query import Query
|
|
from app.models.query_task import QueryTask
|
|
from app.workers.citation_engine import CitationEngine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _verify_query_ownership(
|
|
db: AsyncSession,
|
|
query_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
) -> Query | None:
|
|
stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_citations(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
query_id: uuid.UUID | None = None,
|
|
platform: str | None = None,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
skip: int = 0,
|
|
limit: int = 20,
|
|
) -> tuple[list[CitationRecord], int]:
|
|
# Build base filter: citations belonging to the user's queries
|
|
conditions = [Query.user_id == user_id]
|
|
|
|
if query_id is not None:
|
|
conditions.append(CitationRecord.query_id == query_id)
|
|
# Also verify query ownership explicitly when query_id is provided
|
|
query = await _verify_query_ownership(db, query_id, user_id)
|
|
if query is None:
|
|
return [], 0
|
|
|
|
if platform is not None:
|
|
conditions.append(CitationRecord.platform == platform)
|
|
|
|
if start_date is not None:
|
|
conditions.append(CitationRecord.queried_at >= start_date)
|
|
|
|
if end_date is not None:
|
|
conditions.append(CitationRecord.queried_at <= end_date)
|
|
|
|
stmt = (
|
|
select(CitationRecord)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(and_(*conditions))
|
|
.order_by(CitationRecord.queried_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(stmt)
|
|
items = result.scalars().all()
|
|
|
|
count_stmt = (
|
|
select(func.count())
|
|
.select_from(CitationRecord)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(and_(*conditions))
|
|
)
|
|
count_result = await db.execute(count_stmt)
|
|
total = count_result.scalar_one()
|
|
|
|
return list(items), total
|
|
|
|
|
|
async def get_citation_stats(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
query_id: uuid.UUID | None = None,
|
|
) -> dict:
|
|
# Verify ownership if query_id provided
|
|
if query_id is not None:
|
|
query = await _verify_query_ownership(db, query_id, user_id)
|
|
if query is None:
|
|
return {
|
|
"total_queries": 0,
|
|
"total_citations": 0,
|
|
"citation_rate": 0.0,
|
|
"avg_position": None,
|
|
"by_platform": {},
|
|
"trend": [],
|
|
}
|
|
|
|
# Base filter
|
|
base_conditions = [Query.user_id == user_id]
|
|
if query_id is not None:
|
|
base_conditions.append(CitationRecord.query_id == query_id)
|
|
|
|
base_where = and_(*base_conditions)
|
|
|
|
# Total queries and citations
|
|
total_queries_stmt = (
|
|
select(func.count())
|
|
.select_from(CitationRecord)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(base_where)
|
|
)
|
|
total_queries_result = await db.execute(total_queries_stmt)
|
|
total_queries = total_queries_result.scalar_one()
|
|
|
|
total_citations_stmt = (
|
|
select(func.count())
|
|
.select_from(CitationRecord)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(base_where, CitationRecord.cited.is_(True))
|
|
)
|
|
total_citations_result = await db.execute(total_citations_stmt)
|
|
total_citations = total_citations_result.scalar_one()
|
|
|
|
citation_rate = total_citations / total_queries if total_queries > 0 else 0.0
|
|
|
|
# Average position (only for cited records with a position)
|
|
avg_pos_stmt = (
|
|
select(func.avg(CitationRecord.citation_position))
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(
|
|
base_where,
|
|
CitationRecord.cited.is_(True),
|
|
CitationRecord.citation_position.isnot(None),
|
|
)
|
|
)
|
|
avg_pos_result = await db.execute(avg_pos_stmt)
|
|
avg_position = avg_pos_result.scalar_one()
|
|
avg_position = round(avg_position, 1) if avg_position is not None else None
|
|
|
|
# By platform stats
|
|
platform_stmt = (
|
|
select(
|
|
CitationRecord.platform,
|
|
func.count().label("queries"),
|
|
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
|
|
func.avg(CitationRecord.citation_position).label("avg_position"),
|
|
)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(base_where)
|
|
.group_by(CitationRecord.platform)
|
|
)
|
|
platform_result = await db.execute(platform_stmt)
|
|
by_platform = {}
|
|
for row in platform_result.all():
|
|
platform_name = row.platform
|
|
queries = row.queries
|
|
citations = row.citations or 0
|
|
rate = citations / queries if queries > 0 else 0.0
|
|
plat_avg_pos = row.avg_position
|
|
plat_avg_pos = round(plat_avg_pos, 1) if plat_avg_pos is not None else None
|
|
by_platform[platform_name] = {
|
|
"queries": queries,
|
|
"citations": citations,
|
|
"rate": round(rate, 2),
|
|
"avg_position": plat_avg_pos,
|
|
}
|
|
|
|
# Trend: past 30 days grouped by week
|
|
# Use naive datetime to avoid mixing with naive datetimes from database
|
|
now = datetime.utcnow()
|
|
thirty_days_ago = now - timedelta(days=30)
|
|
|
|
# Cross-database week grouping expression
|
|
dialect = db.bind.dialect.name if db.bind else "postgresql"
|
|
if dialect == "postgresql":
|
|
week_expr = func.date_trunc("week", CitationRecord.queried_at)
|
|
else:
|
|
# SQLite compatible week grouping (YYYY-WW format)
|
|
week_expr = func.strftime("%Y-%W", CitationRecord.queried_at)
|
|
|
|
trend_stmt = (
|
|
select(
|
|
week_expr.label("week_start"),
|
|
func.sum(cast(CitationRecord.cited, Integer)).label("citations"),
|
|
)
|
|
.join(Query, CitationRecord.query_id == Query.id)
|
|
.where(
|
|
base_where,
|
|
CitationRecord.queried_at >= thirty_days_ago,
|
|
)
|
|
.group_by(week_expr)
|
|
.order_by(week_expr)
|
|
)
|
|
trend_result = await db.execute(trend_stmt)
|
|
trend = []
|
|
for row in trend_result.all():
|
|
week_start = row.week_start
|
|
if isinstance(week_start, datetime):
|
|
date_str = week_start.date().isoformat()
|
|
else:
|
|
date_str = str(week_start)
|
|
trend.append({
|
|
"date": date_str,
|
|
"citations": int(row.citations or 0),
|
|
})
|
|
|
|
return {
|
|
"total_queries": total_queries,
|
|
"total_citations": total_citations,
|
|
"citation_rate": round(citation_rate, 2),
|
|
"avg_position": avg_position,
|
|
"by_platform": by_platform,
|
|
"trend": trend,
|
|
}
|
|
|
|
|
|
async def trigger_query_now(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
query_id: uuid.UUID,
|
|
) -> QueryTask:
|
|
query = await _verify_query_ownership(db, query_id, user_id)
|
|
if query is None:
|
|
raise ValueError("Query not found")
|
|
|
|
if query.status != "active":
|
|
raise ValueError("Query is not active")
|
|
|
|
platforms = query.platforms or []
|
|
if not platforms:
|
|
raise ValueError("No platforms configured for this query")
|
|
|
|
first_task = None
|
|
for platform in platforms:
|
|
task = QueryTask(
|
|
query_id=query_id,
|
|
platform=platform,
|
|
status="pending",
|
|
)
|
|
db.add(task)
|
|
if first_task is None:
|
|
first_task = task
|
|
|
|
await db.commit()
|
|
if first_task is not None:
|
|
await db.refresh(first_task)
|
|
|
|
# 新增:立即在后台执行查询任务
|
|
asyncio.create_task(
|
|
_execute_query_tasks(
|
|
query_id=query_id,
|
|
platforms=platforms,
|
|
keyword=query.keyword,
|
|
target_brand=query.target_brand,
|
|
brand_aliases=query.brand_aliases or [],
|
|
)
|
|
)
|
|
|
|
return first_task
|
|
|
|
|
|
async def _execute_query_tasks(
|
|
query_id: uuid.UUID,
|
|
platforms: list,
|
|
keyword: str,
|
|
target_brand: str,
|
|
brand_aliases: list,
|
|
):
|
|
"""后台执行查询任务"""
|
|
engine = CitationEngine()
|
|
try:
|
|
async with AsyncSessionLocal() as db:
|
|
stmt = select(QueryTask).where(
|
|
QueryTask.query_id == query_id,
|
|
QueryTask.status == "pending",
|
|
QueryTask.platform.in_(platforms),
|
|
)
|
|
result = await db.execute(stmt)
|
|
tasks = result.scalars().all()
|
|
|
|
for task in tasks:
|
|
try:
|
|
task.status = "running"
|
|
task.started_at = datetime.utcnow()
|
|
task.error_message = None
|
|
await db.commit()
|
|
|
|
citation_result = await engine.execute_single_platform(
|
|
keyword=keyword,
|
|
platform=task.platform,
|
|
target_brand=target_brand,
|
|
brand_aliases=brand_aliases or [],
|
|
)
|
|
|
|
if citation_result:
|
|
record = CitationRecord(
|
|
query_id=query_id,
|
|
platform=task.platform,
|
|
cited=citation_result.get("cited", False),
|
|
citation_position=citation_result.get("position"),
|
|
citation_text=citation_result.get("citation_text"),
|
|
competitor_brands=citation_result.get("competitor_brands", []),
|
|
raw_response=citation_result.get("raw_response", ""),
|
|
confidence=citation_result.get("confidence"),
|
|
match_type=citation_result.get("match_type"),
|
|
)
|
|
db.add(record)
|
|
|
|
task.status = "success"
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
except Exception as e:
|
|
await db.rollback()
|
|
task.status = "failed"
|
|
task.error_message = str(e)
|
|
task.completed_at = datetime.utcnow()
|
|
await db.commit()
|
|
logger.error(f"查询任务执行失败: {task.id}, 错误: {e}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"查询引擎执行失败: {e}")
|
|
finally:
|
|
await engine.close()
|
|
|
|
|
|
PLATFORM_NAMES = {
|
|
"wenxin": "文心一言",
|
|
"kimi": "Kimi",
|
|
"tongyi": "通义千问",
|
|
"doubao": "豆包",
|
|
"qingyan": "智谱清言",
|
|
"tiangong": "天工AI",
|
|
"xinghuo": "讯飞星火",
|
|
"baidu_ai": "百度AI搜索",
|
|
"yuanbao": "腾讯元宝",
|
|
}
|
|
|
|
|
|
async def export_citations_csv(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
query_id: uuid.UUID,
|
|
) -> str:
|
|
query = await _verify_query_ownership(db, query_id, user_id)
|
|
if query is None:
|
|
raise ValueError("Query not found")
|
|
|
|
stmt = (
|
|
select(CitationRecord)
|
|
.where(CitationRecord.query_id == query_id)
|
|
.order_by(CitationRecord.queried_at.desc())
|
|
)
|
|
result = await db.execute(stmt)
|
|
records = result.scalars().all()
|
|
|
|
output = io.StringIO()
|
|
writer = csv.writer(output)
|
|
writer.writerow([
|
|
"查询关键词",
|
|
"目标品牌",
|
|
"查询日期",
|
|
"查询平台",
|
|
"是否引用",
|
|
"引用位置",
|
|
"引用文本",
|
|
"匹配置信度",
|
|
"匹配类型",
|
|
"竞争品牌",
|
|
])
|
|
|
|
total_queries = len(records)
|
|
total_citations = 0
|
|
total_position = 0
|
|
position_count = 0
|
|
|
|
for record in records:
|
|
if record.cited:
|
|
total_citations += 1
|
|
if record.citation_position is not None:
|
|
total_position += record.citation_position
|
|
position_count += 1
|
|
|
|
date_str = ""
|
|
if record.queried_at:
|
|
date_str = record.queried_at.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
platform_name = PLATFORM_NAMES.get(record.platform, record.platform)
|
|
|
|
match_type_display = ""
|
|
if record.match_type == "exact":
|
|
match_type_display = "精确匹配"
|
|
elif record.match_type == "alias":
|
|
match_type_display = "别名匹配"
|
|
elif record.match_type == "fuzzy":
|
|
match_type_display = "模糊匹配"
|
|
|
|
confidence_str = ""
|
|
if record.confidence is not None:
|
|
confidence_str = f"{record.confidence:.2f}"
|
|
|
|
writer.writerow([
|
|
query.keyword,
|
|
query.target_brand,
|
|
date_str,
|
|
platform_name,
|
|
"是" if record.cited else "否",
|
|
record.citation_position if record.citation_position is not None else "",
|
|
record.citation_text or "",
|
|
confidence_str,
|
|
match_type_display,
|
|
", ".join(record.competitor_brands) if record.competitor_brands else "",
|
|
])
|
|
|
|
# 汇总统计
|
|
writer.writerow([])
|
|
writer.writerow(["汇总统计"])
|
|
writer.writerow(["总查询次数", total_queries])
|
|
writer.writerow(["引用次数", total_citations])
|
|
citation_rate = (total_citations / total_queries * 100) if total_queries > 0 else 0.0
|
|
writer.writerow(["引用率", f"{citation_rate:.1f}%"])
|
|
avg_position = (total_position / position_count) if position_count > 0 else 0.0
|
|
writer.writerow(["平均引用位置", f"{avg_position:.1f}"])
|
|
writer.writerow(["报告生成时间", datetime.now().strftime("%Y-%m-%d %H:%M:%S")])
|
|
|
|
return output.getvalue()
|