diff --git a/backend/app/api/detection.py b/backend/app/api/detection.py new file mode 100644 index 0000000..168e1a0 --- /dev/null +++ b/backend/app/api/detection.py @@ -0,0 +1,152 @@ +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.brand import Brand +from app.models.user import User +from app.schemas.detection_task import ( + DetectionTaskCreate, + DetectionTaskListResponse, + DetectionTaskResponse, + DetectionTaskUpdate, + DetectionTriggerResponse, +) +from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +async def verify_brand_ownership( + brand_id: uuid.UUID, + current_user: User, + db: AsyncSession, +) -> Brand: + stmt = select(Brand).where( + and_( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"品牌 {brand_id} 不存在或不属于当前用户", + ) + + return brand + + +@router.post("/tasks", response_model=DetectionTaskResponse, status_code=status.HTTP_201_CREATED) +async def create_detection_task( + data: DetectionTaskCreate, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await verify_brand_ownership(data.brand_id, current_user, db) + + service = DetectionSchedulerService() + try: + task = await service.create_task( + task_data=data.model_dump(exclude={"brand_id"}), + brand_id=data.brand_id, + user_id=current_user.id, + db=db, + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(e), + ) + + return task + + +@router.get("/tasks", response_model=DetectionTaskListResponse) +async def get_detection_tasks( + brand_id: uuid.UUID = Query(..., description="按品牌筛选"), + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await verify_brand_ownership(brand_id, current_user, db) + + service = DetectionSchedulerService() + tasks = await service.get_tasks(brand_id, current_user.id, db) + + return {"items": tasks[skip : skip + limit], "total": len(tasks)} + + +@router.put("/tasks/{task_id}", response_model=DetectionTaskResponse) +async def update_detection_task( + task_id: uuid.UUID, + data: DetectionTaskUpdate, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + service = DetectionSchedulerService() + try: + task = await service.update_task( + task_id=task_id, + task_data=data.model_dump(exclude_unset=True), + user_id=current_user.id, + db=db, + ) + except TaskNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="检测任务不存在", + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(e), + ) + + return task + + +@router.delete("/tasks/{task_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_detection_task( + task_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + service = DetectionSchedulerService() + result = await service.delete_task(task_id, current_user.id, db) + + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="检测任务不存在", + ) + + return None + + +@router.post("/tasks/{task_id}/trigger", response_model=DetectionTriggerResponse) +async def trigger_detection_task( + task_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + service = DetectionSchedulerService() + result = await service.trigger_task(task_id, current_user.id, db) + + if result.get("status") == "error": + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=result.get("message", "检测任务不存在"), + ) + + return result diff --git a/backend/app/main.py b/backend/app/main.py index 2337bb9..524e4a1 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -39,6 +39,7 @@ from app.api.platform_rules import router as platform_rules_router from app.api.image import router as image_router from app.api.knowledge_graph import router as knowledge_graph_router from app.api.ai_engines import router as ai_engines_router +from app.api.detection import router as detection_router from app.config import settings from app.database import engine, Base from app.schemas.common import ErrorResponse, ErrorCode @@ -167,6 +168,7 @@ app.include_router(platform_rules_router) app.include_router(image_router, prefix="/api/v1") app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases") app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"]) +app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"]) @app.get("/health", tags=["可观测性"]) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index e4395be..da5bf86 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -29,6 +29,7 @@ from app.models.competitor import Competitor from app.models.suggestion import Suggestion from app.models.alert import Alert from app.models.alert_setting import AlertSetting +from app.models.detection_task import DetectionTask __all__ = [ "User", @@ -68,4 +69,5 @@ __all__ = [ "Suggestion", "Alert", "AlertSetting", + "DetectionTask", ] diff --git a/backend/app/models/detection_task.py b/backend/app/models/detection_task.py new file mode 100644 index 0000000..957500f --- /dev/null +++ b/backend/app/models/detection_task.py @@ -0,0 +1,65 @@ +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Uuid, func +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base, JSONType + +VALID_FREQUENCIES = {"hourly", "daily", "weekly"} + +FREQUENCY_DELTAS = { + "hourly": timedelta(hours=1), + "daily": timedelta(days=1), + "weekly": timedelta(weeks=1), +} + + +class DetectionTask(Base): + __tablename__ = "detection_tasks" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + user_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + nullable=False, + index=True, + ) + name: Mapped[str] = mapped_column(String(200), nullable=False) + frequency: Mapped[str] = mapped_column(String(20), nullable=False) + engines: Mapped[list] = mapped_column(JSONType, default=list, nullable=False) + queries: Mapped[list] = mapped_column(JSONType, default=list, nullable=False) + competitor_names: Mapped[list | None] = mapped_column(JSONType, nullable=True) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + last_run_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + next_run_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, default=lambda: datetime.now(timezone.utc) + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + __table_args__ = ( + Index("idx_detection_tasks_brand_id", "brand_id"), + Index("idx_detection_tasks_user_id", "user_id"), + Index("idx_detection_tasks_is_active", "is_active"), + ) + + def compute_next_run_at(self) -> datetime: + delta = FREQUENCY_DELTAS.get(self.frequency, timedelta(days=1)) + base = self.last_run_at or datetime.now(timezone.utc) + return base + delta diff --git a/backend/app/schemas/detection_task.py b/backend/app/schemas/detection_task.py new file mode 100644 index 0000000..67ade5f --- /dev/null +++ b/backend/app/schemas/detection_task.py @@ -0,0 +1,51 @@ +import uuid +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + + +class DetectionTaskCreate(BaseModel): + brand_id: uuid.UUID + name: str = Field(..., min_length=1, max_length=200) + frequency: str = Field(..., pattern="^(hourly|daily|weekly)$") + engines: list[str] = Field(default=[], min_length=1) + queries: list[str] = Field(default=[], min_length=1) + competitor_names: Optional[list[str]] = None + + +class DetectionTaskUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=200) + frequency: Optional[str] = Field(None, pattern="^(hourly|daily|weekly)$") + engines: Optional[list[str]] = None + queries: Optional[list[str]] = None + competitor_names: Optional[list[str]] = None + is_active: Optional[bool] = None + + +class DetectionTaskResponse(BaseModel): + id: uuid.UUID + brand_id: uuid.UUID + user_id: uuid.UUID + name: str + frequency: str + engines: list[str] + queries: list[str] + competitor_names: Optional[list[str]] = None + is_active: bool + last_run_at: Optional[datetime] = None + next_run_at: Optional[datetime] = None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class DetectionTaskListResponse(BaseModel): + items: list[DetectionTaskResponse] + total: int + + +class DetectionTriggerResponse(BaseModel): + status: str + message: Optional[str] = None diff --git a/backend/app/services/citation_pattern.py b/backend/app/services/citation_pattern.py new file mode 100644 index 0000000..3636af7 --- /dev/null +++ b/backend/app/services/citation_pattern.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any + +from app.services.ai_engine.base import AIQueryResult + + +@dataclass +class CitationPattern: + pattern_type: str + pattern_name: str + frequency: float + confidence: float + description: str + details: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PatternAnalysisReport: + brand_id: str + query: str + total_results: int + patterns: list[CitationPattern] + content_structure_insights: dict[str, Any] + authority_signal_insights: dict[str, Any] + citation_format_insights: dict[str, Any] + engine_preferences: dict[str, Any] + recommendations: list[str] = field(default_factory=list) + + +_FAQ_PATTERNS = [ + re.compile(r"Q:\s*.+?\s*A:\s*", re.DOTALL | re.IGNORECASE), + re.compile(r"问题[::]\s*.+?\s*回答[::]\s*", re.DOTALL), + re.compile(r"常见问题", re.IGNORECASE), + re.compile(r"FAQ", re.IGNORECASE), +] + +_LIST_PATTERNS = [ + re.compile(r"(?:^|\n)\s*\d+\.\s+", re.MULTILINE), + re.compile(r"(?:^|\n)\s*[-*]\s+", re.MULTILINE), +] + +_TABLE_PATTERNS = [ + re.compile(r"\|.+\|.+\|", re.MULTILINE), + re.compile(r"-{4,}\s+-{4,}", re.MULTILINE), +] + +_QUOTE_PATTERNS = [ + re.compile(r"^\s*>\s+", re.MULTILINE), + re.compile(r"[\"\u201c].+?[\"\u201d]"), +] + +_DATA_CITATION_PATTERNS = [ + re.compile(r"\d+%", re.MULTILINE), + re.compile(r"(?:study|research|survey|report|data|statistics)", re.IGNORECASE), + re.compile(r"(?:according to|based on|shown by|indicates?|reveals?)", re.IGNORECASE), + re.compile(r"\b(?:19|20)\d{2}\b"), +] + +_EXPERT_CITATION_PATTERNS = [ + re.compile(r"(?:Dr\.|Professor|Prof\.)\s+\w+", re.IGNORECASE), + re.compile(r"(?:expert|analyst|researcher|scientist)\s+\w+", re.IGNORECASE), + re.compile(r"(?:from|at)\s+(?:Harvard|Stanford|MIT|Oxford|Cambridge|Yale|Princeton)", re.IGNORECASE), +] + +_CERTIFICATION_PATTERNS = [ + re.compile(r"(?:ISO\s*\d+|FDA\s+approved|CE\s+certif|UL\s+listed|SOC\s*2|HIPAA|GDPR)", re.IGNORECASE), + re.compile(r"(?:certified|certification|compliant|accredited)", re.IGNORECASE), +] + +_COMPARISON_PATTERNS = [ + re.compile(r"(?:compared?\s+to|vs\.?|versus|while\s+.+\s*,\s*.+)", re.IGNORECASE), + re.compile(r"(?:优于|相比|对比|不如|胜过)"), +] + +_DIRECT_CITATION_PATTERNS = [ + re.compile(r"(?:states?|claims?|says?|mentions?|notes?|reports?|announces?)", re.IGNORECASE), + re.compile(r"(?:according to|as stated by|as reported by)", re.IGNORECASE), +] + +_CONTENT_RULES: list[tuple[str, list[re.Pattern[str]], float, str]] = [ + ("faq_format", _FAQ_PATTERNS, 0.8, "FAQ format detected in AI responses"), + ("list_format", _LIST_PATTERNS, 0.8, "List format detected in AI responses"), + ("table_format", _TABLE_PATTERNS, 0.8, "Table format detected in AI responses"), + ("quote_block", _QUOTE_PATTERNS, 0.7, "Quote block detected in AI responses"), +] + +_AUTHORITY_RULES: list[tuple[str, list[re.Pattern[str]], float, str, int]] = [ + ("data_citation", _DATA_CITATION_PATTERNS, 0.85, "Data citation signals detected in AI responses", 2), + ("expert_citation", _EXPERT_CITATION_PATTERNS, 0.8, "Expert citation signals detected in AI responses", 1), + ("certification_mark", _CERTIFICATION_PATTERNS, 0.9, "Certification marks detected in AI responses", 1), +] + +_RECOMMENDATION_RULES: list[tuple[str, str, float, str]] = [ + ("content_structure", "faq_format", 0.3, "Consider adding FAQ sections to improve AI citation probability"), + ("content_structure", "list_format", 0.3, "Use structured lists to make content more extractable by AI engines"), + ("authority_signal", "data_citation", 0.3, "Include data citations and statistics to increase authority signals"), + ("authority_signal", "expert_citation", 0.3, "Add expert quotes and references to strengthen E-E-A-T signals"), + ("citation_format", "direct_citation", 0.2, "Optimize content for direct citation by AI engines"), +] + + +def _matches_any(text: str, patterns: list[re.Pattern[str]]) -> bool: + return any(p.search(text) for p in patterns) + + +def _count_matches(text: str, patterns: list[re.Pattern[str]]) -> int: + return sum(1 for p in patterns if p.search(text)) + + +def _build_type_insights( + patterns: list[CitationPattern], + pattern_type: str, + top_key: str, +) -> dict[str, Any]: + filtered = [p for p in patterns if p.pattern_type == pattern_type] + insights: dict[str, Any] = {f"{p.pattern_name}_frequency": p.frequency for p in filtered} + if filtered: + best = max(filtered, key=lambda p: p.frequency) + if best.frequency > 0: + insights[top_key] = best.pattern_name + return insights + + +class ContentStructureAnalyzer: + def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]: + if not results: + return self._empty_patterns() + + counts: dict[str, int] = {name: 0 for name, _, _, _ in _CONTENT_RULES} + for r in results: + for name, patterns, _, _ in _CONTENT_RULES: + if _matches_any(r.raw_response, patterns): + counts[name] += 1 + + total = len(results) + return [ + CitationPattern( + pattern_type="content_structure", + pattern_name=name, + frequency=counts[name] / total, + confidence=conf if counts[name] > 0 else 0.0, + description=desc, + details={"count": counts[name], "total": total}, + ) + for name, _, conf, desc in _CONTENT_RULES + ] + + def _empty_patterns(self) -> list[CitationPattern]: + return [ + CitationPattern( + pattern_type="content_structure", + pattern_name=name, + frequency=0.0, + confidence=0.0, + description="", + details={}, + ) + for name, _, _, _ in _CONTENT_RULES + ] + + +class AuthoritySignalAnalyzer: + def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]: + if not results: + return self._empty_patterns() + + counts: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES} + extra: dict[str, int] = {name: 0 for name, _, _, _, _ in _AUTHORITY_RULES} + + for r in results: + for name, patterns, _, _, threshold in _AUTHORITY_RULES: + match_count = _count_matches(r.raw_response, patterns) + if match_count >= threshold: + counts[name] += 1 + extra[name] += match_count + + total = len(results) + return [ + CitationPattern( + pattern_type="authority_signal", + pattern_name=name, + frequency=counts[name] / total, + confidence=conf if counts[name] > 0 else 0.0, + description=desc, + details={"count": counts[name], "total": total, "match_count": extra[name]}, + ) + for name, _, conf, desc, _ in _AUTHORITY_RULES + ] + + def _empty_patterns(self) -> list[CitationPattern]: + return [ + CitationPattern( + pattern_type="authority_signal", + pattern_name=name, + frequency=0.0, + confidence=0.0, + description="", + details={}, + ) + for name, _, _, _, _ in _AUTHORITY_RULES + ] + + +class CitationFormatAnalyzer: + def analyze(self, results: list[AIQueryResult]) -> list[CitationPattern]: + if not results: + return self._empty_patterns() + + direct_count = 0 + indirect_count = 0 + comparison_count = 0 + + for r in results: + if r.has_brand_citation: + if _matches_any(r.raw_response, _COMPARISON_PATTERNS) and r.has_competitor_citation: + comparison_count += 1 + elif r.brand_context and _matches_any(r.brand_context, _DIRECT_CITATION_PATTERNS): + direct_count += 1 + else: + indirect_count += 1 + + total = len(results) + return [ + CitationPattern( + pattern_type="citation_format", + pattern_name="direct_citation", + frequency=direct_count / total, + confidence=0.9 if direct_count > 0 else 0.0, + description="Direct citation format detected", + details={"count": direct_count, "total": total}, + ), + CitationPattern( + pattern_type="citation_format", + pattern_name="indirect_citation", + frequency=indirect_count / total, + confidence=0.7 if indirect_count > 0 else 0.0, + description="Indirect citation format detected", + details={"count": indirect_count, "total": total}, + ), + CitationPattern( + pattern_type="citation_format", + pattern_name="comparison_citation", + frequency=comparison_count / total, + confidence=0.85 if comparison_count > 0 else 0.0, + description="Comparison citation format detected", + details={"count": comparison_count, "total": total}, + ), + ] + + def _empty_patterns(self) -> list[CitationPattern]: + return [ + CitationPattern( + pattern_type="citation_format", + pattern_name=name, + frequency=0.0, + confidence=0.0, + description="", + details={}, + ) + for name in ("direct_citation", "indirect_citation", "comparison_citation") + ] + + +class EnginePreferenceAnalyzer: + def analyze(self, results: list[AIQueryResult]) -> dict[str, dict[str, Any]]: + if not results: + return {} + + engine_data: dict[str, dict[str, Any]] = {} + + for r in results: + engine_name = r.engine_type.value + if engine_name not in engine_data: + engine_data[engine_name] = { + "results": [], + "citation_positions": [], + "format_hits": {"faq": 0, "list": 0, "table": 0}, + } + entry = engine_data[engine_name] + entry["results"].append(r) + if r.has_brand_citation: + for c in r.citations: + entry["citation_positions"].append(c.position) + if _matches_any(r.raw_response, _FAQ_PATTERNS): + entry["format_hits"]["faq"] += 1 + if _matches_any(r.raw_response, _LIST_PATTERNS): + entry["format_hits"]["list"] += 1 + if _matches_any(r.raw_response, _TABLE_PATTERNS): + entry["format_hits"]["table"] += 1 + + prefs: dict[str, dict[str, Any]] = {} + for engine_name, data in engine_data.items(): + total = len(data["results"]) + cited = sum(1 for r in data["results"] if r.has_brand_citation) + positions = data["citation_positions"] + avg_pos = sum(positions) / len(positions) if positions else -1 + + prefs[engine_name] = { + "citation_rate": cited / total if total > 0 else 0.0, + "avg_citation_position": avg_pos, + "format_preferences": { + fmt: count / total if total > 0 else 0.0 + for fmt, count in data["format_hits"].items() + }, + } + + return prefs + + +class CitationPatternEngine: + def __init__(self) -> None: + self.content_analyzer = ContentStructureAnalyzer() + self.authority_analyzer = AuthoritySignalAnalyzer() + self.format_analyzer = CitationFormatAnalyzer() + self.engine_analyzer = EnginePreferenceAnalyzer() + + def analyze( + self, + results: list[AIQueryResult], + brand_id: str, + query: str, + ) -> PatternAnalysisReport: + if not results: + return PatternAnalysisReport( + brand_id=brand_id, + query=query, + total_results=0, + patterns=[], + content_structure_insights={}, + authority_signal_insights={}, + citation_format_insights={}, + engine_preferences={}, + recommendations=[], + ) + + patterns: list[CitationPattern] = [] + patterns.extend(self.content_analyzer.analyze(results)) + patterns.extend(self.authority_analyzer.analyze(results)) + patterns.extend(self.format_analyzer.analyze(results)) + + engine_prefs = self.engine_analyzer.analyze(results) + + recommendations = self._generate_recommendations(patterns, engine_prefs) + + return PatternAnalysisReport( + brand_id=brand_id, + query=query, + total_results=len(results), + patterns=patterns, + content_structure_insights=_build_type_insights(patterns, "content_structure", "dominant_format"), + authority_signal_insights=_build_type_insights(patterns, "authority_signal", "strongest_signal"), + citation_format_insights=_build_type_insights(patterns, "citation_format", "primary_format"), + engine_preferences=engine_prefs, + recommendations=recommendations, + ) + + def _generate_recommendations( + self, + patterns: list[CitationPattern], + engine_prefs: dict[str, dict], + ) -> list[str]: + recommendations: list[str] = [] + + pattern_map = {(p.pattern_type, p.pattern_name): p for p in patterns} + for ptype, pname, threshold, message in _RECOMMENDATION_RULES: + p = pattern_map.get((ptype, pname)) + if p and p.frequency < threshold: + recommendations.append(message) + + for engine_name, prefs in engine_prefs.items(): + if prefs["citation_rate"] < 0.3: + recommendations.append( + f"Low citation rate on {engine_name} ({prefs['citation_rate']:.0%}), " + f"consider optimizing content for this engine" + ) + + return recommendations diff --git a/backend/app/services/detection_scheduler.py b/backend/app/services/detection_scheduler.py new file mode 100644 index 0000000..b461ffb --- /dev/null +++ b/backend/app/services/detection_scheduler.py @@ -0,0 +1,227 @@ +import logging +import uuid +from datetime import datetime, timezone + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.ai_engines import _build_adapters +from app.models.brand import Brand +from app.models.detection_task import VALID_FREQUENCIES, DetectionTask +from app.services.ai_engine.base import EngineType +from app.services.ai_engine.batch_query import BatchQueryService + +logger = logging.getLogger(__name__) + + +class TaskNotFoundError(Exception): + pass + + +class DetectionSchedulerService: + async def create_task( + self, + task_data: dict, + brand_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, + ) -> DetectionTask: + frequency = task_data.get("frequency", "daily") + self._validate_frequency(frequency) + + task = DetectionTask( + brand_id=brand_id, + user_id=user_id, + name=task_data["name"], + frequency=frequency, + engines=task_data.get("engines", []), + queries=task_data.get("queries", []), + competitor_names=task_data.get("competitor_names"), + ) + task.next_run_at = task.compute_next_run_at() + db.add(task) + await db.commit() + await db.refresh(task) + logger.info(f"Created detection task: {task.id}, name={task.name}") + return task + + async def update_task( + self, + task_id: uuid.UUID, + task_data: dict, + user_id: uuid.UUID, + db: AsyncSession, + ) -> DetectionTask: + task = await self._get_task_by_id(task_id, user_id, db) + if not task: + raise TaskNotFoundError(f"Detection task {task_id} not found") + + frequency = task_data.get("frequency", task.frequency) + self._validate_frequency(frequency) + + for field, value in task_data.items(): + setattr(task, field, value) + + task.next_run_at = task.compute_next_run_at() + await db.commit() + await db.refresh(task) + logger.info(f"Updated detection task: {task.id}") + return task + + async def delete_task( + self, + task_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, + ) -> bool: + task = await self._get_task_by_id(task_id, user_id, db) + if not task: + return False + + await db.delete(task) + await db.commit() + logger.info(f"Deleted detection task: {task_id}") + return True + + async def get_tasks( + self, + brand_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, + ) -> list[DetectionTask]: + stmt = ( + select(DetectionTask) + .where( + and_( + DetectionTask.brand_id == brand_id, + DetectionTask.user_id == user_id, + ) + ) + .order_by(DetectionTask.created_at.desc()) + ) + result = await db.execute(stmt) + return list(result.scalars().all()) + + async def trigger_task( + self, + task_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, + ) -> dict: + task = await self._get_task_by_id(task_id, user_id, db) + if not task: + return {"status": "error", "message": "Task not found"} + + return await self.execute_task(task, db) + + async def execute_task( + self, + task: DetectionTask, + db: AsyncSession, + ) -> dict: + try: + results = await self._run_batch_query(task, db) + await self._generate_alerts_if_needed(task, results, db) + + task.last_run_at = datetime.now(timezone.utc) + task.next_run_at = task.compute_next_run_at() + await db.commit() + await db.refresh(task) + + logger.info(f"Detection task executed: {task.id}") + return {"status": "success", "results": results} + except Exception as e: + logger.error(f"Detection task execution failed: {task.id}, error={e}") + return {"status": "error", "message": str(e)} + + async def _run_batch_query( + self, + task: DetectionTask, + db: AsyncSession, + ) -> list: + adapters = _build_adapters() + batch_service = BatchQueryService(adapters) + + all_results = [] + brand = await self._get_brand(task.brand_id, db) + brand_name = brand.name if brand else "Unknown" + + for query_text in task.queries: + engines = [ + EngineType(e) for e in task.engines if e in EngineType._value2member_map_ + ] + results = await batch_service.query_batch( + engines=engines, + query=query_text, + brand_name=brand_name, + competitor_names=task.competitor_names, + ) + all_results.extend(results) + + return all_results + + async def _generate_alerts_if_needed( + self, + task: DetectionTask, + results: list, + db: AsyncSession, + ) -> list: + if not results: + return [] + + brand = await self._get_brand(task.brand_id, db) + if not brand: + return [] + + brand_cited = sum(1 for r in results if r.has_brand_citation) + competitor_cited = sum(1 for r in results if r.has_competitor_citation) + + alerts = [] + if competitor_cited > 0 and brand_cited == 0: + from app.services.alert_engine import AlertEngine + + alert_engine = AlertEngine(db) + alert = await alert_engine._create_alert( + brand_id=task.brand_id, + user_id=task.user_id, + alert_type="competitor_overtake", + severity="warning", + title=f"{brand.name} 未被引用但竞品被引用", + message=f"检测任务「{task.name}」执行后发现:品牌未被AI引擎引用,但竞品被引用了 {competitor_cited} 次。", + data={"task_id": str(task.id), "competitor_cited": competitor_cited}, + ) + if alert: + alerts.append(alert) + + return alerts + + @staticmethod + def _validate_frequency(frequency: str) -> None: + if frequency not in VALID_FREQUENCIES: + raise ValueError( + f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}" + ) + + async def _get_task_by_id( + self, + task_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, + ) -> DetectionTask | None: + stmt = select(DetectionTask).where( + and_( + DetectionTask.id == task_id, + DetectionTask.user_id == user_id, + ) + ) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + async def _get_brand( + self, + brand_id: uuid.UUID, + db: AsyncSession, + ) -> Brand | None: + stmt = select(Brand).where(Brand.id == brand_id) + result = await db.execute(stmt) + return result.scalar_one_or_none() diff --git a/backend/tests/test_api/test_detection_api.py b/backend/tests/test_api/test_detection_api.py new file mode 100644 index 0000000..af16c9f --- /dev/null +++ b/backend/tests/test_api/test_detection_api.py @@ -0,0 +1,202 @@ +import uuid + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestDetectionTaskAPI: + @pytest.mark.asyncio + async def test_create_detection_task(self, async_client, test_brand): + task_data = { + "brand_id": str(test_brand.id), + "name": "每日品牌检测", + "frequency": "daily", + "engines": ["chatgpt", "perplexity"], + "queries": ["最佳保险品牌", "保险推荐"], + "competitor_names": ["竞品A"], + } + response = await async_client.post("/api/v1/detection/tasks", json=task_data) + + assert response.status_code == 201 + data = response.json() + assert data["name"] == "每日品牌检测" + assert data["frequency"] == "daily" + assert data["engines"] == ["chatgpt", "perplexity"] + assert data["queries"] == ["最佳保险品牌", "保险推荐"] + assert data["is_active"] is True + assert "id" in data + + @pytest.mark.asyncio + async def test_get_detection_tasks(self, async_client, test_brand): + response = await async_client.get( + "/api/v1/detection/tasks", params={"brand_id": str(test_brand.id)} + ) + + assert response.status_code == 200 + data = response.json() + assert "items" in data + assert "total" in data + + @pytest.mark.asyncio + async def test_update_detection_task(self, async_client, test_brand): + create_data = { + "brand_id": str(test_brand.id), + "name": "原始任务", + "frequency": "weekly", + "engines": ["chatgpt"], + "queries": ["查询1"], + } + create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) + task_id = create_resp.json()["id"] + + update_data = { + "name": "更新后任务", + "frequency": "daily", + "engines": ["chatgpt", "perplexity"], + } + response = await async_client.put(f"/api/v1/detection/tasks/{task_id}", json=update_data) + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "更新后任务" + assert data["frequency"] == "daily" + + @pytest.mark.asyncio + async def test_delete_detection_task(self, async_client, test_brand): + create_data = { + "brand_id": str(test_brand.id), + "name": "待删除任务", + "frequency": "daily", + "engines": ["chatgpt"], + "queries": ["查询1"], + } + create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) + task_id = create_resp.json()["id"] + + response = await async_client.delete(f"/api/v1/detection/tasks/{task_id}") + + assert response.status_code == 204 + + @pytest.mark.asyncio + async def test_trigger_detection_task(self, async_client, test_brand): + create_data = { + "brand_id": str(test_brand.id), + "name": "手动触发任务", + "frequency": "daily", + "engines": ["chatgpt"], + "queries": ["查询1"], + } + create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) + task_id = create_resp.json()["id"] + + response = await async_client.post(f"/api/v1/detection/tasks/{task_id}/trigger") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + + @pytest.mark.asyncio + async def test_unauthorized_access(self, async_session): + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer invalid_token"} + response = await client.get("/api/v1/detection/tasks", headers=headers) + assert response.status_code == 401 + + app.dependency_overrides.clear() diff --git a/backend/tests/test_services/test_citation_pattern.py b/backend/tests/test_services/test_citation_pattern.py new file mode 100644 index 0000000..25e5923 --- /dev/null +++ b/backend/tests/test_services/test_citation_pattern.py @@ -0,0 +1,619 @@ +from datetime import UTC, datetime + +import pytest + +from app.services.ai_engine.base import AIQueryResult, CitationInfo, EngineType +from app.services.citation_pattern import ( + AuthoritySignalAnalyzer, + CitationFormatAnalyzer, + CitationPattern, + CitationPatternEngine, + ContentStructureAnalyzer, + EnginePreferenceAnalyzer, + PatternAnalysisReport, +) + + +def _make_result( + engine_type: EngineType = EngineType.CHATGPT, + raw_response: str = "test response", + citations: list[CitationInfo] | None = None, + has_brand_citation: bool = False, + has_competitor_citation: bool = False, + brand_context: str | None = None, + competitor_contexts: list[str] | None = None, + metadata: dict | None = None, +) -> AIQueryResult: + return AIQueryResult( + engine_type=engine_type, + query="test query", + raw_response=raw_response, + citations=citations or [], + has_brand_citation=has_brand_citation, + has_competitor_citation=has_competitor_citation, + brand_context=brand_context, + competitor_contexts=competitor_contexts or [], + response_time_ms=1000, + timestamp=datetime(2025, 1, 1, tzinfo=UTC), + metadata=metadata or {}, + ) + + +class TestCitationPatternDataStructure: + def test_create_citation_pattern(self): + pattern = CitationPattern( + pattern_type="content_structure", + pattern_name="faq_format", + frequency=0.6, + confidence=0.8, + description="FAQ format detected in responses", + details={"count": 3, "total": 5}, + ) + assert pattern.pattern_type == "content_structure" + assert pattern.pattern_name == "faq_format" + assert pattern.frequency == 0.6 + assert pattern.confidence == 0.8 + assert pattern.details["count"] == 3 + + def test_frequency_range(self): + pattern = CitationPattern( + pattern_type="authority_signal", + pattern_name="data_citation", + frequency=0.0, + confidence=0.5, + description="", + details={}, + ) + assert 0.0 <= pattern.frequency <= 1.0 + + def test_confidence_range(self): + pattern = CitationPattern( + pattern_type="citation_format", + pattern_name="direct_citation", + frequency=0.5, + confidence=1.0, + description="", + details={}, + ) + assert 0.0 <= pattern.confidence <= 1.0 + + def test_pattern_types(self): + valid_types = {"content_structure", "authority_signal", "citation_format", "engine_preference"} + for pt in valid_types: + pattern = CitationPattern( + pattern_type=pt, + pattern_name="test", + frequency=0.5, + confidence=0.5, + description="", + details={}, + ) + assert pattern.pattern_type in valid_types + + +class TestPatternAnalysisReportDataStructure: + def test_create_report(self): + report = PatternAnalysisReport( + brand_id="brand-123", + query="test query", + total_results=5, + patterns=[], + content_structure_insights={}, + authority_signal_insights={}, + citation_format_insights={}, + engine_preferences={}, + recommendations=[], + ) + assert report.brand_id == "brand-123" + assert report.total_results == 5 + assert report.patterns == [] + assert report.recommendations == [] + + def test_report_with_patterns(self): + patterns = [ + CitationPattern( + pattern_type="content_structure", + pattern_name="faq_format", + frequency=0.7, + confidence=0.9, + description="FAQ detected", + details={"count": 7}, + ) + ] + report = PatternAnalysisReport( + brand_id="b1", + query="q", + total_results=10, + patterns=patterns, + content_structure_insights={"faq_frequency": 0.7}, + authority_signal_insights={}, + citation_format_insights={}, + engine_preferences={}, + recommendations=["Add FAQ sections"], + ) + assert len(report.patterns) == 1 + assert report.content_structure_insights["faq_frequency"] == 0.7 + assert len(report.recommendations) == 1 + + +class TestContentStructureAnalyzer: + @pytest.fixture + def analyzer(self): + return ContentStructureAnalyzer() + + def test_faq_format_detection(self, analyzer): + faq_response = """ + Q: What is GEO? + A: GEO stands for Generative Engine Optimization. + + Q: How does GEO work? + A: GEO works by optimizing content for AI engines. + + 常见问题: + 问题: 什么是SEO? + 回答: SEO是搜索引擎优化。 + """ + results = [_make_result(raw_response=faq_response)] + patterns = analyzer.analyze(results) + faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"] + assert len(faq_pattern) == 1 + assert faq_pattern[0].frequency > 0 + assert faq_pattern[0].pattern_type == "content_structure" + + def test_list_format_detection(self, analyzer): + list_response = """ + Here are the top features: + 1. Fast performance + 2. Easy to use + 3. Affordable price + + - Benefit A + - Benefit B + - Benefit C + """ + results = [_make_result(raw_response=list_response)] + patterns = analyzer.analyze(results) + list_pattern = [p for p in patterns if p.pattern_name == "list_format"] + assert len(list_pattern) == 1 + assert list_pattern[0].frequency > 0 + + def test_table_format_detection(self, analyzer): + table_response = """ + | Feature | Plan A | Plan B | + |---------|--------|--------| + | Price | $10 | $20 | + + Comparison table: + Item Value + ---- ----- + A 100 + B 200 + """ + results = [_make_result(raw_response=table_response)] + patterns = analyzer.analyze(results) + table_pattern = [p for p in patterns if p.pattern_name == "table_format"] + assert len(table_pattern) == 1 + assert table_pattern[0].frequency > 0 + + def test_quote_block_detection(self, analyzer): + quote_response = """ + According to the expert: + > "This is the best solution on the market." + + As stated in the report: + "The company leads in innovation." + """ + results = [_make_result(raw_response=quote_response)] + patterns = analyzer.analyze(results) + quote_pattern = [p for p in patterns if p.pattern_name == "quote_block"] + assert len(quote_pattern) == 1 + assert quote_pattern[0].frequency > 0 + + def test_no_structure_detected(self, analyzer): + plain_response = "This is a plain text response without any special formatting." + results = [_make_result(raw_response=plain_response)] + patterns = analyzer.analyze(results) + for p in patterns: + assert p.frequency == 0.0 + + def test_multiple_results_aggregation(self, analyzer): + results = [ + _make_result(raw_response="Q: What? A: Something."), + _make_result(raw_response="Plain text without structure."), + _make_result(raw_response="1. Item one\n2. Item two"), + ] + patterns = analyzer.analyze(results) + faq_pattern = [p for p in patterns if p.pattern_name == "faq_format"][0] + list_pattern = [p for p in patterns if p.pattern_name == "list_format"][0] + assert faq_pattern.frequency == pytest.approx(1 / 3, rel=0.01) + assert list_pattern.frequency == pytest.approx(1 / 3, rel=0.01) + + +class TestAuthoritySignalAnalyzer: + @pytest.fixture + def analyzer(self): + return AuthoritySignalAnalyzer() + + def test_data_citation_detection(self, analyzer): + data_response = """ + According to a 2024 study by MIT, 78% of companies adopted AI. + Research from Stanford University shows that productivity increased by 45%. + Data from the World Health Organization indicates a 30% reduction. + Statistics show that 90% of users prefer this approach. + """ + results = [_make_result(raw_response=data_response)] + patterns = analyzer.analyze(results) + data_pattern = [p for p in patterns if p.pattern_name == "data_citation"] + assert len(data_pattern) == 1 + assert data_pattern[0].frequency > 0 + assert data_pattern[0].details["match_count"] > 0 + + def test_expert_citation_detection(self, analyzer): + expert_response = """ + Dr. Smith from Harvard notes that this trend will continue. + Professor Johnson at Stanford recommends this approach. + Expert analyst Jane Doe suggests using this method. + According to industry expert John, the market will grow. + """ + results = [_make_result(raw_response=expert_response)] + patterns = analyzer.analyze(results) + expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"] + assert len(expert_pattern) == 1 + assert expert_pattern[0].frequency > 0 + + def test_certification_mark_detection(self, analyzer): + cert_response = """ + The product is ISO 9001 certified and FDA approved. + It has received CE certification and UL listed status. + SOC 2 Type II compliant and HIPAA compliant. + """ + results = [_make_result(raw_response=cert_response)] + patterns = analyzer.analyze(results) + cert_pattern = [p for p in patterns if p.pattern_name == "certification_mark"] + assert len(cert_pattern) == 1 + assert cert_pattern[0].frequency > 0 + + def test_no_authority_signals(self, analyzer): + plain_response = "This is a basic response without any authority signals." + results = [_make_result(raw_response=plain_response)] + patterns = analyzer.analyze(results) + for p in patterns: + assert p.frequency == 0.0 + + def test_multiple_results_aggregation(self, analyzer): + results = [ + _make_result(raw_response="According to a 2024 study, 80% agree."), + _make_result(raw_response="No authority signals here."), + _make_result(raw_response="Dr. Lee from Oxford confirms the findings."), + ] + patterns = analyzer.analyze(results) + data_pattern = [p for p in patterns if p.pattern_name == "data_citation"][0] + expert_pattern = [p for p in patterns if p.pattern_name == "expert_citation"][0] + assert data_pattern.frequency == pytest.approx(1 / 3, rel=0.01) + assert expert_pattern.frequency == pytest.approx(1 / 3, rel=0.01) + + +class TestCitationFormatAnalyzer: + @pytest.fixture + def analyzer(self): + return CitationFormatAnalyzer() + + def test_direct_citation_detection(self, analyzer): + direct_response = """ + According to BrandX, their product is the best in class. + BrandX states that they have over 1 million users. + BrandX claims to be the industry leader. + """ + results = [ + _make_result( + raw_response=direct_response, + has_brand_citation=True, + brand_context="BrandX states that they have over 1 million users", + ) + ] + patterns = analyzer.analyze(results) + direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"] + assert len(direct_pattern) == 1 + assert direct_pattern[0].frequency > 0 + + def test_indirect_citation_detection(self, analyzer): + indirect_response = """ + Some leading solutions in this space include comprehensive platforms + that offer multiple features. One such platform provides AI-powered + analytics and real-time monitoring capabilities. + """ + results = [ + _make_result( + raw_response=indirect_response, + has_brand_citation=True, + brand_context="One such platform provides AI-powered analytics", + ) + ] + patterns = analyzer.analyze(results) + indirect_pattern = [p for p in patterns if p.pattern_name == "indirect_citation"] + assert len(indirect_pattern) == 1 + + def test_comparison_citation_detection(self, analyzer): + comparison_response = """ + Compared to BrandX, CompetitorY offers better pricing. + While BrandX focuses on enterprise, CompetitorY targets SMBs. + BrandX vs CompetitorY: BrandX has more features but CompetitorY is cheaper. + """ + results = [ + _make_result( + raw_response=comparison_response, + has_brand_citation=True, + has_competitor_citation=True, + brand_context="Compared to BrandX", + competitor_contexts=["CompetitorY offers better pricing"], + ) + ] + patterns = analyzer.analyze(results) + comparison_pattern = [p for p in patterns if p.pattern_name == "comparison_citation"] + assert len(comparison_pattern) == 1 + assert comparison_pattern[0].frequency > 0 + + def test_no_citation_format(self, analyzer): + results = [ + _make_result( + raw_response="Generic response without citations.", + has_brand_citation=False, + ) + ] + patterns = analyzer.analyze(results) + for p in patterns: + assert p.frequency == 0.0 + + def test_citation_with_position_info(self, analyzer): + results = [ + _make_result( + raw_response="BrandX is mentioned here.", + has_brand_citation=True, + brand_context="BrandX is mentioned here", + citations=[ + CitationInfo( + source_url="https://example.com", + source_title="Example", + citation_context="BrandX is mentioned here", + confidence=0.9, + position=0, + ) + ], + ) + ] + patterns = analyzer.analyze(results) + direct_pattern = [p for p in patterns if p.pattern_name == "direct_citation"][0] + assert direct_pattern.frequency > 0 + + +class TestEnginePreferenceAnalyzer: + @pytest.fixture + def analyzer(self): + return EnginePreferenceAnalyzer() + + def test_single_engine_citation_rate(self, analyzer): + results = [ + _make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True), + _make_result(engine_type=EngineType.CHATGPT, has_brand_citation=False), + ] + prefs = analyzer.analyze(results) + assert "chatgpt" in prefs + assert prefs["chatgpt"]["citation_rate"] == 0.5 + + def test_multi_engine_preferences(self, analyzer): + results = [ + _make_result(engine_type=EngineType.CHATGPT, has_brand_citation=True), + _make_result(engine_type=EngineType.PERPLEXITY, has_brand_citation=True), + _make_result(engine_type=EngineType.KIMI, has_brand_citation=False), + _make_result(engine_type=EngineType.DEEPSEEK, has_brand_citation=True), + ] + prefs = analyzer.analyze(results) + assert len(prefs) == 4 + for engine_name, engine_data in prefs.items(): + assert "citation_rate" in engine_data + assert "avg_citation_position" in engine_data + assert "format_preferences" in engine_data + + def test_citation_position_preference(self, analyzer): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + has_brand_citation=True, + citations=[ + CitationInfo( + source_url="https://example.com", + source_title="Example", + citation_context="test", + confidence=0.9, + position=0, + ) + ], + ), + _make_result( + engine_type=EngineType.PERPLEXITY, + has_brand_citation=True, + citations=[ + CitationInfo( + source_url="https://example.com", + source_title="Example", + citation_context="test", + confidence=0.9, + position=5, + ) + ], + ), + ] + prefs = analyzer.analyze(results) + assert prefs["chatgpt"]["avg_citation_position"] < prefs["perplexity"]["avg_citation_position"] + + def test_format_preferences(self, analyzer): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What? A: Something.\n1. First item\n2. Second item", + ), + _make_result( + engine_type=EngineType.PERPLEXITY, + raw_response="Plain text without structure.", + ), + ] + prefs = analyzer.analyze(results) + assert "faq" in prefs["chatgpt"]["format_preferences"] + assert "list" in prefs["chatgpt"]["format_preferences"] + + def test_empty_results(self, analyzer): + prefs = analyzer.analyze([]) + assert prefs == {} + + +class TestCitationPatternEngine: + @pytest.fixture + def engine(self): + return CitationPatternEngine() + + def test_full_analysis_flow(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What is GEO? A: GEO is optimization for AI.\n1. First benefit\n2. Second benefit", + has_brand_citation=True, + brand_context="BrandX is mentioned", + citations=[ + CitationInfo( + source_url="https://example.com", + source_title="Example", + citation_context="BrandX is mentioned", + confidence=0.9, + position=0, + ) + ], + ), + _make_result( + engine_type=EngineType.PERPLEXITY, + raw_response="According to a 2024 study, 80% of companies use AI. Dr. Smith recommends BrandX.", + has_brand_citation=True, + brand_context="Dr. Smith recommends BrandX", + citations=[ + CitationInfo( + source_url="https://example.com", + source_title="Example", + citation_context="Dr. Smith recommends BrandX", + confidence=0.8, + position=2, + ) + ], + ), + ] + report = engine.analyze(results, brand_id="brand-123", query="what is geo") + assert isinstance(report, PatternAnalysisReport) + assert report.brand_id == "brand-123" + assert report.query == "what is geo" + assert report.total_results == 2 + assert len(report.patterns) > 0 + assert isinstance(report.content_structure_insights, dict) + assert isinstance(report.authority_signal_insights, dict) + assert isinstance(report.citation_format_insights, dict) + assert isinstance(report.engine_preferences, dict) + assert isinstance(report.recommendations, list) + + def test_empty_input_handling(self, engine): + report = engine.analyze([], brand_id="brand-1", query="test") + assert report.total_results == 0 + assert report.patterns == [] + assert report.content_structure_insights == {} + assert report.authority_signal_insights == {} + assert report.citation_format_insights == {} + assert report.engine_preferences == {} + assert report.recommendations == [] + + def test_single_engine_result(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What? A: This.", + has_brand_citation=True, + brand_context="BrandX is great", + ) + ] + report = engine.analyze(results, brand_id="b1", query="q") + assert report.total_results == 1 + assert "chatgpt" in report.engine_preferences + assert len(report.patterns) > 0 + + def test_multi_engine_aggregation(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What? A: This. BrandX is the best.", + has_brand_citation=True, + brand_context="BrandX is the best", + ), + _make_result( + engine_type=EngineType.PERPLEXITY, + raw_response="1. First point\n2. Second point\nBrandX offers great value.", + has_brand_citation=True, + brand_context="BrandX offers great value", + ), + _make_result( + engine_type=EngineType.KIMI, + raw_response="Plain response. No citations here.", + has_brand_citation=False, + ), + ] + report = engine.analyze(results, brand_id="b1", query="q") + assert report.total_results == 3 + assert len(report.engine_preferences) == 3 + chatgpt_rate = report.engine_preferences["chatgpt"]["citation_rate"] + kimi_rate = report.engine_preferences["kimi"]["citation_rate"] + assert chatgpt_rate == 1.0 + assert kimi_rate == 0.0 + + def test_pattern_report_generation(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What is GEO? A: GEO optimization.\nAccording to 2024 research, 75% agree.", + has_brand_citation=True, + brand_context="BrandX provides GEO", + citations=[ + CitationInfo( + source_url="https://brandx.com", + source_title="BrandX", + citation_context="BrandX provides GEO", + confidence=0.95, + position=0, + ) + ], + ), + ] + report = engine.analyze(results, brand_id="b1", query="geo optimization") + assert len(report.patterns) > 0 + pattern_types = {p.pattern_type for p in report.patterns} + assert "content_structure" in pattern_types + assert "authority_signal" in pattern_types + assert "citation_format" in pattern_types + + def test_recommendations_generated(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What? A: Something.", + has_brand_citation=False, + ), + ] + report = engine.analyze(results, brand_id="b1", query="q") + assert isinstance(report.recommendations, list) + + def test_insights_populated(self, engine): + results = [ + _make_result( + engine_type=EngineType.CHATGPT, + raw_response="Q: What? A: This.\n1. Item\nISO 9001 certified.", + has_brand_citation=True, + brand_context="BrandX is here", + ), + ] + report = engine.analyze(results, brand_id="b1", query="q") + assert len(report.content_structure_insights) > 0 + assert len(report.authority_signal_insights) > 0 + assert len(report.citation_format_insights) > 0 diff --git a/backend/tests/test_services/test_detection_scheduler.py b/backend/tests/test_services/test_detection_scheduler.py new file mode 100644 index 0000000..6638cac --- /dev/null +++ b/backend/tests/test_services/test_detection_scheduler.py @@ -0,0 +1,376 @@ +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +class TestDetectionTaskModel: + @pytest.mark.asyncio + async def test_create_detection_task(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="每日品牌检测", + frequency="daily", + engines=["chatgpt", "perplexity"], + queries=["最佳保险品牌", "保险推荐"], + competitor_names=["竞品A", "竞品B"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + assert task.id is not None + assert task.brand_id == test_brand.id + assert task.user_id == test_user.id + assert task.name == "每日品牌检测" + assert task.frequency == "daily" + assert task.engines == ["chatgpt", "perplexity"] + assert task.queries == ["最佳保险品牌", "保险推荐"] + assert task.competitor_names == ["竞品A", "竞品B"] + assert task.is_active is True + assert task.last_run_at is None + assert task.next_run_at is not None + assert task.created_at is not None + assert task.updated_at is not None + + @pytest.mark.asyncio + async def test_detection_task_default_values(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="简单检测", + frequency="weekly", + engines=["chatgpt"], + queries=["测试查询"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + assert task.is_active is True + assert task.competitor_names is None + assert task.last_run_at is None + + +class TestDetectionSchedulerService: + @pytest.mark.asyncio + async def test_create_task(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + task_data = { + "name": "每日品牌检测", + "frequency": "daily", + "engines": ["chatgpt", "perplexity"], + "queries": ["最佳保险品牌", "保险推荐"], + "competitor_names": ["竞品A"], + } + task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) + + assert task.id is not None + assert task.name == "每日品牌检测" + assert task.frequency == "daily" + assert task.brand_id == test_brand.id + assert task.user_id == test_user.id + + @pytest.mark.asyncio + async def test_update_task(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + from app.services.detection_scheduler import DetectionSchedulerService + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="旧名称", + frequency="weekly", + engines=["chatgpt"], + queries=["查询1"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + service = DetectionSchedulerService() + update_data = { + "name": "新名称", + "frequency": "daily", + "engines": ["chatgpt", "perplexity"], + } + updated = await service.update_task(task.id, update_data, test_user.id, async_session) + + assert updated.name == "新名称" + assert updated.frequency == "daily" + assert updated.engines == ["chatgpt", "perplexity"] + + @pytest.mark.asyncio + async def test_delete_task(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + from app.services.detection_scheduler import DetectionSchedulerService + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="待删除", + frequency="weekly", + engines=["chatgpt"], + queries=["查询1"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + service = DetectionSchedulerService() + result = await service.delete_task(task.id, test_user.id, async_session) + assert result is True + + stmt = select(DetectionTask).where(DetectionTask.id == task.id) + db_result = await async_session.execute(stmt) + assert db_result.scalar_one_or_none() is None + + @pytest.mark.asyncio + async def test_get_tasks(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + from app.services.detection_scheduler import DetectionSchedulerService + + for i in range(3): + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name=f"任务{i}", + frequency="daily", + engines=["chatgpt"], + queries=[f"查询{i}"], + ) + async_session.add(task) + await async_session.commit() + + service = DetectionSchedulerService() + tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) + assert len(tasks) == 3 + + @pytest.mark.asyncio + async def test_trigger_task(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + from app.services.detection_scheduler import DetectionSchedulerService + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="手动触发测试", + frequency="daily", + engines=["chatgpt"], + queries=["测试查询"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + service = DetectionSchedulerService() + with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute: + mock_execute.return_value = {"status": "success", "results": []} + result = await service.trigger_task(task.id, test_user.id, async_session) + + assert result["status"] == "success" + mock_execute.assert_called_once() + + @pytest.mark.asyncio + async def test_frequency_validation_hourly(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + task_data = { + "name": "每小时检测", + "frequency": "hourly", + "engines": ["chatgpt"], + "queries": ["查询"], + } + task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) + assert task.frequency == "hourly" + assert task.next_run_at is not None + + @pytest.mark.asyncio + async def test_frequency_validation_daily(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + task_data = { + "name": "每日检测", + "frequency": "daily", + "engines": ["chatgpt"], + "queries": ["查询"], + } + task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) + assert task.frequency == "daily" + assert task.next_run_at is not None + + @pytest.mark.asyncio + async def test_frequency_validation_weekly(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + task_data = { + "name": "每周检测", + "frequency": "weekly", + "engines": ["chatgpt"], + "queries": ["查询"], + } + task = await service.create_task(task_data, test_brand.id, test_user.id, async_session) + assert task.frequency == "weekly" + assert task.next_run_at is not None + + @pytest.mark.asyncio + async def test_frequency_validation_invalid(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + task_data = { + "name": "无效频率", + "frequency": "monthly", + "engines": ["chatgpt"], + "queries": ["查询"], + } + with pytest.raises(ValueError, match="frequency"): + await service.create_task(task_data, test_brand.id, test_user.id, async_session) + + @pytest.mark.asyncio + async def test_execute_task_flow(self, async_session, test_brand, test_user): + from app.models.detection_task import DetectionTask + from app.services.detection_scheduler import DetectionSchedulerService + + task = DetectionTask( + brand_id=test_brand.id, + user_id=test_user.id, + name="执行流程测试", + frequency="daily", + engines=["chatgpt"], + queries=["测试查询"], + competitor_names=["竞品A"], + ) + async_session.add(task) + await async_session.commit() + await async_session.refresh(task) + + service = DetectionSchedulerService() + + mock_batch_result = [ + MagicMock( + engine_type=MagicMock(value="chatgpt"), + has_brand_citation=True, + has_competitor_citation=False, + ) + ] + + with patch.object( + service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result + ), patch.object( + service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[] + ): + result = await service.execute_task(task, async_session) + + assert result["status"] == "success" + assert "results" in result + + await async_session.refresh(task) + assert task.last_run_at is not None + assert task.next_run_at is not None + + @pytest.mark.asyncio + async def test_delete_task_not_found(self, async_session, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + result = await service.delete_task(uuid.uuid4(), test_user.id, async_session) + assert result is False + + @pytest.mark.asyncio + async def test_update_task_not_found(self, async_session, test_user): + from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError + + service = DetectionSchedulerService() + with pytest.raises(TaskNotFoundError): + await service.update_task(uuid.uuid4(), {"name": "新名称"}, test_user.id, async_session) + + @pytest.mark.asyncio + async def test_get_tasks_empty(self, async_session, test_brand, test_user): + from app.services.detection_scheduler import DetectionSchedulerService + + service = DetectionSchedulerService() + tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) + assert tasks == []