import logging import time import uuid from datetime import datetime, timedelta, timezone from sqlalchemy import select from app.agent_framework.base import BaseAgent from app.agent_framework.protocol import ( AgentCapability, AgentType, TaskMessage, TaskResult, TaskStatus, ) from app.database import AsyncSessionLocal from app.models.citation_record import CitationRecord from app.models.query import Query from app.models.query_task import QueryTask from app.services.ai_engine.platform_bridge import execute_single_platform logger = logging.getLogger(__name__) class CitationDetectorAgent(BaseAgent): def __init__(self): super().__init__( name="citation_detector", agent_type=AgentType.CITATION_DETECTOR, version="1.0.0", ) def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["citation_detect", "citation_detect_single"], max_concurrency=3, description="AI平台引用检测Agent:检测目标品牌在各AI平台回答中的引用情况", ) async def execute(self, task: TaskMessage) -> TaskResult: started_at = datetime.now(timezone.utc) start_time = time.monotonic() try: if task.task_type == "citation_detect": output = await self._execute_full_detect(task) elif task.task_type == "citation_detect_single": output = await self._execute_single_detect(task) else: raise ValueError(f"Unsupported task type: {task.task_type}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.COMPLETED, output_data=output, error_message=None, started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) except Exception as e: elapsed = time.monotonic() - start_time logger.error(f"CitationDetector task {task.task_id} failed: {e}") return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) async def _execute_full_detect(self, task: TaskMessage) -> dict: query_id = task.input_data.get("query_id") if not query_id: raise ValueError("input_data must contain 'query_id'") async with AsyncSessionLocal() as db: stmt = select(Query).where(Query.id == query_id) result = await db.execute(stmt) query = result.scalar_one_or_none() if not query: raise ValueError(f"Query {query_id} not found") await self.report_progress( task_id=task.task_id, progress=0.1, message=f"Starting citation detection for query '{query.keyword}'", ) records: list[CitationRecord] = [] platforms = query.platforms or ["wenxin", "kimi"] brand_aliases = query.brand_aliases or [] for i, platform_name in enumerate(platforms): progress = 0.1 + 0.8 * (i / len(platforms)) await self.report_progress( task_id=task.task_id, progress=progress, message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})", ) task_obj = await self._get_or_create_task(db, query.id, platform_name) task_obj.status = "running" task_obj.started_at = datetime.utcnow() task_obj.error_message = None await db.commit() try: detect_result = await execute_single_platform( keyword=query.keyword, platform=platform_name, target_brand=query.target_brand, brand_aliases=brand_aliases, ) record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, result=detect_result, ) db.add(record) records.append(record) task_obj.status = "success" task_obj.completed_at = datetime.utcnow() await db.commit() except Exception as e: logger.error(f"平台 {platform_name} 查询失败: {e}") task_obj.status = "failed" task_obj.error_message = str(e) task_obj.completed_at = datetime.utcnow() record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, result={"cited": False, "raw_response": str(e)}, ) db.add(record) records.append(record) await db.commit() query.last_queried_at = datetime.utcnow() query.next_query_at = self._calculate_next_query_at(query.frequency) await db.commit() await self.report_progress( task_id=task.task_id, progress=1.0, message=f"Detection completed: {len(records)} records found", ) record_summaries = [] for r in records: record_summaries.append({ "id": str(r.id), "platform": r.platform, "cited": r.cited, "confidence": r.confidence, "match_type": r.match_type, }) return { "query_id": str(query_id), "keyword": query.keyword, "total_records": len(records), "cited_count": sum(1 for r in records if r.cited), "records": record_summaries, } async def _execute_single_detect(self, task: TaskMessage) -> dict: keyword = task.input_data.get("keyword") platform = task.input_data.get("platform") target_brand = task.input_data.get("target_brand") brand_aliases = task.input_data.get("brand_aliases", []) if not all([keyword, platform, target_brand]): raise ValueError( "input_data must contain 'keyword', 'platform', 'target_brand'" ) await self.report_progress( task_id=task.task_id, progress=0.2, message=f"Querying platform '{platform}' with keyword '{keyword}'", ) result = await execute_single_platform( keyword=keyword, platform=platform, target_brand=target_brand, brand_aliases=brand_aliases, ) await self.report_progress( task_id=task.task_id, progress=1.0, message=f"Single platform detection completed on '{platform}'", ) output = {k: v for k, v in result.items() if k != "raw_response"} return output async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]: records: list[CitationRecord] = [] platforms = query.platforms or ["wenxin", "kimi"] brand_aliases = query.brand_aliases or [] for platform_name in platforms: task_obj = await self._get_or_create_task(db, query.id, platform_name) task_obj.status = "running" task_obj.started_at = datetime.utcnow() task_obj.error_message = None await db.commit() try: detect_result = await execute_single_platform( keyword=query.keyword, platform=platform_name, target_brand=query.target_brand, brand_aliases=brand_aliases, ) record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, result=detect_result, ) db.add(record) records.append(record) task_obj.status = "success" task_obj.completed_at = datetime.utcnow() await db.commit() except Exception as e: logger.error(f"平台 {platform_name} 查询失败: {e}") task_obj.status = "failed" task_obj.error_message = str(e) task_obj.completed_at = datetime.utcnow() record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, result={"cited": False, "raw_response": str(e)}, ) db.add(record) records.append(record) await db.commit() query.last_queried_at = datetime.utcnow() query.next_query_at = self._calculate_next_query_at(query.frequency) await db.commit() return records async def execute_single_platform_compat( self, keyword: str, platform: str, target_brand: str, brand_aliases: list ) -> dict: return await execute_single_platform( keyword=keyword, platform=platform, target_brand=target_brand, brand_aliases=brand_aliases, ) async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask: stmt = select(QueryTask).where( QueryTask.query_id == query_id, QueryTask.platform == platform, ) result = await db.execute(stmt) task_obj = result.scalar_one_or_none() if not task_obj: task_obj = QueryTask( query_id=query_id, platform=platform, status="pending", ) db.add(task_obj) await db.commit() await db.refresh(task_obj) return task_obj @staticmethod def _calculate_next_query_at(frequency: str | None) -> datetime: now = datetime.utcnow() freq_map = { "daily": timedelta(days=1), "weekly": timedelta(days=7), "monthly": timedelta(days=30), } delta = freq_map.get(frequency or "weekly", timedelta(days=7)) return now + delta