feat: Phase1 Week3-4 - 引用模式识别+定时检测任务调度
后端(TDD): - 引用模式识别引擎(4个分析器+报告生成) - ContentStructureAnalyzer: FAQ/列表/表格/引用块检测 - AuthoritySignalAnalyzer: 数据引用/专家引用/认证标记 - CitationFormatAnalyzer: 直接/间接/对比引用 - EnginePreferenceAnalyzer: 引擎偏好分析 - 定时检测任务调度服务 - DetectionTask模型(hourly/daily/weekly) - DetectionSchedulerService(CRUD+触发+执行) - 检测API端点(5个) - Schema定义 - 34+21=55个测试全部通过 前端: - AI引擎分析页面(引用率/引擎结果/上下文详情)
This commit is contained in:
parent
1ec5ea42da
commit
9d67a801be
|
|
@ -0,0 +1,152 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.schemas.detection_task import (
|
||||
DetectionTaskCreate,
|
||||
DetectionTaskListResponse,
|
||||
DetectionTaskResponse,
|
||||
DetectionTaskUpdate,
|
||||
DetectionTriggerResponse,
|
||||
)
|
||||
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def verify_brand_ownership(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
and_(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
|
||||
)
|
||||
|
||||
return brand
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=DetectionTaskResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_detection_task(
|
||||
data: DetectionTaskCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await verify_brand_ownership(data.brand_id, current_user, db)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
try:
|
||||
task = await service.create_task(
|
||||
task_data=data.model_dump(exclude={"brand_id"}),
|
||||
brand_id=data.brand_id,
|
||||
user_id=current_user.id,
|
||||
db=db,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=DetectionTaskListResponse)
|
||||
async def get_detection_tasks(
|
||||
brand_id: uuid.UUID = Query(..., description="按品牌筛选"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await verify_brand_ownership(brand_id, current_user, db)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(brand_id, current_user.id, db)
|
||||
|
||||
return {"items": tasks[skip : skip + limit], "total": len(tasks)}
|
||||
|
||||
|
||||
@router.put("/tasks/{task_id}", response_model=DetectionTaskResponse)
|
||||
async def update_detection_task(
|
||||
task_id: uuid.UUID,
|
||||
data: DetectionTaskUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = DetectionSchedulerService()
|
||||
try:
|
||||
task = await service.update_task(
|
||||
task_id=task_id,
|
||||
task_data=data.model_dump(exclude_unset=True),
|
||||
user_id=current_user.id,
|
||||
db=db,
|
||||
)
|
||||
except TaskNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="检测任务不存在",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_detection_task(
|
||||
task_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(task_id, current_user.id, db)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="检测任务不存在",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/trigger", response_model=DetectionTriggerResponse)
|
||||
async def trigger_detection_task(
|
||||
task_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.trigger_task(task_id, current_user.id, db)
|
||||
|
||||
if result.get("status") == "error":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=result.get("message", "检测任务不存在"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -39,6 +39,7 @@ from app.api.platform_rules import router as platform_rules_router
|
|||
from app.api.image import router as image_router
|
||||
from app.api.knowledge_graph import router as knowledge_graph_router
|
||||
from app.api.ai_engines import router as ai_engines_router
|
||||
from app.api.detection import router as detection_router
|
||||
from app.config import settings
|
||||
from app.database import engine, Base
|
||||
from app.schemas.common import ErrorResponse, ErrorCode
|
||||
|
|
@ -167,6 +168,7 @@ app.include_router(platform_rules_router)
|
|||
app.include_router(image_router, prefix="/api/v1")
|
||||
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
|
||||
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
|
||||
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
|
||||
|
||||
|
||||
@app.get("/health", tags=["可观测性"])
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.models.competitor import Competitor
|
|||
from app.models.suggestion import Suggestion
|
||||
from app.models.alert import Alert
|
||||
from app.models.alert_setting import AlertSetting
|
||||
from app.models.detection_task import DetectionTask
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
|
|
@ -68,4 +69,5 @@ __all__ = [
|
|||
"Suggestion",
|
||||
"Alert",
|
||||
"AlertSetting",
|
||||
"DetectionTask",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Uuid, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
VALID_FREQUENCIES = {"hourly", "daily", "weekly"}
|
||||
|
||||
FREQUENCY_DELTAS = {
|
||||
"hourly": timedelta(hours=1),
|
||||
"daily": timedelta(days=1),
|
||||
"weekly": timedelta(weeks=1),
|
||||
}
|
||||
|
||||
|
||||
class DetectionTask(Base):
|
||||
__tablename__ = "detection_tasks"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
frequency: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
engines: Mapped[list] = mapped_column(JSONType, default=list, nullable=False)
|
||||
queries: Mapped[list] = mapped_column(JSONType, default=list, nullable=False)
|
||||
competitor_names: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
next_run_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime, nullable=True, default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_detection_tasks_brand_id", "brand_id"),
|
||||
Index("idx_detection_tasks_user_id", "user_id"),
|
||||
Index("idx_detection_tasks_is_active", "is_active"),
|
||||
)
|
||||
|
||||
def compute_next_run_at(self) -> datetime:
|
||||
delta = FREQUENCY_DELTAS.get(self.frequency, timedelta(days=1))
|
||||
base = self.last_run_at or datetime.now(timezone.utc)
|
||||
return base + delta
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DetectionTaskCreate(BaseModel):
|
||||
brand_id: uuid.UUID
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
frequency: str = Field(..., pattern="^(hourly|daily|weekly)$")
|
||||
engines: list[str] = Field(default=[], min_length=1)
|
||||
queries: list[str] = Field(default=[], min_length=1)
|
||||
competitor_names: Optional[list[str]] = None
|
||||
|
||||
|
||||
class DetectionTaskUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
frequency: Optional[str] = Field(None, pattern="^(hourly|daily|weekly)$")
|
||||
engines: Optional[list[str]] = None
|
||||
queries: Optional[list[str]] = None
|
||||
competitor_names: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class DetectionTaskResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
brand_id: uuid.UUID
|
||||
user_id: uuid.UUID
|
||||
name: str
|
||||
frequency: str
|
||||
engines: list[str]
|
||||
queries: list[str]
|
||||
competitor_names: Optional[list[str]] = None
|
||||
is_active: bool
|
||||
last_run_at: Optional[datetime] = None
|
||||
next_run_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DetectionTaskListResponse(BaseModel):
|
||||
items: list[DetectionTaskResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DetectionTriggerResponse(BaseModel):
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.services.ai_engine.base import AIQueryResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class CitationPattern:
|
||||
pattern_type: str
|
||||
pattern_name: str
|
||||
frequency: float
|
||||
confidence: float
|
||||
description: str
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatternAnalysisReport:
|
||||
brand_id: str
|
||||
query: str
|
||||
total_results: int
|
||||
patterns: list[CitationPattern]
|
||||
content_structure_insights: dict[str, Any]
|
||||
authority_signal_insights: dict[str, Any]
|
||||
citation_format_insights: dict[str, Any]
|
||||
engine_preferences: dict[str, Any]
|
||||
recommendations: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
_FAQ_PATTERNS = [
|
||||
re.compile(r"Q:\s*.+?\s*A:\s*", re.DOTALL | re.IGNORECASE),
|
||||
re.compile(r"问题[::]\s*.+?\s*回答[::]\s*", re.DOTALL),
|
||||
re.compile(r"常见问题", re.IGNORECASE),
|
||||
re.compile(r"FAQ", re.IGNORECASE),
|
||||
]
|
||||
|
||||
_LIST_PATTERNS = [
|
||||
re.compile(r"(?:^|\n)\s*\d+\.\s+", re.MULTILINE),
|
||||
re.compile(r"(?:^|\n)\s*[-*]\s+", re.MULTILINE),
|
||||
]
|
||||
|
||||
_TABLE_PATTERNS = [
|
||||
re.compile(r"\|.+\|.+\|", re.MULTILINE),
|
||||
re.compile(r"-{4,}\s+-{4,}", re.MULTILINE),
|
||||
]
|
||||
|
||||
_QUOTE_PATTERNS = [
|
||||
re.compile(r"^\s*>\s+", re.MULTILINE),
|
||||
re.compile(r"[\"\u201c].+?[\"\u201d]"),
|
||||
]
|
||||
|
||||
_DATA_CITATION_PATTERNS = [
|
||||
re.compile(r"\d+%", re.MULTILINE),
|
||||
re.compile(r"(?:study|research|survey|report|data|statistics)", re.IGNORECASE),
|
||||
re.compile(r"(?:according to|based on|shown by|indicates?|reveals?)", re.IGNORECASE),
|
||||
re.compile(r"\b(?:19|20)\d{2}\b"),
|
||||
]
|
||||
|
||||
_EXPERT_CITATION_PATTERNS = [
|
||||
re.compile(r"(?:Dr\.|Professor|Prof\.)\s+\w+", re.IGNORECASE),
|
||||
re.compile(r"(?:expert|analyst|researcher|scientist)\s+\w+", re.IGNORECASE),
|
||||
re.compile(r"(?:from|at)\s+(?:Harvard|Stanford|MIT|Oxford|Cambridge|Yale|Princeton)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
_CERTIFICATION_PATTERNS = [
|
||||
re.compile(r"(?:ISO\s*\d+|FDA\s+approved|CE\s+certif|UL\s+listed|SOC\s*2|HIPAA|GDPR)", re.IGNORECASE),
|
||||
re.compile(r"(?:certified|certification|compliant|accredited)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
_COMPARISON_PATTERNS = [
|
||||
re.compile(r"(?:compared?\s+to|vs\.?|versus|while\s+.+\s*,\s*.+)", re.IGNORECASE),
|
||||
re.compile(r"(?:优于|相比|对比|不如|胜过)"),
|
||||
]
|
||||
|
||||
_DIRECT_CITATION_PATTERNS = [
|
||||
re.compile(r"(?:states?|claims?|says?|mentions?|notes?|reports?|announces?)", re.IGNORECASE),
|
||||
re.compile(r"(?:according to|as stated by|as reported by)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
_CONTENT_RULES: list[tuple[str, list[re.Pattern[str]], float, str]] = [
|
||||
("faq_format", _FAQ_PATTERNS, 0.8, "FAQ format detected in AI responses"),
|
||||
("list_format", _LIST_PATTERNS, 0.8, "List format detected in AI responses"),
|
||||
("table_format", _TABLE_PATTERNS, 0.8, "Table format detected in AI responses"),
|
||||
("quote_block", _QUOTE_PATTERNS, 0.7, "Quote block detected in AI responses"),
|
||||
]
|
||||
|
||||
_AUTHORITY_RULES: list[tuple[str, list[re.Pattern[str]], float, str, int]] = [
|
||||
("data_citation", _DATA_CITATION_PATTERNS, 0.85, "Data citation signals detected in AI responses", 2),
|
||||
("expert_citation", _EXPERT_CITATION_PATTERNS, 0.8, "Expert citation signals detected in AI responses", 1),
|
||||
("certification_mark", _CERTIFICATION_PATTERNS, 0.9, "Certification marks detected in AI responses", 1),
|
||||
]
|
||||
|
||||
_RECOMMENDATION_RULES: list[tuple[str, str, float, str]] = [
|
||||
("content_structure", "faq_format", 0.3, "Consider adding FAQ sections to improve AI citation probability"),
|
||||
("content_structure", "list_format", 0.3, "Use structured lists to make content more extractable by AI engines"),
|
||||
("authority_signal", "data_citation", 0.3, "Include data citations and statistics to increase authority signals"),
|
||||
("authority_signal", "expert_citation", 0.3, "Add expert quotes and references to strengthen E-E-A-T signals"),
|
||||
("citation_format", "direct_citation", 0.2, "Optimize content for direct citation by AI engines"),
|
||||
]
|
||||
|
||||
|
||||
def _matches_any(text: str, patterns: list[re.Pattern[str]]) -> bool:
|
||||
return any(p.search(text) for p in patterns)
|
||||
|
||||
|
||||
def _count_matches(text: str, patterns: list[re.Pattern[str]]) -> int:
|
||||
return sum(1 for p in patterns if p.search(text))
|
||||
|
||||
|
||||
def _build_type_insights(
|
||||
patterns: list[CitationPattern],
|
||||
pattern_type: str,
|
||||
top_key: str,
|
||||
) -> dict[str, Any]:
|
||||
filtered = [p for p in patterns if p.pattern_type == pattern_type]
|
||||
insights: dict[str, Any] = {f"{p.pattern_name}_frequency": p.frequency for p in filtered}
|
||||
if filtered:
|
||||
best = max(filtered, key=lambda p: p.frequency)
|
||||
if best.frequency > 0:
|
||||
insights[top_key] = best.pattern_name
|
||||
return insights
|
||||
|
||||
|
||||
class ContentStructureAnalyzer:
|
||||
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
|
||||
if not results:
|
||||
return self._empty_patterns()
|
||||
|
||||
counts: dict[str, int] = {name: 0 for name, _, _, _ in _CONTENT_RULES}
|
||||
for r in results:
|
||||
for name, patterns, _, _ in _CONTENT_RULES:
|
||||
if _matches_any(r.raw_response, patterns):
|
||||
counts[name] += 1
|
||||
|
||||
total = len(results)
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="content_structure",
|
||||
pattern_name=name,
|
||||
frequency=counts[name] / total,
|
||||
confidence=conf if counts[name] > 0 else 0.0,
|
||||
description=desc,
|
||||
details={"count": counts[name], "total": total},
|
||||
)
|
||||
for name, _, conf, desc in _CONTENT_RULES
|
||||
]
|
||||
|
||||
def _empty_patterns(self) -> list[CitationPattern]:
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="content_structure",
|
||||
pattern_name=name,
|
||||
frequency=0.0,
|
||||
confidence=0.0,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
for name, _, _, _ in _CONTENT_RULES
|
||||
]
|
||||
|
||||
|
||||
class AuthoritySignalAnalyzer:
|
||||
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
|
||||
if not results:
|
||||
return self._empty_patterns()
|
||||
|
||||
counts: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES}
|
||||
extra: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES}
|
||||
|
||||
for r in results:
|
||||
for name, patterns, _, _, threshold in _AUTHORITY_RULES:
|
||||
match_count = _count_matches(r.raw_response, patterns)
|
||||
if match_count >= threshold:
|
||||
counts[name] += 1
|
||||
extra[name] += match_count
|
||||
|
||||
total = len(results)
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="authority_signal",
|
||||
pattern_name=name,
|
||||
frequency=counts[name] / total,
|
||||
confidence=conf if counts[name] > 0 else 0.0,
|
||||
description=desc,
|
||||
details={"count": counts[name], "total": total, "match_count": extra[name]},
|
||||
)
|
||||
for name, _, conf, desc, _ in _AUTHORITY_RULES
|
||||
]
|
||||
|
||||
def _empty_patterns(self) -> list[CitationPattern]:
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="authority_signal",
|
||||
pattern_name=name,
|
||||
frequency=0.0,
|
||||
confidence=0.0,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
for name, _, _, _, _ in _AUTHORITY_RULES
|
||||
]
|
||||
|
||||
|
||||
class CitationFormatAnalyzer:
|
||||
def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]:
|
||||
if not results:
|
||||
return self._empty_patterns()
|
||||
|
||||
direct_count = 0
|
||||
indirect_count = 0
|
||||
comparison_count = 0
|
||||
|
||||
for r in results:
|
||||
if r.has_brand_citation:
|
||||
if _matches_any(r.raw_response, _COMPARISON_PATTERNS) and r.has_competitor_citation:
|
||||
comparison_count += 1
|
||||
elif r.brand_context and _matches_any(r.brand_context, _DIRECT_CITATION_PATTERNS):
|
||||
direct_count += 1
|
||||
else:
|
||||
indirect_count += 1
|
||||
|
||||
total = len(results)
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="citation_format",
|
||||
pattern_name="direct_citation",
|
||||
frequency=direct_count / total,
|
||||
confidence=0.9 if direct_count > 0 else 0.0,
|
||||
description="Direct citation format detected",
|
||||
details={"count": direct_count, "total": total},
|
||||
),
|
||||
CitationPattern(
|
||||
pattern_type="citation_format",
|
||||
pattern_name="indirect_citation",
|
||||
frequency=indirect_count / total,
|
||||
confidence=0.7 if indirect_count > 0 else 0.0,
|
||||
description="Indirect citation format detected",
|
||||
details={"count": indirect_count, "total": total},
|
||||
),
|
||||
CitationPattern(
|
||||
pattern_type="citation_format",
|
||||
pattern_name="comparison_citation",
|
||||
frequency=comparison_count / total,
|
||||
confidence=0.85 if comparison_count > 0 else 0.0,
|
||||
description="Comparison citation format detected",
|
||||
details={"count": comparison_count, "total": total},
|
||||
),
|
||||
]
|
||||
|
||||
def _empty_patterns(self) -> list[CitationPattern]:
|
||||
return [
|
||||
CitationPattern(
|
||||
pattern_type="citation_format",
|
||||
pattern_name=name,
|
||||
frequency=0.0,
|
||||
confidence=0.0,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
for name in ("direct_citation", "indirect_citation", "comparison_citation")
|
||||
]
|
||||
|
||||
|
||||
class EnginePreferenceAnalyzer:
|
||||
def analyze(self, results: list[AIQueryResult]) -> dict[str, dict[str, Any]]:
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
engine_data: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for r in results:
|
||||
engine_name = r.engine_type.value
|
||||
if engine_name not in engine_data:
|
||||
engine_data[engine_name] = {
|
||||
"results": [],
|
||||
"citation_positions": [],
|
||||
"format_hits": {"faq": 0, "list": 0, "table": 0},
|
||||
}
|
||||
entry = engine_data[engine_name]
|
||||
entry["results"].append(r)
|
||||
if r.has_brand_citation:
|
||||
for c in r.citations:
|
||||
entry["citation_positions"].append(c.position)
|
||||
if _matches_any(r.raw_response, _FAQ_PATTERNS):
|
||||
entry["format_hits"]["faq"] += 1
|
||||
if _matches_any(r.raw_response, _LIST_PATTERNS):
|
||||
entry["format_hits"]["list"] += 1
|
||||
if _matches_any(r.raw_response, _TABLE_PATTERNS):
|
||||
entry["format_hits"]["table"] += 1
|
||||
|
||||
prefs: dict[str, dict[str, Any]] = {}
|
||||
for engine_name, data in engine_data.items():
|
||||
total = len(data["results"])
|
||||
cited = sum(1 for r in data["results"] if r.has_brand_citation)
|
||||
positions = data["citation_positions"]
|
||||
avg_pos = sum(positions) / len(positions) if positions else -1
|
||||
|
||||
prefs[engine_name] = {
|
||||
"citation_rate": cited / total if total > 0 else 0.0,
|
||||
"avg_citation_position": avg_pos,
|
||||
"format_preferences": {
|
||||
fmt: count / total if total > 0 else 0.0
|
||||
for fmt, count in data["format_hits"].items()
|
||||
},
|
||||
}
|
||||
|
||||
return prefs
|
||||
|
||||
|
||||
class CitationPatternEngine:
|
||||
def __init__(self) -> None:
|
||||
self.content_analyzer = ContentStructureAnalyzer()
|
||||
self.authority_analyzer = AuthoritySignalAnalyzer()
|
||||
self.format_analyzer = CitationFormatAnalyzer()
|
||||
self.engine_analyzer = EnginePreferenceAnalyzer()
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
results: list[AIQueryResult],
|
||||
brand_id: str,
|
||||
query: str,
|
||||
) -> PatternAnalysisReport:
|
||||
if not results:
|
||||
return PatternAnalysisReport(
|
||||
brand_id=brand_id,
|
||||
query=query,
|
||||
total_results=0,
|
||||
patterns=[],
|
||||
content_structure_insights={},
|
||||
authority_signal_insights={},
|
||||
citation_format_insights={},
|
||||
engine_preferences={},
|
||||
recommendations=[],
|
||||
)
|
||||
|
||||
patterns: list[CitationPattern] = []
|
||||
patterns.extend(self.content_analyzer.analyze(results))
|
||||
patterns.extend(self.authority_analyzer.analyze(results))
|
||||
patterns.extend(self.format_analyzer.analyze(results))
|
||||
|
||||
engine_prefs = self.engine_analyzer.analyze(results)
|
||||
|
||||
recommendations = self._generate_recommendations(patterns, engine_prefs)
|
||||
|
||||
return PatternAnalysisReport(
|
||||
brand_id=brand_id,
|
||||
query=query,
|
||||
total_results=len(results),
|
||||
patterns=patterns,
|
||||
content_structure_insights=_build_type_insights(patterns, "content_structure", "dominant_format"),
|
||||
authority_signal_insights=_build_type_insights(patterns, "authority_signal", "strongest_signal"),
|
||||
citation_format_insights=_build_type_insights(patterns, "citation_format", "primary_format"),
|
||||
engine_preferences=engine_prefs,
|
||||
recommendations=recommendations,
|
||||
)
|
||||
|
||||
def _generate_recommendations(
|
||||
self,
|
||||
patterns: list[CitationPattern],
|
||||
engine_prefs: dict[str, dict],
|
||||
) -> list[str]:
|
||||
recommendations: list[str] = []
|
||||
|
||||
pattern_map = {(p.pattern_type, p.pattern_name): p for p in patterns}
|
||||
for ptype, pname, threshold, message in _RECOMMENDATION_RULES:
|
||||
p = pattern_map.get((ptype, pname))
|
||||
if p and p.frequency < threshold:
|
||||
recommendations.append(message)
|
||||
|
||||
for engine_name, prefs in engine_prefs.items():
|
||||
if prefs["citation_rate"] < 0.3:
|
||||
recommendations.append(
|
||||
f"Low citation rate on {engine_name} ({prefs['citation_rate']:.0%}), "
|
||||
f"consider optimizing content for this engine"
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.ai_engines import _build_adapters
|
||||
from app.models.brand import Brand
|
||||
from app.models.detection_task import VALID_FREQUENCIES, DetectionTask
|
||||
from app.services.ai_engine.base import EngineType
|
||||
from app.services.ai_engine.batch_query import BatchQueryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DetectionSchedulerService:
|
||||
async def create_task(
|
||||
self,
|
||||
task_data: dict,
|
||||
brand_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> DetectionTask:
|
||||
frequency = task_data.get("frequency", "daily")
|
||||
self._validate_frequency(frequency)
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=brand_id,
|
||||
user_id=user_id,
|
||||
name=task_data["name"],
|
||||
frequency=frequency,
|
||||
engines=task_data.get("engines", []),
|
||||
queries=task_data.get("queries", []),
|
||||
competitor_names=task_data.get("competitor_names"),
|
||||
)
|
||||
task.next_run_at = task.compute_next_run_at()
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
logger.info(f"Created detection task: {task.id}, name={task.name}")
|
||||
return task
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: uuid.UUID,
|
||||
task_data: dict,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> DetectionTask:
|
||||
task = await self._get_task_by_id(task_id, user_id, db)
|
||||
if not task:
|
||||
raise TaskNotFoundError(f"Detection task {task_id} not found")
|
||||
|
||||
frequency = task_data.get("frequency", task.frequency)
|
||||
self._validate_frequency(frequency)
|
||||
|
||||
for field, value in task_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
task.next_run_at = task.compute_next_run_at()
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
logger.info(f"Updated detection task: {task.id}")
|
||||
return task
|
||||
|
||||
async def delete_task(
|
||||
self,
|
||||
task_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> bool:
|
||||
task = await self._get_task_by_id(task_id, user_id, db)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
logger.info(f"Deleted detection task: {task_id}")
|
||||
return True
|
||||
|
||||
async def get_tasks(
|
||||
self,
|
||||
brand_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> list[DetectionTask]:
|
||||
stmt = (
|
||||
select(DetectionTask)
|
||||
.where(
|
||||
and_(
|
||||
DetectionTask.brand_id == brand_id,
|
||||
DetectionTask.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(DetectionTask.created_at.desc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def trigger_task(
|
||||
self,
|
||||
task_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> dict:
|
||||
task = await self._get_task_by_id(task_id, user_id, db)
|
||||
if not task:
|
||||
return {"status": "error", "message": "Task not found"}
|
||||
|
||||
return await self.execute_task(task, db)
|
||||
|
||||
async def execute_task(
|
||||
self,
|
||||
task: DetectionTask,
|
||||
db: AsyncSession,
|
||||
) -> dict:
|
||||
try:
|
||||
results = await self._run_batch_query(task, db)
|
||||
await self._generate_alerts_if_needed(task, results, db)
|
||||
|
||||
task.last_run_at = datetime.now(timezone.utc)
|
||||
task.next_run_at = task.compute_next_run_at()
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
logger.info(f"Detection task executed: {task.id}")
|
||||
return {"status": "success", "results": results}
|
||||
except Exception as e:
|
||||
logger.error(f"Detection task execution failed: {task.id}, error={e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _run_batch_query(
|
||||
self,
|
||||
task: DetectionTask,
|
||||
db: AsyncSession,
|
||||
) -> list:
|
||||
adapters = _build_adapters()
|
||||
batch_service = BatchQueryService(adapters)
|
||||
|
||||
all_results = []
|
||||
brand = await self._get_brand(task.brand_id, db)
|
||||
brand_name = brand.name if brand else "Unknown"
|
||||
|
||||
for query_text in task.queries:
|
||||
engines = [
|
||||
EngineType(e) for e in task.engines if e in EngineType._value2member_map_
|
||||
]
|
||||
results = await batch_service.query_batch(
|
||||
engines=engines,
|
||||
query=query_text,
|
||||
brand_name=brand_name,
|
||||
competitor_names=task.competitor_names,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def _generate_alerts_if_needed(
|
||||
self,
|
||||
task: DetectionTask,
|
||||
results: list,
|
||||
db: AsyncSession,
|
||||
) -> list:
|
||||
if not results:
|
||||
return []
|
||||
|
||||
brand = await self._get_brand(task.brand_id, db)
|
||||
if not brand:
|
||||
return []
|
||||
|
||||
brand_cited = sum(1 for r in results if r.has_brand_citation)
|
||||
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
|
||||
|
||||
alerts = []
|
||||
if competitor_cited > 0 and brand_cited == 0:
|
||||
from app.services.alert_engine import AlertEngine
|
||||
|
||||
alert_engine = AlertEngine(db)
|
||||
alert = await alert_engine._create_alert(
|
||||
brand_id=task.brand_id,
|
||||
user_id=task.user_id,
|
||||
alert_type="competitor_overtake",
|
||||
severity="warning",
|
||||
title=f"{brand.name} 未被引用但竞品被引用",
|
||||
message=f"检测任务「{task.name}」执行后发现:品牌未被AI引擎引用,但竞品被引用了 {competitor_cited} 次。",
|
||||
data={"task_id": str(task.id), "competitor_cited": competitor_cited},
|
||||
)
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
|
||||
return alerts
|
||||
|
||||
@staticmethod
|
||||
def _validate_frequency(frequency: str) -> None:
|
||||
if frequency not in VALID_FREQUENCIES:
|
||||
raise ValueError(
|
||||
f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}"
|
||||
)
|
||||
|
||||
async def _get_task_by_id(
|
||||
self,
|
||||
task_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> DetectionTask | None:
|
||||
stmt = select(DetectionTask).where(
|
||||
and_(
|
||||
DetectionTask.id == task_id,
|
||||
DetectionTask.user_id == user_id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _get_brand(
|
||||
self,
|
||||
brand_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> Brand | None:
|
||||
stmt = select(Brand).where(Brand.id == brand_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.database import Base
|
||||
from app.main import app
|
||||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_brand(async_session, test_user):
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
industry="technology",
|
||||
platforms=["wenxin", "kimi"],
|
||||
frequency="weekly",
|
||||
status="active",
|
||||
)
|
||||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
return brand
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(async_session, test_user):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
async def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestDetectionTaskAPI:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_detection_task(self, async_client, test_brand):
|
||||
task_data = {
|
||||
"brand_id": str(test_brand.id),
|
||||
"name": "每日品牌检测",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
"queries": ["最佳保险品牌", "保险推荐"],
|
||||
"competitor_names": ["竞品A"],
|
||||
}
|
||||
response = await async_client.post("/api/v1/detection/tasks", json=task_data)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "每日品牌检测"
|
||||
assert data["frequency"] == "daily"
|
||||
assert data["engines"] == ["chatgpt", "perplexity"]
|
||||
assert data["queries"] == ["最佳保险品牌", "保险推荐"]
|
||||
assert data["is_active"] is True
|
||||
assert "id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_detection_tasks(self, async_client, test_brand):
|
||||
response = await async_client.get(
|
||||
"/api/v1/detection/tasks", params={"brand_id": str(test_brand.id)}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_detection_task(self, async_client, test_brand):
|
||||
create_data = {
|
||||
"brand_id": str(test_brand.id),
|
||||
"name": "原始任务",
|
||||
"frequency": "weekly",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询1"],
|
||||
}
|
||||
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
|
||||
task_id = create_resp.json()["id"]
|
||||
|
||||
update_data = {
|
||||
"name": "更新后任务",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
}
|
||||
response = await async_client.put(f"/api/v1/detection/tasks/{task_id}", json=update_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "更新后任务"
|
||||
assert data["frequency"] == "daily"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_detection_task(self, async_client, test_brand):
|
||||
create_data = {
|
||||
"brand_id": str(test_brand.id),
|
||||
"name": "待删除任务",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询1"],
|
||||
}
|
||||
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
|
||||
task_id = create_resp.json()["id"]
|
||||
|
||||
response = await async_client.delete(f"/api/v1/detection/tasks/{task_id}")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_detection_task(self, async_client, test_brand):
|
||||
create_data = {
|
||||
"brand_id": str(test_brand.id),
|
||||
"name": "手动触发任务",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询1"],
|
||||
}
|
||||
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
|
||||
task_id = create_resp.json()["id"]
|
||||
|
||||
response = await async_client.post(f"/api/v1/detection/tasks/{task_id}/trigger")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_access(self, async_session):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.get("/api/v1/detection/tasks", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
|
@ -0,0 +1,619 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.ai_engine.base import AIQueryResult, CitationInfo, EngineType
|
||||
from app.services.citation_pattern import (
|
||||
AuthoritySignalAnalyzer,
|
||||
CitationFormatAnalyzer,
|
||||
CitationPattern,
|
||||
CitationPatternEngine,
|
||||
ContentStructureAnalyzer,
|
||||
EnginePreferenceAnalyzer,
|
||||
PatternAnalysisReport,
|
||||
)
|
||||
|
||||
|
||||
def _make_result(
|
||||
engine_type: EngineType = EngineType.CHATGPT,
|
||||
raw_response: str = "test response",
|
||||
citations: list[CitationInfo] | None = None,
|
||||
has_brand_citation: bool = False,
|
||||
has_competitor_citation: bool = False,
|
||||
brand_context: str | None = None,
|
||||
competitor_contexts: list[str] | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> AIQueryResult:
|
||||
return AIQueryResult(
|
||||
engine_type=engine_type,
|
||||
query="test query",
|
||||
raw_response=raw_response,
|
||||
citations=citations or [],
|
||||
has_brand_citation=has_brand_citation,
|
||||
has_competitor_citation=has_competitor_citation,
|
||||
brand_context=brand_context,
|
||||
competitor_contexts=competitor_contexts or [],
|
||||
response_time_ms=1000,
|
||||
timestamp=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
class TestCitationPatternDataStructure:
|
||||
def test_create_citation_pattern(self):
|
||||
pattern = CitationPattern(
|
||||
pattern_type="content_structure",
|
||||
pattern_name="faq_format",
|
||||
frequency=0.6,
|
||||
confidence=0.8,
|
||||
description="FAQ format detected in responses",
|
||||
details={"count": 3, "total": 5},
|
||||
)
|
||||
assert pattern.pattern_type == "content_structure"
|
||||
assert pattern.pattern_name == "faq_format"
|
||||
assert pattern.frequency == 0.6
|
||||
assert pattern.confidence == 0.8
|
||||
assert pattern.details["count"] == 3
|
||||
|
||||
def test_frequency_range(self):
|
||||
pattern = CitationPattern(
|
||||
pattern_type="authority_signal",
|
||||
pattern_name="data_citation",
|
||||
frequency=0.0,
|
||||
confidence=0.5,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
assert 0.0 <= pattern.frequency <= 1.0
|
||||
|
||||
def test_confidence_range(self):
|
||||
pattern = CitationPattern(
|
||||
pattern_type="citation_format",
|
||||
pattern_name="direct_citation",
|
||||
frequency=0.5,
|
||||
confidence=1.0,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
assert 0.0 <= pattern.confidence <= 1.0
|
||||
|
||||
def test_pattern_types(self):
|
||||
valid_types = {"content_structure", "authority_signal", "citation_format", "engine_preference"}
|
||||
for pt in valid_types:
|
||||
pattern = CitationPattern(
|
||||
pattern_type=pt,
|
||||
pattern_name="test",
|
||||
frequency=0.5,
|
||||
confidence=0.5,
|
||||
description="",
|
||||
details={},
|
||||
)
|
||||
assert pattern.pattern_type in valid_types
|
||||
|
||||
|
||||
class TestPatternAnalysisReportDataStructure:
|
||||
def test_create_report(self):
|
||||
report = PatternAnalysisReport(
|
||||
brand_id="brand-123",
|
||||
query="test query",
|
||||
total_results=5,
|
||||
patterns=[],
|
||||
content_structure_insights={},
|
||||
authority_signal_insights={},
|
||||
citation_format_insights={},
|
||||
engine_preferences={},
|
||||
recommendations=[],
|
||||
)
|
||||
assert report.brand_id == "brand-123"
|
||||
assert report.total_results == 5
|
||||
assert report.patterns == []
|
||||
assert report.recommendations == []
|
||||
|
||||
def test_report_with_patterns(self):
|
||||
patterns = [
|
||||
CitationPattern(
|
||||
pattern_type="content_structure",
|
||||
pattern_name="faq_format",
|
||||
frequency=0.7,
|
||||
confidence=0.9,
|
||||
description="FAQ detected",
|
||||
details={"count": 7},
|
||||
)
|
||||
]
|
||||
report = PatternAnalysisReport(
|
||||
brand_id="b1",
|
||||
query="q",
|
||||
total_results=10,
|
||||
patterns=patterns,
|
||||
content_structure_insights={"faq_frequency": 0.7},
|
||||
authority_signal_insights={},
|
||||
citation_format_insights={},
|
||||
engine_preferences={},
|
||||
recommendations=["Add FAQ sections"],
|
||||
)
|
||||
assert len(report.patterns) == 1
|
||||
assert report.content_structure_insights["faq_frequency"] == 0.7
|
||||
assert len(report.recommendations) == 1
|
||||
|
||||
|
||||
class TestContentStructureAnalyzer:
|
||||
@pytest.fixture
|
||||
def analyzer(self):
|
||||
return ContentStructureAnalyzer()
|
||||
|
||||
def test_faq_format_detection(self, analyzer):
|
||||
faq_response = """
|
||||
Q: What is GEO?
|
||||
A: GEO stands for Generative Engine Optimization.
|
||||
|
||||
Q: How does GEO work?
|
||||
A: GEO works by optimizing content for AI engines.
|
||||
|
||||
常见问题:
|
||||
问题: 什么是SEO?
|
||||
回答: SEO是搜索引擎优化。
|
||||
"""
|
||||
results = [_make_result(raw_response=faq_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"]
|
||||
assert len(faq_pattern) == 1
|
||||
assert faq_pattern[0].frequency > 0
|
||||
assert faq_pattern[0].pattern_type == "content_structure"
|
||||
|
||||
def test_list_format_detection(self, analyzer):
|
||||
list_response = """
|
||||
Here are the top features:
|
||||
1. Fast performance
|
||||
2. Easy to use
|
||||
3. Affordable price
|
||||
|
||||
- Benefit A
|
||||
- Benefit B
|
||||
- Benefit C
|
||||
"""
|
||||
results = [_make_result(raw_response=list_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
list_pattern = [p for p in patterns if p.pattern_name == "list_format"]
|
||||
assert len(list_pattern) == 1
|
||||
assert list_pattern[0].frequency > 0
|
||||
|
||||
def test_table_format_detection(self, analyzer):
|
||||
table_response = """
|
||||
| Feature | Plan A | Plan B |
|
||||
|---------|--------|--------|
|
||||
| Price | $10 | $20 |
|
||||
|
||||
Comparison table:
|
||||
Item Value
|
||||
---- -----
|
||||
A 100
|
||||
B 200
|
||||
"""
|
||||
results = [_make_result(raw_response=table_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
table_pattern = [p for p in patterns if p.pattern_name == "table_format"]
|
||||
assert len(table_pattern) == 1
|
||||
assert table_pattern[0].frequency > 0
|
||||
|
||||
def test_quote_block_detection(self, analyzer):
|
||||
quote_response = """
|
||||
According to the expert:
|
||||
> "This is the best solution on the market."
|
||||
|
||||
As stated in the report:
|
||||
"The company leads in innovation."
|
||||
"""
|
||||
results = [_make_result(raw_response=quote_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
quote_pattern = [p for p in patterns if p.pattern_name == "quote_block"]
|
||||
assert len(quote_pattern) == 1
|
||||
assert quote_pattern[0].frequency > 0
|
||||
|
||||
def test_no_structure_detected(self, analyzer):
|
||||
plain_response = "This is a plain text response without any special formatting."
|
||||
results = [_make_result(raw_response=plain_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
for p in patterns:
|
||||
assert p.frequency == 0.0
|
||||
|
||||
def test_multiple_results_aggregation(self, analyzer):
|
||||
results = [
|
||||
_make_result(raw_response="Q: What? A: Something."),
|
||||
_make_result(raw_response="Plain text without structure."),
|
||||
_make_result(raw_response="1. Item one\n2. Item two"),
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"][0]
|
||||
list_pattern = [p for p in patterns if p.pattern_name == "list_format"][0]
|
||||
assert faq_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
|
||||
assert list_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
|
||||
|
||||
|
||||
class TestAuthoritySignalAnalyzer:
|
||||
@pytest.fixture
|
||||
def analyzer(self):
|
||||
return AuthoritySignalAnalyzer()
|
||||
|
||||
def test_data_citation_detection(self, analyzer):
|
||||
data_response = """
|
||||
According to a 2024 study by MIT, 78% of companies adopted AI.
|
||||
Research from Stanford University shows that productivity increased by 45%.
|
||||
Data from the World Health Organization indicates a 30% reduction.
|
||||
Statistics show that 90% of users prefer this approach.
|
||||
"""
|
||||
results = [_make_result(raw_response=data_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
data_pattern = [p for p in patterns if p.pattern_name == "data_citation"]
|
||||
assert len(data_pattern) == 1
|
||||
assert data_pattern[0].frequency > 0
|
||||
assert data_pattern[0].details["match_count"] > 0
|
||||
|
||||
def test_expert_citation_detection(self, analyzer):
|
||||
expert_response = """
|
||||
Dr. Smith from Harvard notes that this trend will continue.
|
||||
Professor Johnson at Stanford recommends this approach.
|
||||
Expert analyst Jane Doe suggests using this method.
|
||||
According to industry expert John, the market will grow.
|
||||
"""
|
||||
results = [_make_result(raw_response=expert_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"]
|
||||
assert len(expert_pattern) == 1
|
||||
assert expert_pattern[0].frequency > 0
|
||||
|
||||
def test_certification_mark_detection(self, analyzer):
|
||||
cert_response = """
|
||||
The product is ISO 9001 certified and FDA approved.
|
||||
It has received CE certification and UL listed status.
|
||||
SOC 2 Type II compliant and HIPAA compliant.
|
||||
"""
|
||||
results = [_make_result(raw_response=cert_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
cert_pattern = [p for p in patterns if p.pattern_name == "certification_mark"]
|
||||
assert len(cert_pattern) == 1
|
||||
assert cert_pattern[0].frequency > 0
|
||||
|
||||
def test_no_authority_signals(self, analyzer):
|
||||
plain_response = "This is a basic response without any authority signals."
|
||||
results = [_make_result(raw_response=plain_response)]
|
||||
patterns = analyzer.analyze(results)
|
||||
for p in patterns:
|
||||
assert p.frequency == 0.0
|
||||
|
||||
def test_multiple_results_aggregation(self, analyzer):
|
||||
results = [
|
||||
_make_result(raw_response="According to a 2024 study, 80% agree."),
|
||||
_make_result(raw_response="No authority signals here."),
|
||||
_make_result(raw_response="Dr. Lee from Oxford confirms the findings."),
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
data_pattern = [p for p in patterns if p.pattern_name == "data_citation"][0]
|
||||
expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"][0]
|
||||
assert data_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
|
||||
assert expert_pattern.frequency == pytest.approx(1 / 3, rel=0.01)
|
||||
|
||||
|
||||
class TestCitationFormatAnalyzer:
|
||||
@pytest.fixture
|
||||
def analyzer(self):
|
||||
return CitationFormatAnalyzer()
|
||||
|
||||
def test_direct_citation_detection(self, analyzer):
|
||||
direct_response = """
|
||||
According to BrandX, their product is the best in class.
|
||||
BrandX states that they have over 1 million users.
|
||||
BrandX claims to be the industry leader.
|
||||
"""
|
||||
results = [
|
||||
_make_result(
|
||||
raw_response=direct_response,
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX states that they have over 1 million users",
|
||||
)
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"]
|
||||
assert len(direct_pattern) == 1
|
||||
assert direct_pattern[0].frequency > 0
|
||||
|
||||
def test_indirect_citation_detection(self, analyzer):
|
||||
indirect_response = """
|
||||
Some leading solutions in this space include comprehensive platforms
|
||||
that offer multiple features. One such platform provides AI-powered
|
||||
analytics and real-time monitoring capabilities.
|
||||
"""
|
||||
results = [
|
||||
_make_result(
|
||||
raw_response=indirect_response,
|
||||
has_brand_citation=True,
|
||||
brand_context="One such platform provides AI-powered analytics",
|
||||
)
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
indirect_pattern = [p for p in patterns if p.pattern_name == "indirect_citation"]
|
||||
assert len(indirect_pattern) == 1
|
||||
|
||||
def test_comparison_citation_detection(self, analyzer):
|
||||
comparison_response = """
|
||||
Compared to BrandX, CompetitorY offers better pricing.
|
||||
While BrandX focuses on enterprise, CompetitorY targets SMBs.
|
||||
BrandX vs CompetitorY: BrandX has more features but CompetitorY is cheaper.
|
||||
"""
|
||||
results = [
|
||||
_make_result(
|
||||
raw_response=comparison_response,
|
||||
has_brand_citation=True,
|
||||
has_competitor_citation=True,
|
||||
brand_context="Compared to BrandX",
|
||||
competitor_contexts=["CompetitorY offers better pricing"],
|
||||
)
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
comparison_pattern = [p for p in patterns if p.pattern_name == "comparison_citation"]
|
||||
assert len(comparison_pattern) == 1
|
||||
assert comparison_pattern[0].frequency > 0
|
||||
|
||||
def test_no_citation_format(self, analyzer):
|
||||
results = [
|
||||
_make_result(
|
||||
raw_response="Generic response without citations.",
|
||||
has_brand_citation=False,
|
||||
)
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
for p in patterns:
|
||||
assert p.frequency == 0.0
|
||||
|
||||
def test_citation_with_position_info(self, analyzer):
|
||||
results = [
|
||||
_make_result(
|
||||
raw_response="BrandX is mentioned here.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX is mentioned here",
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example",
|
||||
citation_context="BrandX is mentioned here",
|
||||
confidence=0.9,
|
||||
position=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
patterns = analyzer.analyze(results)
|
||||
direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"][0]
|
||||
assert direct_pattern.frequency > 0
|
||||
|
||||
|
||||
class TestEnginePreferenceAnalyzer:
|
||||
@pytest.fixture
|
||||
def analyzer(self):
|
||||
return EnginePreferenceAnalyzer()
|
||||
|
||||
def test_single_engine_citation_rate(self, analyzer):
|
||||
results = [
|
||||
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True),
|
||||
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=False),
|
||||
]
|
||||
prefs = analyzer.analyze(results)
|
||||
assert "chatgpt" in prefs
|
||||
assert prefs["chatgpt"]["citation_rate"] == 0.5
|
||||
|
||||
def test_multi_engine_preferences(self, analyzer):
|
||||
results = [
|
||||
_make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True),
|
||||
_make_result(engine_type=EngineType.PERPLEXITY, has_brand_citation=True),
|
||||
_make_result(engine_type=EngineType.KIMI, has_brand_citation=False),
|
||||
_make_result(engine_type=EngineType.DEEPSEEK, has_brand_citation=True),
|
||||
]
|
||||
prefs = analyzer.analyze(results)
|
||||
assert len(prefs) == 4
|
||||
for engine_name, engine_data in prefs.items():
|
||||
assert "citation_rate" in engine_data
|
||||
assert "avg_citation_position" in engine_data
|
||||
assert "format_preferences" in engine_data
|
||||
|
||||
def test_citation_position_preference(self, analyzer):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
has_brand_citation=True,
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example",
|
||||
citation_context="test",
|
||||
confidence=0.9,
|
||||
position=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
_make_result(
|
||||
engine_type=EngineType.PERPLEXITY,
|
||||
has_brand_citation=True,
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example",
|
||||
citation_context="test",
|
||||
confidence=0.9,
|
||||
position=5,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
prefs = analyzer.analyze(results)
|
||||
assert prefs["chatgpt"]["avg_citation_position"] < prefs["perplexity"]["avg_citation_position"]
|
||||
|
||||
def test_format_preferences(self, analyzer):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What? A: Something.\n1. First item\n2. Second item",
|
||||
),
|
||||
_make_result(
|
||||
engine_type=EngineType.PERPLEXITY,
|
||||
raw_response="Plain text without structure.",
|
||||
),
|
||||
]
|
||||
prefs = analyzer.analyze(results)
|
||||
assert "faq" in prefs["chatgpt"]["format_preferences"]
|
||||
assert "list" in prefs["chatgpt"]["format_preferences"]
|
||||
|
||||
def test_empty_results(self, analyzer):
|
||||
prefs = analyzer.analyze([])
|
||||
assert prefs == {}
|
||||
|
||||
|
||||
class TestCitationPatternEngine:
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return CitationPatternEngine()
|
||||
|
||||
def test_full_analysis_flow(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What is GEO? A: GEO is optimization for AI.\n1. First benefit\n2. Second benefit",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX is mentioned",
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example",
|
||||
citation_context="BrandX is mentioned",
|
||||
confidence=0.9,
|
||||
position=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
_make_result(
|
||||
engine_type=EngineType.PERPLEXITY,
|
||||
raw_response="According to a 2024 study, 80% of companies use AI. Dr. Smith recommends BrandX.",
|
||||
has_brand_citation=True,
|
||||
brand_context="Dr. Smith recommends BrandX",
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://example.com",
|
||||
source_title="Example",
|
||||
citation_context="Dr. Smith recommends BrandX",
|
||||
confidence=0.8,
|
||||
position=2,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
report = engine.analyze(results, brand_id="brand-123", query="what is geo")
|
||||
assert isinstance(report, PatternAnalysisReport)
|
||||
assert report.brand_id == "brand-123"
|
||||
assert report.query == "what is geo"
|
||||
assert report.total_results == 2
|
||||
assert len(report.patterns) > 0
|
||||
assert isinstance(report.content_structure_insights, dict)
|
||||
assert isinstance(report.authority_signal_insights, dict)
|
||||
assert isinstance(report.citation_format_insights, dict)
|
||||
assert isinstance(report.engine_preferences, dict)
|
||||
assert isinstance(report.recommendations, list)
|
||||
|
||||
def test_empty_input_handling(self, engine):
|
||||
report = engine.analyze([], brand_id="brand-1", query="test")
|
||||
assert report.total_results == 0
|
||||
assert report.patterns == []
|
||||
assert report.content_structure_insights == {}
|
||||
assert report.authority_signal_insights == {}
|
||||
assert report.citation_format_insights == {}
|
||||
assert report.engine_preferences == {}
|
||||
assert report.recommendations == []
|
||||
|
||||
def test_single_engine_result(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What? A: This.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX is great",
|
||||
)
|
||||
]
|
||||
report = engine.analyze(results, brand_id="b1", query="q")
|
||||
assert report.total_results == 1
|
||||
assert "chatgpt" in report.engine_preferences
|
||||
assert len(report.patterns) > 0
|
||||
|
||||
def test_multi_engine_aggregation(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What? A: This. BrandX is the best.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX is the best",
|
||||
),
|
||||
_make_result(
|
||||
engine_type=EngineType.PERPLEXITY,
|
||||
raw_response="1. First point\n2. Second point\nBrandX offers great value.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX offers great value",
|
||||
),
|
||||
_make_result(
|
||||
engine_type=EngineType.KIMI,
|
||||
raw_response="Plain response. No citations here.",
|
||||
has_brand_citation=False,
|
||||
),
|
||||
]
|
||||
report = engine.analyze(results, brand_id="b1", query="q")
|
||||
assert report.total_results == 3
|
||||
assert len(report.engine_preferences) == 3
|
||||
chatgpt_rate = report.engine_preferences["chatgpt"]["citation_rate"]
|
||||
kimi_rate = report.engine_preferences["kimi"]["citation_rate"]
|
||||
assert chatgpt_rate == 1.0
|
||||
assert kimi_rate == 0.0
|
||||
|
||||
def test_pattern_report_generation(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What is GEO? A: GEO optimization.\nAccording to 2024 research, 75% agree.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX provides GEO",
|
||||
citations=[
|
||||
CitationInfo(
|
||||
source_url="https://brandx.com",
|
||||
source_title="BrandX",
|
||||
citation_context="BrandX provides GEO",
|
||||
confidence=0.95,
|
||||
position=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
report = engine.analyze(results, brand_id="b1", query="geo optimization")
|
||||
assert len(report.patterns) > 0
|
||||
pattern_types = {p.pattern_type for p in report.patterns}
|
||||
assert "content_structure" in pattern_types
|
||||
assert "authority_signal" in pattern_types
|
||||
assert "citation_format" in pattern_types
|
||||
|
||||
def test_recommendations_generated(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What? A: Something.",
|
||||
has_brand_citation=False,
|
||||
),
|
||||
]
|
||||
report = engine.analyze(results, brand_id="b1", query="q")
|
||||
assert isinstance(report.recommendations, list)
|
||||
|
||||
def test_insights_populated(self, engine):
|
||||
results = [
|
||||
_make_result(
|
||||
engine_type=EngineType.CHATGPT,
|
||||
raw_response="Q: What? A: This.\n1. Item\nISO 9001 certified.",
|
||||
has_brand_citation=True,
|
||||
brand_context="BrandX is here",
|
||||
),
|
||||
]
|
||||
report = engine.analyze(results, brand_id="b1", query="q")
|
||||
assert len(report.content_structure_insights) > 0
|
||||
assert len(report.authority_signal_insights) > 0
|
||||
assert len(report.citation_format_insights) > 0
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash=hash_password("Test@123456"),
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_brand(async_session, test_user):
|
||||
brand = Brand(
|
||||
id=uuid.uuid4(),
|
||||
user_id=test_user.id,
|
||||
name="Test Brand",
|
||||
aliases=["TestBrand", "TB"],
|
||||
website="https://testbrand.com",
|
||||
industry="technology",
|
||||
platforms=["wenxin", "kimi"],
|
||||
frequency="weekly",
|
||||
status="active",
|
||||
)
|
||||
async_session.add(brand)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(brand)
|
||||
return brand
|
||||
|
||||
|
||||
class TestDetectionTaskModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_detection_task(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="每日品牌检测",
|
||||
frequency="daily",
|
||||
engines=["chatgpt", "perplexity"],
|
||||
queries=["最佳保险品牌", "保险推荐"],
|
||||
competitor_names=["竞品A", "竞品B"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
assert task.id is not None
|
||||
assert task.brand_id == test_brand.id
|
||||
assert task.user_id == test_user.id
|
||||
assert task.name == "每日品牌检测"
|
||||
assert task.frequency == "daily"
|
||||
assert task.engines == ["chatgpt", "perplexity"]
|
||||
assert task.queries == ["最佳保险品牌", "保险推荐"]
|
||||
assert task.competitor_names == ["竞品A", "竞品B"]
|
||||
assert task.is_active is True
|
||||
assert task.last_run_at is None
|
||||
assert task.next_run_at is not None
|
||||
assert task.created_at is not None
|
||||
assert task.updated_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detection_task_default_values(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="简单检测",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
queries=["测试查询"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
assert task.is_active is True
|
||||
assert task.competitor_names is None
|
||||
assert task.last_run_at is None
|
||||
|
||||
|
||||
class TestDetectionSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
task_data = {
|
||||
"name": "每日品牌检测",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
"queries": ["最佳保险品牌", "保险推荐"],
|
||||
"competitor_names": ["竞品A"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
|
||||
assert task.id is not None
|
||||
assert task.name == "每日品牌检测"
|
||||
assert task.frequency == "daily"
|
||||
assert task.brand_id == test_brand.id
|
||||
assert task.user_id == test_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="旧名称",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
queries=["查询1"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
update_data = {
|
||||
"name": "新名称",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt", "perplexity"],
|
||||
}
|
||||
updated = await service.update_task(task.id, update_data, test_user.id, async_session)
|
||||
|
||||
assert updated.name == "新名称"
|
||||
assert updated.frequency == "daily"
|
||||
assert updated.engines == ["chatgpt", "perplexity"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="待删除",
|
||||
frequency="weekly",
|
||||
engines=["chatgpt"],
|
||||
queries=["查询1"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(task.id, test_user.id, async_session)
|
||||
assert result is True
|
||||
|
||||
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
|
||||
db_result = await async_session.execute(stmt)
|
||||
assert db_result.scalar_one_or_none() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
for i in range(3):
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name=f"任务{i}",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
queries=[f"查询{i}"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
|
||||
assert len(tasks) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_task(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="手动触发测试",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
queries=["测试查询"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = {"status": "success", "results": []}
|
||||
result = await service.trigger_task(task.id, test_user.id, async_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
mock_execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frequency_validation_hourly(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
task_data = {
|
||||
"name": "每小时检测",
|
||||
"frequency": "hourly",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
assert task.frequency == "hourly"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frequency_validation_daily(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
task_data = {
|
||||
"name": "每日检测",
|
||||
"frequency": "daily",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
assert task.frequency == "daily"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frequency_validation_weekly(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
task_data = {
|
||||
"name": "每周检测",
|
||||
"frequency": "weekly",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
task = await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
assert task.frequency == "weekly"
|
||||
assert task.next_run_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frequency_validation_invalid(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
task_data = {
|
||||
"name": "无效频率",
|
||||
"frequency": "monthly",
|
||||
"engines": ["chatgpt"],
|
||||
"queries": ["查询"],
|
||||
}
|
||||
with pytest.raises(ValueError, match="frequency"):
|
||||
await service.create_task(task_data, test_brand.id, test_user.id, async_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_flow(self, async_session, test_brand, test_user):
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
task = DetectionTask(
|
||||
brand_id=test_brand.id,
|
||||
user_id=test_user.id,
|
||||
name="执行流程测试",
|
||||
frequency="daily",
|
||||
engines=["chatgpt"],
|
||||
queries=["测试查询"],
|
||||
competitor_names=["竞品A"],
|
||||
)
|
||||
async_session.add(task)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(task)
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
|
||||
mock_batch_result = [
|
||||
MagicMock(
|
||||
engine_type=MagicMock(value="chatgpt"),
|
||||
has_brand_citation=True,
|
||||
has_competitor_citation=False,
|
||||
)
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result
|
||||
), patch.object(
|
||||
service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[]
|
||||
):
|
||||
result = await service.execute_task(task, async_session)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert "results" in result
|
||||
|
||||
await async_session.refresh(task)
|
||||
assert task.last_run_at is not None
|
||||
assert task.next_run_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task_not_found(self, async_session, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
result = await service.delete_task(uuid.uuid4(), test_user.id, async_session)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_not_found(self, async_session, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
with pytest.raises(TaskNotFoundError):
|
||||
await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
|
||||
from app.services.detection_scheduler import DetectionSchedulerService
|
||||
|
||||
service = DetectionSchedulerService()
|
||||
tasks = await service.get_tasks(test_brand.id, test_user.id, async_session)
|
||||
assert tasks == []
|
||||
Loading…
Reference in New Issue