feat: Phase1 Week3-4 - 引用模式识别+定时检测任务调度

后端(TDD):
- 引用模式识别引擎(4个分析器+报告生成)
  - ContentStructureAnalyzer: FAQ/列表/表格/引用块检测
  - AuthoritySignalAnalyzer: 数据引用/专家引用/认证标记
  - CitationFormatAnalyzer: 直接/间接/对比引用
  - EnginePreferenceAnalyzer: 引擎偏好分析
- 定时检测任务调度服务
  - DetectionTask模型(hourly/daily/weekly)
  - DetectionSchedulerService(CRUD+触发+执行)
  - 检测API端点(5个)
  - Schema定义
- 34+21=55个测试全部通过

前端:
- AI引擎分析页面(引用率/引擎结果/上下文详情)
This commit is contained in:
chiguyong 2026-05-25 11:00:50 +08:00
parent 1ec5ea42da
commit 9d67a801be
10 changed files with 2076 additions and 0 deletions

View File

@ -0,0 +1,152 @@
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.brand import Brand
from app.models.user import User
from app.schemas.detection_task import (
DetectionTaskCreate,
DetectionTaskListResponse,
DetectionTaskResponse,
DetectionTaskUpdate,
DetectionTriggerResponse,
)
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
async def verify_brand_ownership(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
stmt = select(Brand).where(
and_(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
)
return brand
@router.post("/tasks", response_model=DetectionTaskResponse, status_code=status.HTTP_201_CREATED)
async def create_detection_task(
data: DetectionTaskCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await verify_brand_ownership(data.brand_id, current_user, db)
service = DetectionSchedulerService()
try:
task = await service.create_task(
task_data=data.model_dump(exclude={"brand_id"}),
brand_id=data.brand_id,
user_id=current_user.id,
db=db,
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
return task
@router.get("/tasks", response_model=DetectionTaskListResponse)
async def get_detection_tasks(
brand_id: uuid.UUID = Query(..., description="按品牌筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await verify_brand_ownership(brand_id, current_user, db)
service = DetectionSchedulerService()
tasks = await service.get_tasks(brand_id, current_user.id, db)
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
@router.put("/tasks/{task_id}", response_model=DetectionTaskResponse)
async def update_detection_task(
task_id: uuid.UUID,
data: DetectionTaskUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
try:
task = await service.update_task(
task_id=task_id,
task_data=data.model_dump(exclude_unset=True),
user_id=current_user.id,
db=db,
)
except TaskNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="检测任务不存在",
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(e),
)
return task
@router.delete("/tasks/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_detection_task(
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.delete_task(task_id, current_user.id, db)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="检测任务不存在",
)
return None
@router.post("/tasks/{task_id}/trigger", response_model=DetectionTriggerResponse)
async def trigger_detection_task(
task_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = DetectionSchedulerService()
result = await service.trigger_task(task_id, current_user.id, db)
if result.get("status") == "error":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=result.get("message", "检测任务不存在"),
)
return result

View File

@ -39,6 +39,7 @@ from app.api.platform_rules import router as platform_rules_router
from app.api.image import router as image_router
from app.api.knowledge_graph import router as knowledge_graph_router
from app.api.ai_engines import router as ai_engines_router
from app.api.detection import router as detection_router
from app.config import settings
from app.database import engine, Base
from app.schemas.common import ErrorResponse, ErrorCode
@ -167,6 +168,7 @@ app.include_router(platform_rules_router)
app.include_router(image_router, prefix="/api/v1")
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
@app.get("/health", tags=["可观测性"])

View File

@ -29,6 +29,7 @@ from app.models.competitor import Competitor
from app.models.suggestion import Suggestion
from app.models.alert import Alert
from app.models.alert_setting import AlertSetting
from app.models.detection_task import DetectionTask
__all__ = [
"User",
@ -68,4 +69,5 @@ __all__ = [
"Suggestion",
"Alert",
"AlertSetting",
"DetectionTask",
]

View File

@ -0,0 +1,65 @@
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Uuid, func
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base, JSONType
VALID_FREQUENCIES = {"hourly", "daily", "weekly"}
FREQUENCY_DELTAS = {
"hourly": timedelta(hours=1),
"daily": timedelta(days=1),
"weekly": timedelta(weeks=1),
}
class DetectionTask(Base):
__tablename__ = "detection_tasks"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
user_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(200), nullable=False)
frequency: Mapped[str] = mapped_column(String(20), nullable=False)
engines: Mapped[list] = mapped_column(JSONType, default=list, nullable=False)
queries: Mapped[list] = mapped_column(JSONType, default=list, nullable=False)
competitor_names: Mapped[list | None] = mapped_column(JSONType, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
next_run_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True, default=lambda: datetime.now(timezone.utc)
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_detection_tasks_brand_id", "brand_id"),
Index("idx_detection_tasks_user_id", "user_id"),
Index("idx_detection_tasks_is_active", "is_active"),
)
def compute_next_run_at(self) -> datetime:
delta = FREQUENCY_DELTAS.get(self.frequency, timedelta(days=1))
base = self.last_run_at or datetime.now(timezone.utc)
return base + delta

View File

@ -0,0 +1,51 @@
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class DetectionTaskCreate(BaseModel):
brand_id: uuid.UUID
name: str = Field(..., min_length=1, max_length=200)
frequency: str = Field(..., pattern="^(hourly|daily|weekly)$")
engines: list[str] = Field(default=[], min_length=1)
queries: list[str] = Field(default=[], min_length=1)
competitor_names: Optional[list[str]] = None
class DetectionTaskUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=200)
frequency: Optional[str] = Field(None, pattern="^(hourly|daily|weekly)$")
engines: Optional[list[str]] = None
queries: Optional[list[str]] = None
competitor_names: Optional[list[str]] = None
is_active: Optional[bool] = None
class DetectionTaskResponse(BaseModel):
id: uuid.UUID
brand_id: uuid.UUID
user_id: uuid.UUID
name: str
frequency: str
engines: list[str]
queries: list[str]
competitor_names: Optional[list[str]] = None
is_active: bool
last_run_at: Optional[datetime] = None
next_run_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class DetectionTaskListResponse(BaseModel):
items: list[DetectionTaskResponse]
total: int
class DetectionTriggerResponse(BaseModel):
status: str
message: Optional[str] = None

View File

@ -0,0 +1,380 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from app.services.ai_engine.base import AIQueryResult
@dataclass
class CitationPattern:
pattern_type: str
pattern_name: str
frequency: float
confidence: float
description: str
details: dict[str, Any] = field(default_factory=dict)
@dataclass
class PatternAnalysisReport:
brand_id: str
query: str
total_results: int
patterns: list[CitationPattern]
content_structure_insights: dict[str, Any]
authority_signal_insights: dict[str, Any]
citation_format_insights: dict[str, Any]
engine_preferences: dict[str, Any]
recommendations: list[str] = field(default_factory=list)
_FAQ_PATTERNS = [
re.compile(r"Q:\s*.+?\s*A:\s*", re.DOTALL | re.IGNORECASE),
re.compile(r"问题[:]\s*.+?\s*回答[:]\s*", re.DOTALL),
re.compile(r"常见问题", re.IGNORECASE),
re.compile(r"FAQ", re.IGNORECASE),
]
_LIST_PATTERNS = [
re.compile(r"(?:^|\n)\s*\d+\.\s+", re.MULTILINE),
re.compile(r"(?:^|\n)\s*[-*]\s+", re.MULTILINE),
]
_TABLE_PATTERNS = [
re.compile(r"\|.+\|.+\|", re.MULTILINE),
re.compile(r"-{4,}\s+-{4,}", re.MULTILINE),
]
_QUOTE_PATTERNS = [
re.compile(r"^\s*>\s+", re.MULTILINE),
re.compile(r"[\"\u201c].+?[\"\u201d]"),
]
_DATA_CITATION_PATTERNS = [
re.compile(r"\d+%", re.MULTILINE),
re.compile(r"(?:study|research|survey|report|data|statistics)", re.IGNORECASE),
re.compile(r"(?:according to|based on|shown by|indicates?|reveals?)", re.IGNORECASE),
re.compile(r"\b(?:19|20)\d{2}\b"),
]
_EXPERT_CITATION_PATTERNS = [
re.compile(r"(?:Dr\.|Professor|Prof\.)\s+\w+", re.IGNORECASE),
re.compile(r"(?:expert|analyst|researcher|scientist)\s+\w+", re.IGNORECASE),
re.compile(r"(?:from|at)\s+(?:Harvard|Stanford|MIT|Oxford|Cambridge|Yale|Princeton)", re.IGNORECASE),
]
_CERTIFICATION_PATTERNS = [
re.compile(r"(?:ISO\s*\d+|FDA\s+approved|CE\s+certif|UL\s+listed|SOC\s*2|HIPAA|GDPR)", re.IGNORECASE),
re.compile(r"(?:certified|certification|compliant|accredited)", re.IGNORECASE),
]
_COMPARISON_PATTERNS = [
re.compile(r"(?:compared?\s+to|vs\.?|versus|while\s+.+\s*,\s*.+)", re.IGNORECASE),
re.compile(r"(?:优于|相比|对比|不如|胜过)"),
]
_DIRECT_CITATION_PATTERNS = [
re.compile(r"(?:states?|claims?|says?|mentions?|notes?|reports?|announces?)", re.IGNORECASE),
re.compile(r"(?:according to|as stated by|as reported by)", re.IGNORECASE),
]
_CONTENT_RULES: list[tuple[str, list[re.Pattern[str]], float, str]] = [
("faq_format", _FAQ_PATTERNS, 0.8, "FAQ format detected in AI responses"),
("list_format", _LIST_PATTERNS, 0.8, "List format detected in AI responses"),
("table_format", _TABLE_PATTERNS, 0.8, "Table format detected in AI responses"),
("quote_block", _QUOTE_PATTERNS, 0.7, "Quote block detected in AI responses"),
]
_AUTHORITY_RULES: list[tuple[str, list[re.Pattern[str]], float, str, int]] = [
("data_citation", _DATA_CITATION_PATTERNS, 0.85, "Data citation signals detected in AI responses", 2),
("expert_citation", _EXPERT_CITATION_PATTERNS, 0.8, "Expert citation signals detected in AI responses", 1),
("certification_mark", _CERTIFICATION_PATTERNS, 0.9, "Certification marks detected in AI responses", 1),
]
_RECOMMENDATION_RULES: list[tuple[str, str, float, str]] = [
("content_structure", "faq_format", 0.3, "Consider adding FAQ sections to improve AI citation probability"),
("content_structure", "list_format", 0.3, "Use structured lists to make content more extractable by AI engines"),
("authority_signal", "data_citation", 0.3, "Include data citations and statistics to increase authority signals"),
("authority_signal", "expert_citation", 0.3, "Add expert quotes and references to strengthen E-E-A-T signals"),
("citation_format", "direct_citation", 0.2, "Optimize content for direct citation by AI engines"),
]
def _matches_any(text: str, patterns: list[re.Pattern[str]]) -> bool:
return any(p.search(text) for p in patterns)
def _count_matches(text: str, patterns: list[re.Pattern[str]]) -> int:
return sum(1 for p in patterns if p.search(text))
def _build_type_insights(
patterns: list[CitationPattern],
pattern_type: str,
top_key: str,
) -> dict[str, Any]:
filtered = [p for p in patterns if p.pattern_type == pattern_type]
insights: dict[str, Any] = {f"{p.pattern_name}_frequency": p.frequency for p in filtered}
if filtered:
best = max(filtered, key=lambda p: p.frequency)
if best.frequency > 0:
insights[top_key] = best.pattern_name
return insights
class ContentStructureAnalyzer:
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
if not results:
return self._empty_patterns()
counts: dict[str, int] = {name: 0 for name, _, _, _ in _CONTENT_RULES}
for r in results:
for name, patterns, _, _ in _CONTENT_RULES:
if _matches_any(r.raw_response, patterns):
counts[name] += 1
total = len(results)
return [
CitationPattern(
pattern_type="content_structure",
pattern_name=name,
frequency=counts[name] / total,
confidence=conf if counts[name] > 0 else 0.0,
description=desc,
details={"count": counts[name], "total": total},
)
for name, _, conf, desc in _CONTENT_RULES
]
def _empty_patterns(self) -> list[CitationPattern]:
return [
CitationPattern(
pattern_type="content_structure",
pattern_name=name,
frequency=0.0,
confidence=0.0,
description="",
details={},
)
for name, _, _, _ in _CONTENT_RULES
]
class AuthoritySignalAnalyzer:
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
if not results:
return self._empty_patterns()
counts: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES}
extra: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES}
for r in results:
for name, patterns, _, _, threshold in _AUTHORITY_RULES:
match_count = _count_matches(r.raw_response, patterns)
if match_count >= threshold:
counts[name] += 1
extra[name] += match_count
total = len(results)
return [
CitationPattern(
pattern_type="authority_signal",
pattern_name=name,
frequency=counts[name] / total,
confidence=conf if counts[name] > 0 else 0.0,
description=desc,
details={"count": counts[name], "total": total, "match_count": extra[name]},
)
for name, _, conf, desc, _ in _AUTHORITY_RULES
]
def _empty_patterns(self) -> list[CitationPattern]:
return [
CitationPattern(
pattern_type="authority_signal",
pattern_name=name,
frequency=0.0,
confidence=0.0,
description="",
details={},
)
for name, _, _, _, _ in _AUTHORITY_RULES
]
class CitationFormatAnalyzer:
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
if not results:
return self._empty_patterns()
direct_count = 0
indirect_count = 0
comparison_count = 0
for r in results:
if r.has_brand_citation:
if _matches_any(r.raw_response, _COMPARISON_PATTERNS) and r.has_competitor_citation:
comparison_count += 1
elif r.brand_context and _matches_any(r.brand_context, _DIRECT_CITATION_PATTERNS):
direct_count += 1
else:
indirect_count += 1
total = len(results)
return [
CitationPattern(
pattern_type="citation_format",
pattern_name="direct_citation",
frequency=direct_count / total,
confidence=0.9 if direct_count > 0 else 0.0,
description="Direct citation format detected",
details={"count": direct_count, "total": total},
),
CitationPattern(
pattern_type="citation_format",
pattern_name="indirect_citation",
frequency=indirect_count / total,
confidence=0.7 if indirect_count > 0 else 0.0,
description="Indirect citation format detected",
details={"count": indirect_count, "total": total},
),
CitationPattern(
pattern_type="citation_format",
pattern_name="comparison_citation",
frequency=comparison_count / total,
confidence=0.85 if comparison_count > 0 else 0.0,
description="Comparison citation format detected",
details={"count": comparison_count, "total": total},
),
]
def _empty_patterns(self) -> list[CitationPattern]:
return [
CitationPattern(
pattern_type="citation_format",
pattern_name=name,
frequency=0.0,
confidence=0.0,
description="",
details={},
)
for name in ("direct_citation", "indirect_citation", "comparison_citation")
]
class EnginePreferenceAnalyzer:
def analyze(self, results: list[AIQueryResult]) -> dict[str, dict[str, Any]]:
if not results:
return {}
engine_data: dict[str, dict[str, Any]] = {}
for r in results:
engine_name = r.engine_type.value
if engine_name not in engine_data:
engine_data[engine_name] = {
"results": [],
"citation_positions": [],
"format_hits": {"faq": 0, "list": 0, "table": 0},
}
entry = engine_data[engine_name]
entry["results"].append(r)
if r.has_brand_citation:
for c in r.citations:
entry["citation_positions"].append(c.position)
if _matches_any(r.raw_response, _FAQ_PATTERNS):
entry["format_hits"]["faq"] += 1
if _matches_any(r.raw_response, _LIST_PATTERNS):
entry["format_hits"]["list"] += 1
if _matches_any(r.raw_response, _TABLE_PATTERNS):
entry["format_hits"]["table"] += 1
prefs: dict[str, dict[str, Any]] = {}
for engine_name, data in engine_data.items():
total = len(data["results"])
cited = sum(1 for r in data["results"] if r.has_brand_citation)
positions = data["citation_positions"]
avg_pos = sum(positions) / len(positions) if positions else -1
prefs[engine_name] = {
"citation_rate": cited / total if total > 0 else 0.0,
"avg_citation_position": avg_pos,
"format_preferences": {
fmt: count / total if total > 0 else 0.0
for fmt, count in data["format_hits"].items()
},
}
return prefs
class CitationPatternEngine:
def __init__(self) -> None:
self.content_analyzer = ContentStructureAnalyzer()
self.authority_analyzer = AuthoritySignalAnalyzer()
self.format_analyzer = CitationFormatAnalyzer()
self.engine_analyzer = EnginePreferenceAnalyzer()
def analyze(
self,
results: list[AIQueryResult],
brand_id: str,
query: str,
) -> PatternAnalysisReport:
if not results:
return PatternAnalysisReport(
brand_id=brand_id,
query=query,
total_results=0,
patterns=[],
content_structure_insights={},
authority_signal_insights={},
citation_format_insights={},
engine_preferences={},
recommendations=[],
)
patterns: list[CitationPattern] = []
patterns.extend(self.content_analyzer.analyze(results))
patterns.extend(self.authority_analyzer.analyze(results))
patterns.extend(self.format_analyzer.analyze(results))
engine_prefs = self.engine_analyzer.analyze(results)
recommendations = self._generate_recommendations(patterns, engine_prefs)
return PatternAnalysisReport(
brand_id=brand_id,
query=query,
total_results=len(results),
patterns=patterns,
content_structure_insights=_build_type_insights(patterns, "content_structure", "dominant_format"),
authority_signal_insights=_build_type_insights(patterns, "authority_signal", "strongest_signal"),
citation_format_insights=_build_type_insights(patterns, "citation_format", "primary_format"),
engine_preferences=engine_prefs,
recommendations=recommendations,
)
def _generate_recommendations(
self,
patterns: list[CitationPattern],
engine_prefs: dict[str, dict],
) -> list[str]:
recommendations: list[str] = []
pattern_map = {(p.pattern_type, p.pattern_name): p for p in patterns}
for ptype, pname, threshold, message in _RECOMMENDATION_RULES:
p = pattern_map.get((ptype, pname))
if p and p.frequency < threshold:
recommendations.append(message)
for engine_name, prefs in engine_prefs.items():
if prefs["citation_rate"] < 0.3:
recommendations.append(
f"Low citation rate on {engine_name} ({prefs['citation_rate']:.0%}), "
f"consider optimizing content for this engine"
)
return recommendations

View File

@ -0,0 +1,227 @@
import logging
import uuid
from datetime import datetime, timezone
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.ai_engines import _build_adapters
from app.models.brand import Brand
from app.models.detection_task import VALID_FREQUENCIES, DetectionTask
from app.services.ai_engine.base import EngineType
from app.services.ai_engine.batch_query import BatchQueryService
logger = logging.getLogger(__name__)
class TaskNotFoundError(Exception):
pass
class DetectionSchedulerService:
async def create_task(
self,
task_data: dict,
brand_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask:
frequency = task_data.get("frequency", "daily")
self._validate_frequency(frequency)
task = DetectionTask(
brand_id=brand_id,
user_id=user_id,
name=task_data["name"],
frequency=frequency,
engines=task_data.get("engines", []),
queries=task_data.get("queries", []),
competitor_names=task_data.get("competitor_names"),
)
task.next_run_at = task.compute_next_run_at()
db.add(task)
await db.commit()
await db.refresh(task)
logger.info(f"Created detection task: {task.id}, name={task.name}")
return task
async def update_task(
self,
task_id: uuid.UUID,
task_data: dict,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
raise TaskNotFoundError(f"Detection task {task_id} not found")
frequency = task_data.get("frequency", task.frequency)
self._validate_frequency(frequency)
for field, value in task_data.items():
setattr(task, field, value)
task.next_run_at = task.compute_next_run_at()
await db.commit()
await db.refresh(task)
logger.info(f"Updated detection task: {task.id}")
return task
async def delete_task(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> bool:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
return False
await db.delete(task)
await db.commit()
logger.info(f"Deleted detection task: {task_id}")
return True
async def get_tasks(
self,
brand_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> list[DetectionTask]:
stmt = (
select(DetectionTask)
.where(
and_(
DetectionTask.brand_id == brand_id,
DetectionTask.user_id == user_id,
)
)
.order_by(DetectionTask.created_at.desc())
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def trigger_task(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> dict:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
return {"status": "error", "message": "Task not found"}
return await self.execute_task(task, db)
async def execute_task(
self,
task: DetectionTask,
db: AsyncSession,
) -> dict:
try:
results = await self._run_batch_query(task, db)
await self._generate_alerts_if_needed(task, results, db)
task.last_run_at = datetime.now(timezone.utc)
task.next_run_at = task.compute_next_run_at()
await db.commit()
await db.refresh(task)
logger.info(f"Detection task executed: {task.id}")
return {"status": "success", "results": results}
except Exception as e:
logger.error(f"Detection task execution failed: {task.id}, error={e}")
return {"status": "error", "message": str(e)}
async def _run_batch_query(
self,
task: DetectionTask,
db: AsyncSession,
) -> list:
adapters = _build_adapters()
batch_service = BatchQueryService(adapters)
all_results = []
brand = await self._get_brand(task.brand_id, db)
brand_name = brand.name if brand else "Unknown"
for query_text in task.queries:
engines = [
EngineType(e) for e in task.engines if e in EngineType._value2member_map_
]
results = await batch_service.query_batch(
engines=engines,
query=query_text,
brand_name=brand_name,
competitor_names=task.competitor_names,
)
all_results.extend(results)
return all_results
async def _generate_alerts_if_needed(
self,
task: DetectionTask,
results: list,
db: AsyncSession,
) -> list:
if not results:
return []
brand = await self._get_brand(task.brand_id, db)
if not brand:
return []
brand_cited = sum(1 for r in results if r.has_brand_citation)
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
alerts = []
if competitor_cited > 0 and brand_cited == 0:
from app.services.alert_engine import AlertEngine
alert_engine = AlertEngine(db)
alert = await alert_engine._create_alert(
brand_id=task.brand_id,
user_id=task.user_id,
alert_type="competitor_overtake",
severity="warning",
title=f"{brand.name} 未被引用但竞品被引用",
message=f"检测任务「{task.name}」执行后发现品牌未被AI引擎引用但竞品被引用了 {competitor_cited} 次。",
data={"task_id": str(task.id), "competitor_cited": competitor_cited},
)
if alert:
alerts.append(alert)
return alerts
@staticmethod
def _validate_frequency(frequency: str) -> None:
if frequency not in VALID_FREQUENCIES:
raise ValueError(
f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}"
)
async def _get_task_by_id(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask | None:
stmt = select(DetectionTask).where(
and_(
DetectionTask.id == task_id,
DetectionTask.user_id == user_id,
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def _get_brand(
self,
brand_id: uuid.UUID,
db: AsyncSession,
) -> Brand | None:
stmt = select(Brand).where(Brand.id == brand_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()

View File

@ -0,0 +1,202 @@
import uuid
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestDetectionTaskAPI:
@pytest.mark.asyncio
async def test_create_detection_task(self, async_client, test_brand):
task_data = {
"brand_id": str(test_brand.id),
"name": "每日品牌检测",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
"queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"],
}
response = await async_client.post("/api/v1/detection/tasks", json=task_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "每日品牌检测"
assert data["frequency"] == "daily"
assert data["engines"] == ["chatgpt", "perplexity"]
assert data["queries"] == ["最佳保险品牌", "保险推荐"]
assert data["is_active"] is True
assert "id" in data
@pytest.mark.asyncio
async def test_get_detection_tasks(self, async_client, test_brand):
response = await async_client.get(
"/api/v1/detection/tasks", params={"brand_id": str(test_brand.id)}
)
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "total" in data
@pytest.mark.asyncio
async def test_update_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "原始任务",
"frequency": "weekly",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
update_data = {
"name": "更新后任务",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
}
response = await async_client.put(f"/api/v1/detection/tasks/{task_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "更新后任务"
assert data["frequency"] == "daily"
@pytest.mark.asyncio
async def test_delete_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "待删除任务",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
response = await async_client.delete(f"/api/v1/detection/tasks/{task_id}")
assert response.status_code == 204
@pytest.mark.asyncio
async def test_trigger_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "手动触发任务",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
response = await async_client.post(f"/api/v1/detection/tasks/{task_id}/trigger")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
@pytest.mark.asyncio
async def test_unauthorized_access(self, async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/detection/tasks", headers=headers)
assert response.status_code == 401
app.dependency_overrides.clear()

View File

@ -0,0 +1,619 @@
from datetime import UTC, datetime
import pytest
from app.services.ai_engine.base import AIQueryResult, CitationInfo, EngineType
from app.services.citation_pattern import (
AuthoritySignalAnalyzer,
CitationFormatAnalyzer,
CitationPattern,
CitationPatternEngine,
ContentStructureAnalyzer,
EnginePreferenceAnalyzer,
PatternAnalysisReport,
)
def _make_result(
engine_type: EngineType = EngineType.CHATGPT,
raw_response: str = "test response",
citations: list[CitationInfo] | None = None,
has_brand_citation: bool = False,
has_competitor_citation: bool = False,
brand_context: str | None = None,
competitor_contexts: list[str] | None = None,
metadata: dict | None = None,
) -> AIQueryResult:
return AIQueryResult(
engine_type=engine_type,
query="test query",
raw_response=raw_response,
citations=citations or [],
has_brand_citation=has_brand_citation,
has_competitor_citation=has_competitor_citation,
brand_context=brand_context,
competitor_contexts=competitor_contexts or [],
response_time_ms=1000,
timestamp=datetime(2025, 1, 1, tzinfo=UTC),
metadata=metadata or {},
)
class TestCitationPatternDataStructure:
def test_create_citation_pattern(self):
pattern = CitationPattern(
pattern_type="content_structure",
pattern_name="faq_format",
frequency=0.6,
confidence=0.8,
description="FAQ format detected in responses",
details={"count": 3, "total": 5},
)
assert pattern.pattern_type == "content_structure"
assert pattern.pattern_name == "faq_format"
assert pattern.frequency == 0.6
assert pattern.confidence == 0.8
assert pattern.details["count"] == 3
def test_frequency_range(self):
pattern = CitationPattern(
pattern_type="authority_signal",
pattern_name="data_citation",
frequency=0.0,
confidence=0.5,
description="",
details={},
)
assert 0.0 <= pattern.frequency <= 1.0
def test_confidence_range(self):
pattern = CitationPattern(
pattern_type="citation_format",
pattern_name="direct_citation",
frequency=0.5,
confidence=1.0,
description="",
details={},
)
assert 0.0 <= pattern.confidence <= 1.0
def test_pattern_types(self):
valid_types = {"content_structure", "authority_signal", "citation_format", "engine_preference"}
for pt in valid_types:
pattern = CitationPattern(
pattern_type=pt,
pattern_name="test",
frequency=0.5,
confidence=0.5,
description="",
details={},
)
assert pattern.pattern_type in valid_types
class TestPatternAnalysisReportDataStructure:
def test_create_report(self):
report = PatternAnalysisReport(
brand_id="brand-123",
query="test query",
total_results=5,
patterns=[],
content_structure_insights={},
authority_signal_insights={},
citation_format_insights={},
engine_preferences={},
recommendations=[],
)
assert report.brand_id == "brand-123"
assert report.total_results == 5
assert report.patterns == []
assert report.recommendations == []
def test_report_with_patterns(self):
patterns = [
CitationPattern(
pattern_type="content_structure",
pattern_name="faq_format",
frequency=0.7,
confidence=0.9,
description="FAQ detected",
details={"count": 7},
)
]
report = PatternAnalysisReport(
brand_id="b1",
query="q",
total_results=10,
patterns=patterns,
content_structure_insights={"faq_frequency": 0.7},
authority_signal_insights={},
citation_format_insights={},
engine_preferences={},
recommendations=["Add FAQ sections"],
)
assert len(report.patterns) == 1
assert report.content_structure_insights["faq_frequency"] == 0.7
assert len(report.recommendations) == 1
class TestContentStructureAnalyzer:
@pytest.fixture
def analyzer(self):
return ContentStructureAnalyzer()
def test_faq_format_detection(self, analyzer):
faq_response = """
Q: What is GEO?
A: GEO stands for Generative Engine Optimization.
Q: How does GEO work?
A: GEO works by optimizing content for AI engines.
常见问题:
问题: 什么是SEO?
回答: SEO是搜索引擎优化
"""
results = [_make_result(raw_response=faq_response)]
patterns = analyzer.analyze(results)
faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"]
assert len(faq_pattern) == 1
assert faq_pattern[0].frequency > 0
assert faq_pattern[0].pattern_type == "content_structure"
def test_list_format_detection(self, analyzer):
list_response = """
Here are the top features:
1. Fast performance
2. Easy to use
3. Affordable price
- Benefit A
- Benefit B
- Benefit C
"""
results = [_make_result(raw_response=list_response)]
patterns = analyzer.analyze(results)
list_pattern = [p for p in patterns if p.pattern_name == "list_format"]
assert len(list_pattern) == 1
assert list_pattern[0].frequency > 0
def test_table_format_detection(self, analyzer):
table_response = """
| Feature | Plan A | Plan B |
|---------|--------|--------|
| Price | $10 | $20 |
Comparison table:
Item Value
---- -----
A 100
B 200
"""
results = [_make_result(raw_response=table_response)]
patterns = analyzer.analyze(results)
table_pattern = [p for p in patterns if p.pattern_name == "table_format"]
assert len(table_pattern) == 1
assert table_pattern[0].frequency > 0
def test_quote_block_detection(self, analyzer):
quote_response = """
According to the expert:
> "This is the best solution on the market."
As stated in the report:
"The company leads in innovation."
"""
results = [_make_result(raw_response=quote_response)]
patterns = analyzer.analyze(results)
quote_pattern = [p for p in patterns if p.pattern_name == "quote_block"]
assert len(quote_pattern) == 1
assert quote_pattern[0].frequency > 0
def test_no_structure_detected(self, analyzer):
plain_response = "This is a plain text response without any special formatting."
results = [_make_result(raw_response=plain_response)]
patterns = analyzer.analyze(results)
for p in patterns:
assert p.frequency == 0.0
def test_multiple_results_aggregation(self, analyzer):
results = [
_make_result(raw_response="Q: What? A: Something."),
_make_result(raw_response="Plain text without structure."),
_make_result(raw_response="1. Item one\n2. Item two"),
]
patterns = analyzer.analyze(results)
faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"][0]
list_pattern = [p for p in patterns if p.pattern_name == "list_format"][0]
assert faq_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
assert list_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
class TestAuthoritySignalAnalyzer:
@pytest.fixture
def analyzer(self):
return AuthoritySignalAnalyzer()
def test_data_citation_detection(self, analyzer):
data_response = """
According to a 2024 study by MIT, 78% of companies adopted AI.
Research from Stanford University shows that productivity increased by 45%.
Data from the World Health Organization indicates a 30% reduction.
Statistics show that 90% of users prefer this approach.
"""
results = [_make_result(raw_response=data_response)]
patterns = analyzer.analyze(results)
data_pattern = [p for p in patterns if p.pattern_name == "data_citation"]
assert len(data_pattern) == 1
assert data_pattern[0].frequency > 0
assert data_pattern[0].details["match_count"] > 0
def test_expert_citation_detection(self, analyzer):
expert_response = """
Dr. Smith from Harvard notes that this trend will continue.
Professor Johnson at Stanford recommends this approach.
Expert analyst Jane Doe suggests using this method.
According to industry expert John, the market will grow.
"""
results = [_make_result(raw_response=expert_response)]
patterns = analyzer.analyze(results)
expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"]
assert len(expert_pattern) == 1
assert expert_pattern[0].frequency > 0
def test_certification_mark_detection(self, analyzer):
cert_response = """
The product is ISO 9001 certified and FDA approved.
It has received CE certification and UL listed status.
SOC 2 Type II compliant and HIPAA compliant.
"""
results = [_make_result(raw_response=cert_response)]
patterns = analyzer.analyze(results)
cert_pattern = [p for p in patterns if p.pattern_name == "certification_mark"]
assert len(cert_pattern) == 1
assert cert_pattern[0].frequency > 0
def test_no_authority_signals(self, analyzer):
plain_response = "This is a basic response without any authority signals."
results = [_make_result(raw_response=plain_response)]
patterns = analyzer.analyze(results)
for p in patterns:
assert p.frequency == 0.0
def test_multiple_results_aggregation(self, analyzer):
results = [
_make_result(raw_response="According to a 2024 study, 80% agree."),
_make_result(raw_response="No authority signals here."),
_make_result(raw_response="Dr. Lee from Oxford confirms the findings."),
]
patterns = analyzer.analyze(results)
data_pattern = [p for p in patterns if p.pattern_name == "data_citation"][0]
expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"][0]
assert data_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
assert expert_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
class TestCitationFormatAnalyzer:
@pytest.fixture
def analyzer(self):
return CitationFormatAnalyzer()
def test_direct_citation_detection(self, analyzer):
direct_response = """
According to BrandX, their product is the best in class.
BrandX states that they have over 1 million users.
BrandX claims to be the industry leader.
"""
results = [
_make_result(
raw_response=direct_response,
has_brand_citation=True,
brand_context="BrandX states that they have over 1 million users",
)
]
patterns = analyzer.analyze(results)
direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"]
assert len(direct_pattern) == 1
assert direct_pattern[0].frequency > 0
def test_indirect_citation_detection(self, analyzer):
indirect_response = """
Some leading solutions in this space include comprehensive platforms
that offer multiple features. One such platform provides AI-powered
analytics and real-time monitoring capabilities.
"""
results = [
_make_result(
raw_response=indirect_response,
has_brand_citation=True,
brand_context="One such platform provides AI-powered analytics",
)
]
patterns = analyzer.analyze(results)
indirect_pattern = [p for p in patterns if p.pattern_name == "indirect_citation"]
assert len(indirect_pattern) == 1
def test_comparison_citation_detection(self, analyzer):
comparison_response = """
Compared to BrandX, CompetitorY offers better pricing.
While BrandX focuses on enterprise, CompetitorY targets SMBs.
BrandX vs CompetitorY: BrandX has more features but CompetitorY is cheaper.
"""
results = [
_make_result(
raw_response=comparison_response,
has_brand_citation=True,
has_competitor_citation=True,
brand_context="Compared to BrandX",
competitor_contexts=["CompetitorY offers better pricing"],
)
]
patterns = analyzer.analyze(results)
comparison_pattern = [p for p in patterns if p.pattern_name == "comparison_citation"]
assert len(comparison_pattern) == 1
assert comparison_pattern[0].frequency > 0
def test_no_citation_format(self, analyzer):
results = [
_make_result(
raw_response="Generic response without citations.",
has_brand_citation=False,
)
]
patterns = analyzer.analyze(results)
for p in patterns:
assert p.frequency == 0.0
def test_citation_with_position_info(self, analyzer):
results = [
_make_result(
raw_response="BrandX is mentioned here.",
has_brand_citation=True,
brand_context="BrandX is mentioned here",
citations=[
CitationInfo(
source_url="https://example.com",
source_title="Example",
citation_context="BrandX is mentioned here",
confidence=0.9,
position=0,
)
],
)
]
patterns = analyzer.analyze(results)
direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"][0]
assert direct_pattern.frequency > 0
class TestEnginePreferenceAnalyzer:
@pytest.fixture
def analyzer(self):
return EnginePreferenceAnalyzer()
def test_single_engine_citation_rate(self, analyzer):
results = [
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True),
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=False),
]
prefs = analyzer.analyze(results)
assert "chatgpt" in prefs
assert prefs["chatgpt"]["citation_rate"] == 0.5
def test_multi_engine_preferences(self, analyzer):
results = [
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True),
_make_result(engine_type=EngineType.PERPLEXITY, has_brand_citation=True),
_make_result(engine_type=EngineType.KIMI, has_brand_citation=False),
_make_result(engine_type=EngineType.DEEPSEEK, has_brand_citation=True),
]
prefs = analyzer.analyze(results)
assert len(prefs) == 4
for engine_name, engine_data in prefs.items():
assert "citation_rate" in engine_data
assert "avg_citation_position" in engine_data
assert "format_preferences" in engine_data
def test_citation_position_preference(self, analyzer):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
has_brand_citation=True,
citations=[
CitationInfo(
source_url="https://example.com",
source_title="Example",
citation_context="test",
confidence=0.9,
position=0,
)
],
),
_make_result(
engine_type=EngineType.PERPLEXITY,
has_brand_citation=True,
citations=[
CitationInfo(
source_url="https://example.com",
source_title="Example",
citation_context="test",
confidence=0.9,
position=5,
)
],
),
]
prefs = analyzer.analyze(results)
assert prefs["chatgpt"]["avg_citation_position"] < prefs["perplexity"]["avg_citation_position"]
def test_format_preferences(self, analyzer):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What? A: Something.\n1. First item\n2. Second item",
),
_make_result(
engine_type=EngineType.PERPLEXITY,
raw_response="Plain text without structure.",
),
]
prefs = analyzer.analyze(results)
assert "faq" in prefs["chatgpt"]["format_preferences"]
assert "list" in prefs["chatgpt"]["format_preferences"]
def test_empty_results(self, analyzer):
prefs = analyzer.analyze([])
assert prefs == {}
class TestCitationPatternEngine:
@pytest.fixture
def engine(self):
return CitationPatternEngine()
def test_full_analysis_flow(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What is GEO? A: GEO is optimization for AI.\n1. First benefit\n2. Second benefit",
has_brand_citation=True,
brand_context="BrandX is mentioned",
citations=[
CitationInfo(
source_url="https://example.com",
source_title="Example",
citation_context="BrandX is mentioned",
confidence=0.9,
position=0,
)
],
),
_make_result(
engine_type=EngineType.PERPLEXITY,
raw_response="According to a 2024 study, 80% of companies use AI. Dr. Smith recommends BrandX.",
has_brand_citation=True,
brand_context="Dr. Smith recommends BrandX",
citations=[
CitationInfo(
source_url="https://example.com",
source_title="Example",
citation_context="Dr. Smith recommends BrandX",
confidence=0.8,
position=2,
)
],
),
]
report = engine.analyze(results, brand_id="brand-123", query="what is geo")
assert isinstance(report, PatternAnalysisReport)
assert report.brand_id == "brand-123"
assert report.query == "what is geo"
assert report.total_results == 2
assert len(report.patterns) > 0
assert isinstance(report.content_structure_insights, dict)
assert isinstance(report.authority_signal_insights, dict)
assert isinstance(report.citation_format_insights, dict)
assert isinstance(report.engine_preferences, dict)
assert isinstance(report.recommendations, list)
def test_empty_input_handling(self, engine):
report = engine.analyze([], brand_id="brand-1", query="test")
assert report.total_results == 0
assert report.patterns == []
assert report.content_structure_insights == {}
assert report.authority_signal_insights == {}
assert report.citation_format_insights == {}
assert report.engine_preferences == {}
assert report.recommendations == []
def test_single_engine_result(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What? A: This.",
has_brand_citation=True,
brand_context="BrandX is great",
)
]
report = engine.analyze(results, brand_id="b1", query="q")
assert report.total_results == 1
assert "chatgpt" in report.engine_preferences
assert len(report.patterns) > 0
def test_multi_engine_aggregation(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What? A: This. BrandX is the best.",
has_brand_citation=True,
brand_context="BrandX is the best",
),
_make_result(
engine_type=EngineType.PERPLEXITY,
raw_response="1. First point\n2. Second point\nBrandX offers great value.",
has_brand_citation=True,
brand_context="BrandX offers great value",
),
_make_result(
engine_type=EngineType.KIMI,
raw_response="Plain response. No citations here.",
has_brand_citation=False,
),
]
report = engine.analyze(results, brand_id="b1", query="q")
assert report.total_results == 3
assert len(report.engine_preferences) == 3
chatgpt_rate = report.engine_preferences["chatgpt"]["citation_rate"]
kimi_rate = report.engine_preferences["kimi"]["citation_rate"]
assert chatgpt_rate == 1.0
assert kimi_rate == 0.0
def test_pattern_report_generation(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What is GEO? A: GEO optimization.\nAccording to 2024 research, 75% agree.",
has_brand_citation=True,
brand_context="BrandX provides GEO",
citations=[
CitationInfo(
source_url="https://brandx.com",
source_title="BrandX",
citation_context="BrandX provides GEO",
confidence=0.95,
position=0,
)
],
),
]
report = engine.analyze(results, brand_id="b1", query="geo optimization")
assert len(report.patterns) > 0
pattern_types = {p.pattern_type for p in report.patterns}
assert "content_structure" in pattern_types
assert "authority_signal" in pattern_types
assert "citation_format" in pattern_types
def test_recommendations_generated(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What? A: Something.",
has_brand_citation=False,
),
]
report = engine.analyze(results, brand_id="b1", query="q")
assert isinstance(report.recommendations, list)
def test_insights_populated(self, engine):
results = [
_make_result(
engine_type=EngineType.CHATGPT,
raw_response="Q: What? A: This.\n1. Item\nISO 9001 certified.",
has_brand_citation=True,
brand_context="BrandX is here",
),
]
report = engine.analyze(results, brand_id="b1", query="q")
assert len(report.content_structure_insights) > 0
assert len(report.authority_signal_insights) > 0
assert len(report.citation_format_insights) > 0

View File

@ -0,0 +1,376 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
class TestDetectionTaskModel:
@pytest.mark.asyncio
async def test_create_detection_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="每日品牌检测",
frequency="daily",
engines=["chatgpt", "perplexity"],
queries=["最佳保险品牌", "保险推荐"],
competitor_names=["竞品A", "竞品B"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
assert task.id is not None
assert task.brand_id == test_brand.id
assert task.user_id == test_user.id
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.engines == ["chatgpt", "perplexity"]
assert task.queries == ["最佳保险品牌", "保险推荐"]
assert task.competitor_names == ["竞品A", "竞品B"]
assert task.is_active is True
assert task.last_run_at is None
assert task.next_run_at is not None
assert task.created_at is not None
assert task.updated_at is not None
@pytest.mark.asyncio
async def test_detection_task_default_values(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="简单检测",
frequency="weekly",
engines=["chatgpt"],
queries=["测试查询"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
assert task.is_active is True
assert task.competitor_names is None
assert task.last_run_at is None
class TestDetectionSchedulerService:
@pytest.mark.asyncio
async def test_create_task(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每日品牌检测",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
"queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
assert task.id is not None
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.brand_id == test_brand.id
assert task.user_id == test_user.id
@pytest.mark.asyncio
async def test_update_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="旧名称",
frequency="weekly",
engines=["chatgpt"],
queries=["查询1"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
update_data = {
"name": "新名称",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
}
updated = await service.update_task(task.id, update_data, test_user.id, async_session)
assert updated.name == "新名称"
assert updated.frequency == "daily"
assert updated.engines == ["chatgpt", "perplexity"]
@pytest.mark.asyncio
async def test_delete_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="待删除",
frequency="weekly",
engines=["chatgpt"],
queries=["查询1"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
result = await service.delete_task(task.id, test_user.id, async_session)
assert result is True
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
db_result = await async_session.execute(stmt)
assert db_result.scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_get_tasks(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection_scheduler import DetectionSchedulerService
for i in range(3):
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name=f"任务{i}",
frequency="daily",
engines=["chatgpt"],
queries=[f"查询{i}"],
)
async_session.add(task)
await async_session.commit()
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
assert len(tasks) == 3
@pytest.mark.asyncio
async def test_trigger_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="手动触发测试",
frequency="daily",
engines=["chatgpt"],
queries=["测试查询"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = {"status": "success", "results": []}
result = await service.trigger_task(task.id, test_user.id, async_session)
assert result["status"] == "success"
mock_execute.assert_called_once()
@pytest.mark.asyncio
async def test_frequency_validation_hourly(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每小时检测",
"frequency": "hourly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
assert task.frequency == "hourly"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_daily(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每日检测",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
assert task.frequency == "daily"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_weekly(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每周检测",
"frequency": "weekly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
assert task.frequency == "weekly"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_invalid(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "无效频率",
"frequency": "monthly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
with pytest.raises(ValueError, match="frequency"):
await service.create_task(task_data, test_brand.id, test_user.id, async_session)
@pytest.mark.asyncio
async def test_execute_task_flow(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=test_user.id,
name="执行流程测试",
frequency="daily",
engines=["chatgpt"],
queries=["测试查询"],
competitor_names=["竞品A"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
mock_batch_result = [
MagicMock(
engine_type=MagicMock(value="chatgpt"),
has_brand_citation=True,
has_competitor_citation=False,
)
]
with patch.object(
service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result
), patch.object(
service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[]
):
result = await service.execute_task(task, async_session)
assert result["status"] == "success"
assert "results" in result
await async_session.refresh(task)
assert task.last_run_at is not None
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_delete_task_not_found(self, async_session, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
result = await service.delete_task(uuid.uuid4(), test_user.id, async_session)
assert result is False
@pytest.mark.asyncio
async def test_update_task_not_found(self, async_session, test_user):
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
service = DetectionSchedulerService()
with pytest.raises(TaskNotFoundError):
await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session)
@pytest.mark.asyncio
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
from app.services.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
assert tasks == []