import asyncio import csv import io import logging import uuid from datetime import datetime, timedelta, timezone from sqlalchemy import func, select, and_, cast, Integer from sqlalchemy.ext.asyncio import AsyncSession from app.database import AsyncSessionLocal from app.models.citation_record import CitationRecord from app.models.query import Query from app.models.query_task import QueryTask from app.workers.citation_engine import CitationEngine logger = logging.getLogger(__name__) async def _verify_query_ownership( db: AsyncSession, query_id: uuid.UUID, user_id: uuid.UUID, ) -> Query | None: stmt = select(Query).where(Query.id == query_id, Query.user_id == user_id) result = await db.execute(stmt) return result.scalar_one_or_none() async def get_citations( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID | None = None, platform: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, skip: int = 0, limit: int = 20, ) -> tuple[list[CitationRecord], int]: # Build base filter: citations belonging to the user's queries conditions = [Query.user_id == user_id] if query_id is not None: conditions.append(CitationRecord.query_id == query_id) # Also verify query ownership explicitly when query_id is provided query = await _verify_query_ownership(db, query_id, user_id) if query is None: return [], 0 if platform is not None: conditions.append(CitationRecord.platform == platform) if start_date is not None: conditions.append(CitationRecord.queried_at >= start_date) if end_date is not None: conditions.append(CitationRecord.queried_at <= end_date) stmt = ( select(CitationRecord) .join(Query, CitationRecord.query_id == Query.id) .where(and_(*conditions)) .order_by(CitationRecord.queried_at.desc()) .offset(skip) .limit(limit) ) result = await db.execute(stmt) items = result.scalars().all() count_stmt = ( select(func.count()) .select_from(CitationRecord) .join(Query, CitationRecord.query_id == Query.id) .where(and_(*conditions)) ) count_result = await db.execute(count_stmt) total = count_result.scalar_one() return list(items), total async def get_citation_stats( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID | None = None, ) -> dict: # Verify ownership if query_id provided if query_id is not None: query = await _verify_query_ownership(db, query_id, user_id) if query is None: return { "total_queries": 0, "total_citations": 0, "citation_rate": 0.0, "avg_position": None, "by_platform": {}, "trend": [], } # Base filter base_conditions = [Query.user_id == user_id] if query_id is not None: base_conditions.append(CitationRecord.query_id == query_id) base_where = and_(*base_conditions) # Total queries and citations total_queries_stmt = ( select(func.count()) .select_from(CitationRecord) .join(Query, CitationRecord.query_id == Query.id) .where(base_where) ) total_queries_result = await db.execute(total_queries_stmt) total_queries = total_queries_result.scalar_one() total_citations_stmt = ( select(func.count()) .select_from(CitationRecord) .join(Query, CitationRecord.query_id == Query.id) .where(base_where, CitationRecord.cited.is_(True)) ) total_citations_result = await db.execute(total_citations_stmt) total_citations = total_citations_result.scalar_one() citation_rate = total_citations / total_queries if total_queries > 0 else 0.0 # Average position (only for cited records with a position) avg_pos_stmt = ( select(func.avg(CitationRecord.citation_position)) .join(Query, CitationRecord.query_id == Query.id) .where( base_where, CitationRecord.cited.is_(True), CitationRecord.citation_position.isnot(None), ) ) avg_pos_result = await db.execute(avg_pos_stmt) avg_position = avg_pos_result.scalar_one() avg_position = round(avg_position, 1) if avg_position is not None else None # By platform stats platform_stmt = ( select( CitationRecord.platform, func.count().label("queries"), func.sum(cast(CitationRecord.cited, Integer)).label("citations"), func.avg(CitationRecord.citation_position).label("avg_position"), ) .join(Query, CitationRecord.query_id == Query.id) .where(base_where) .group_by(CitationRecord.platform) ) platform_result = await db.execute(platform_stmt) by_platform = {} for row in platform_result.all(): platform_name = row.platform queries = row.queries citations = row.citations or 0 rate = citations / queries if queries > 0 else 0.0 plat_avg_pos = row.avg_position plat_avg_pos = round(plat_avg_pos, 1) if plat_avg_pos is not None else None by_platform[platform_name] = { "queries": queries, "citations": citations, "rate": round(rate, 2), "avg_position": plat_avg_pos, } # Trend: past 30 days grouped by week # Use naive datetime to avoid mixing with naive datetimes from database now = datetime.utcnow() thirty_days_ago = now - timedelta(days=30) # Cross-database week grouping expression dialect = db.bind.dialect.name if db.bind else "postgresql" if dialect == "postgresql": week_expr = func.date_trunc("week", CitationRecord.queried_at) else: # SQLite compatible week grouping (YYYY-WW format) week_expr = func.strftime("%Y-%W", CitationRecord.queried_at) trend_stmt = ( select( week_expr.label("week_start"), func.sum(cast(CitationRecord.cited, Integer)).label("citations"), ) .join(Query, CitationRecord.query_id == Query.id) .where( base_where, CitationRecord.queried_at >= thirty_days_ago, ) .group_by(week_expr) .order_by(week_expr) ) trend_result = await db.execute(trend_stmt) trend = [] for row in trend_result.all(): week_start = row.week_start if isinstance(week_start, datetime): date_str = week_start.date().isoformat() else: date_str = str(week_start) trend.append({ "date": date_str, "citations": int(row.citations or 0), }) return { "total_queries": total_queries, "total_citations": total_citations, "citation_rate": round(citation_rate, 2), "avg_position": avg_position, "by_platform": by_platform, "trend": trend, } async def trigger_query_now( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID, ) -> QueryTask: query = await _verify_query_ownership(db, query_id, user_id) if query is None: raise ValueError("Query not found") if query.status != "active": raise ValueError("Query is not active") platforms = query.platforms or [] if not platforms: raise ValueError("No platforms configured for this query") first_task = None for platform in platforms: task = QueryTask( query_id=query_id, platform=platform, status="pending", ) db.add(task) if first_task is None: first_task = task await db.commit() if first_task is not None: await db.refresh(first_task) # 新增:立即在后台执行查询任务 asyncio.create_task( _execute_query_tasks( query_id=query_id, platforms=platforms, keyword=query.keyword, target_brand=query.target_brand, brand_aliases=query.brand_aliases or [], ) ) return first_task async def _execute_query_tasks( query_id: uuid.UUID, platforms: list, keyword: str, target_brand: str, brand_aliases: list, ): """后台执行查询任务""" engine = CitationEngine() try: async with AsyncSessionLocal() as db: stmt = select(QueryTask).where( QueryTask.query_id == query_id, QueryTask.status == "pending", QueryTask.platform.in_(platforms), ) result = await db.execute(stmt) tasks = result.scalars().all() for task in tasks: try: task.status = "running" task.started_at = datetime.utcnow() task.error_message = None await db.commit() citation_result = await engine.execute_single_platform( keyword=keyword, platform=task.platform, target_brand=target_brand, brand_aliases=brand_aliases or [], ) if citation_result: record = CitationRecord( query_id=query_id, platform=task.platform, cited=citation_result.get("cited", False), citation_position=citation_result.get("position"), citation_text=citation_result.get("citation_text"), competitor_brands=citation_result.get("competitor_brands", []), raw_response=citation_result.get("raw_response", ""), confidence=citation_result.get("confidence"), match_type=citation_result.get("match_type"), ) db.add(record) task.status = "success" task.completed_at = datetime.utcnow() await db.commit() except Exception as e: await db.rollback() task.status = "failed" task.error_message = str(e) task.completed_at = datetime.utcnow() await db.commit() logger.error(f"查询任务执行失败: {task.id}, 错误: {e}") except Exception as e: logger.error(f"查询引擎执行失败: {e}") finally: await engine.close() PLATFORM_NAMES = { "wenxin": "文心一言", "kimi": "Kimi", "tongyi": "通义千问", "doubao": "豆包", "qingyan": "智谱清言", "tiangong": "天工AI", "xinghuo": "讯飞星火", "baidu_ai": "百度AI搜索", "yuanbao": "腾讯元宝", } async def export_citations_csv( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID, ) -> str: query = await _verify_query_ownership(db, query_id, user_id) if query is None: raise ValueError("Query not found") stmt = ( select(CitationRecord) .where(CitationRecord.query_id == query_id) .order_by(CitationRecord.queried_at.desc()) ) result = await db.execute(stmt) records = result.scalars().all() output = io.StringIO() writer = csv.writer(output) writer.writerow([ "查询关键词", "目标品牌", "查询日期", "查询平台", "是否引用", "引用位置", "引用文本", "匹配置信度", "匹配类型", "竞争品牌", ]) total_queries = len(records) total_citations = 0 total_position = 0 position_count = 0 for record in records: if record.cited: total_citations += 1 if record.citation_position is not None: total_position += record.citation_position position_count += 1 date_str = "" if record.queried_at: date_str = record.queried_at.strftime("%Y-%m-%d %H:%M:%S") platform_name = PLATFORM_NAMES.get(record.platform, record.platform) match_type_display = "" if record.match_type == "exact": match_type_display = "精确匹配" elif record.match_type == "alias": match_type_display = "别名匹配" elif record.match_type == "fuzzy": match_type_display = "模糊匹配" confidence_str = "" if record.confidence is not None: confidence_str = f"{record.confidence:.2f}" writer.writerow([ query.keyword, query.target_brand, date_str, platform_name, "是" if record.cited else "否", record.citation_position if record.citation_position is not None else "", record.citation_text or "", confidence_str, match_type_display, ", ".join(record.competitor_brands) if record.competitor_brands else "", ]) # 汇总统计 writer.writerow([]) writer.writerow(["汇总统计"]) writer.writerow(["总查询次数", total_queries]) writer.writerow(["引用次数", total_citations]) citation_rate = (total_citations / total_queries * 100) if total_queries > 0 else 0.0 writer.writerow(["引用率", f"{citation_rate:.1f}%"]) avg_position = (total_position / position_count) if position_count > 0 else 0.0 writer.writerow(["平均引用位置", f"{avg_position:.1f}"]) writer.writerow(["报告生成时间", datetime.now().strftime("%Y-%m-%d %H:%M:%S")]) return output.getvalue()