geo/backend/app/agent_framework/agents/citation_detector.py

314 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.services.ai_engine.platform_bridge import execute_single_platform
logger = logging.getLogger(__name__)
class CitationDetectorAgent(BaseAgent):
def __init__(self):
super().__init__(
name="citation_detector",
agent_type=AgentType.CITATION_DETECTOR,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["citation_detect", "citation_detect_single"],
max_concurrency=3,
description="AI平台引用检测Agent检测目标品牌在各AI平台回答中的引用情况",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
if task.task_type == "citation_detect":
output = await self._execute_full_detect(task)
elif task.task_type == "citation_detect_single":
output = await self._execute_single_detect(task)
else:
raise ValueError(f"Unsupported task type: {task.task_type}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"CitationDetector task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _execute_full_detect(self, task: TaskMessage) -> dict:
query_id = task.input_data.get("query_id")
if not query_id:
raise ValueError("input_data must contain 'query_id'")
async with AsyncSessionLocal() as db:
stmt = select(Query).where(Query.id == query_id)
result = await db.execute(stmt)
query = result.scalar_one_or_none()
if not query:
raise ValueError(f"Query {query_id} not found")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message=f"Starting citation detection for query '{query.keyword}'",
)
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
brand_aliases = query.brand_aliases or []
for i, platform_name in enumerate(platforms):
progress = 0.1 + 0.8 * (i / len(platforms))
await self.report_progress(
task_id=task.task_id,
progress=progress,
message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})",
)
task_obj = await self._get_or_create_task(db, query.id, platform_name)
task_obj.status = "running"
task_obj.started_at = datetime.utcnow()
task_obj.error_message = None
await db.commit()
try:
detect_result = await execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=brand_aliases,
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=detect_result,
)
db.add(record)
records.append(record)
task_obj.status = "success"
task_obj.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
task_obj.status = "failed"
task_obj.error_message = str(e)
task_obj.completed_at = datetime.utcnow()
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": str(e)},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Detection completed: {len(records)} records found",
)
record_summaries = []
for r in records:
record_summaries.append({
"id": str(r.id),
"platform": r.platform,
"cited": r.cited,
"confidence": r.confidence,
"match_type": r.match_type,
})
return {
"query_id": str(query_id),
"keyword": query.keyword,
"total_records": len(records),
"cited_count": sum(1 for r in records if r.cited),
"records": record_summaries,
}
async def _execute_single_detect(self, task: TaskMessage) -> dict:
keyword = task.input_data.get("keyword")
platform = task.input_data.get("platform")
target_brand = task.input_data.get("target_brand")
brand_aliases = task.input_data.get("brand_aliases", [])
if not all([keyword, platform, target_brand]):
raise ValueError(
"input_data must contain 'keyword', 'platform', 'target_brand'"
)
await self.report_progress(
task_id=task.task_id,
progress=0.2,
message=f"Querying platform '{platform}' with keyword '{keyword}'",
)
result = await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Single platform detection completed on '{platform}'",
)
output = {k: v for k, v in result.items() if k != "raw_response"}
return output
async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]:
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
brand_aliases = query.brand_aliases or []
for platform_name in platforms:
task_obj = await self._get_or_create_task(db, query.id, platform_name)
task_obj.status = "running"
task_obj.started_at = datetime.utcnow()
task_obj.error_message = None
await db.commit()
try:
detect_result = await execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=brand_aliases,
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=detect_result,
)
db.add(record)
records.append(record)
task_obj.status = "success"
task_obj.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
task_obj.status = "failed"
task_obj.error_message = str(e)
task_obj.completed_at = datetime.utcnow()
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": str(e)},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
return records
async def execute_single_platform_compat(
self, keyword: str, platform: str, target_brand: str, brand_aliases: list
) -> dict:
return await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask:
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task_obj = result.scalar_one_or_none()
if not task_obj:
task_obj = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task_obj)
await db.commit()
await db.refresh(task_obj)
return task_obj
@staticmethod
def _calculate_next_query_at(frequency: str | None) -> datetime:
now = datetime.utcnow()
freq_map = {
"daily": timedelta(days=1),
"weekly": timedelta(days=7),
"monthly": timedelta(days=30),
}
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
return now + delta