314 lines
11 KiB
Python
314 lines
11 KiB
Python
import logging
|
||
import time
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from sqlalchemy import select
|
||
|
||
from app.agent_framework.base import BaseAgent
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentType,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.models.query_task import QueryTask
|
||
from app.services.ai_engine.platform_bridge import execute_single_platform
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CitationDetectorAgent(BaseAgent):
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
name="citation_detector",
|
||
agent_type=AgentType.CITATION_DETECTOR,
|
||
version="1.0.0",
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["citation_detect", "citation_detect_single"],
|
||
max_concurrency=3,
|
||
description="AI平台引用检测Agent:检测目标品牌在各AI平台回答中的引用情况",
|
||
)
|
||
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
started_at = datetime.now(timezone.utc)
|
||
start_time = time.monotonic()
|
||
|
||
try:
|
||
if task.task_type == "citation_detect":
|
||
output = await self._execute_full_detect(task)
|
||
elif task.task_type == "citation_detect_single":
|
||
output = await self._execute_single_detect(task)
|
||
else:
|
||
raise ValueError(f"Unsupported task type: {task.task_type}")
|
||
|
||
elapsed = time.monotonic() - start_time
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"CitationDetector task {task.task_id} failed: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
async def _execute_full_detect(self, task: TaskMessage) -> dict:
|
||
query_id = task.input_data.get("query_id")
|
||
if not query_id:
|
||
raise ValueError("input_data must contain 'query_id'")
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
stmt = select(Query).where(Query.id == query_id)
|
||
result = await db.execute(stmt)
|
||
query = result.scalar_one_or_none()
|
||
|
||
if not query:
|
||
raise ValueError(f"Query {query_id} not found")
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message=f"Starting citation detection for query '{query.keyword}'",
|
||
)
|
||
|
||
records: list[CitationRecord] = []
|
||
platforms = query.platforms or ["wenxin", "kimi"]
|
||
brand_aliases = query.brand_aliases or []
|
||
|
||
for i, platform_name in enumerate(platforms):
|
||
progress = 0.1 + 0.8 * (i / len(platforms))
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=progress,
|
||
message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})",
|
||
)
|
||
|
||
task_obj = await self._get_or_create_task(db, query.id, platform_name)
|
||
task_obj.status = "running"
|
||
task_obj.started_at = datetime.utcnow()
|
||
task_obj.error_message = None
|
||
await db.commit()
|
||
|
||
try:
|
||
detect_result = await execute_single_platform(
|
||
keyword=query.keyword,
|
||
platform=platform_name,
|
||
target_brand=query.target_brand,
|
||
brand_aliases=brand_aliases,
|
||
)
|
||
|
||
record = CitationRecord.from_citation_result(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
result=detect_result,
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
|
||
task_obj.status = "success"
|
||
task_obj.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||
task_obj.status = "failed"
|
||
task_obj.error_message = str(e)
|
||
task_obj.completed_at = datetime.utcnow()
|
||
|
||
record = CitationRecord.from_citation_result(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
result={"cited": False, "raw_response": str(e)},
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
await db.commit()
|
||
|
||
query.last_queried_at = datetime.utcnow()
|
||
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
||
await db.commit()
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message=f"Detection completed: {len(records)} records found",
|
||
)
|
||
|
||
record_summaries = []
|
||
for r in records:
|
||
record_summaries.append({
|
||
"id": str(r.id),
|
||
"platform": r.platform,
|
||
"cited": r.cited,
|
||
"confidence": r.confidence,
|
||
"match_type": r.match_type,
|
||
})
|
||
|
||
return {
|
||
"query_id": str(query_id),
|
||
"keyword": query.keyword,
|
||
"total_records": len(records),
|
||
"cited_count": sum(1 for r in records if r.cited),
|
||
"records": record_summaries,
|
||
}
|
||
|
||
async def _execute_single_detect(self, task: TaskMessage) -> dict:
|
||
keyword = task.input_data.get("keyword")
|
||
platform = task.input_data.get("platform")
|
||
target_brand = task.input_data.get("target_brand")
|
||
brand_aliases = task.input_data.get("brand_aliases", [])
|
||
|
||
if not all([keyword, platform, target_brand]):
|
||
raise ValueError(
|
||
"input_data must contain 'keyword', 'platform', 'target_brand'"
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.2,
|
||
message=f"Querying platform '{platform}' with keyword '{keyword}'",
|
||
)
|
||
|
||
result = await execute_single_platform(
|
||
keyword=keyword,
|
||
platform=platform,
|
||
target_brand=target_brand,
|
||
brand_aliases=brand_aliases,
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message=f"Single platform detection completed on '{platform}'",
|
||
)
|
||
|
||
output = {k: v for k, v in result.items() if k != "raw_response"}
|
||
return output
|
||
|
||
async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]:
|
||
records: list[CitationRecord] = []
|
||
platforms = query.platforms or ["wenxin", "kimi"]
|
||
brand_aliases = query.brand_aliases or []
|
||
|
||
for platform_name in platforms:
|
||
task_obj = await self._get_or_create_task(db, query.id, platform_name)
|
||
task_obj.status = "running"
|
||
task_obj.started_at = datetime.utcnow()
|
||
task_obj.error_message = None
|
||
await db.commit()
|
||
|
||
try:
|
||
detect_result = await execute_single_platform(
|
||
keyword=query.keyword,
|
||
platform=platform_name,
|
||
target_brand=query.target_brand,
|
||
brand_aliases=brand_aliases,
|
||
)
|
||
|
||
record = CitationRecord.from_citation_result(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
result=detect_result,
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
|
||
task_obj.status = "success"
|
||
task_obj.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||
task_obj.status = "failed"
|
||
task_obj.error_message = str(e)
|
||
task_obj.completed_at = datetime.utcnow()
|
||
|
||
record = CitationRecord.from_citation_result(
|
||
query_id=query.id,
|
||
platform=platform_name,
|
||
result={"cited": False, "raw_response": str(e)},
|
||
)
|
||
db.add(record)
|
||
records.append(record)
|
||
await db.commit()
|
||
|
||
query.last_queried_at = datetime.utcnow()
|
||
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
||
await db.commit()
|
||
|
||
return records
|
||
|
||
async def execute_single_platform_compat(
|
||
self, keyword: str, platform: str, target_brand: str, brand_aliases: list
|
||
) -> dict:
|
||
return await execute_single_platform(
|
||
keyword=keyword,
|
||
platform=platform,
|
||
target_brand=target_brand,
|
||
brand_aliases=brand_aliases,
|
||
)
|
||
|
||
async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask:
|
||
stmt = select(QueryTask).where(
|
||
QueryTask.query_id == query_id,
|
||
QueryTask.platform == platform,
|
||
)
|
||
result = await db.execute(stmt)
|
||
task_obj = result.scalar_one_or_none()
|
||
|
||
if not task_obj:
|
||
task_obj = QueryTask(
|
||
query_id=query_id,
|
||
platform=platform,
|
||
status="pending",
|
||
)
|
||
db.add(task_obj)
|
||
await db.commit()
|
||
await db.refresh(task_obj)
|
||
|
||
return task_obj
|
||
|
||
@staticmethod
|
||
def _calculate_next_query_at(frequency: str | None) -> datetime:
|
||
now = datetime.utcnow()
|
||
freq_map = {
|
||
"daily": timedelta(days=1),
|
||
"weekly": timedelta(days=7),
|
||
"monthly": timedelta(days=30),
|
||
}
|
||
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
|
||
return now + delta
|